# -*- coding: utf-8 -*-
"""The redis based memory storage implementation."""
import json
from typing import Any, TYPE_CHECKING
from ._base import MemoryBase
from ...message import Msg
if TYPE_CHECKING:
from redis.asyncio import ConnectionPool, Redis
else:
ConnectionPool = Any
Redis = Any
[docs]
class RedisMemory(MemoryBase):
"""Redis memory storage implementation, which supports session and user
context.
.. note:: All the operations in this class are within a specific session
and user context, identified by `session_id` and `user_id`. Cross-session
or cross-user operations are not supported. For example, the
`remove_messages` method will only remove messages that belong to the
specified `session_id` and `user_id`.
.. note:: All Redis keys used by this class will be prefixed by `prefix`
(if provided) to support multi-tenant / multi-app isolation.
"""
SESSION_KEY = "user_id:{user_id}:session:{session_id}:messages"
"""Redis key pattern (without prefix) for storing message IDs (ordered) for
a specific session.
"""
SESSION_PATTERN = "user_id:{user_id}:session:{session_id}:*"
"""Redis key pattern (without prefix) for scanning all keys belong to
a specific user and session."""
MARK_KEY = "user_id:{user_id}:session:{session_id}:mark:{mark}"
"""Redis key pattern (without prefix) for storing message IDs that belong
to a specific mark.
"""
MESSAGE_KEY = "user_id:{user_id}:session:{session_id}:msg:{msg_id}"
"""Redis key pattern (without prefix) for storing message payload data."""
[docs]
def __init__(
self,
session_id: str = "default_session",
user_id: str = "default_user",
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: str | None = None,
connection_pool: ConnectionPool | None = None,
key_prefix: str = "",
key_ttl: int | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Redis based storage by connecting to the Redis
server. You can provide either the connection parameters or an
existing connection pool.
Args:
session_id (`str`, default to `"default_session"`):
The session ID for the storage.
user_id (`str`, default to `"default_user"`):
The user ID for the storage.
host (`str`, default to `"localhost"`):
The Redis server host.
port (`int`, default to `6379`):
The Redis server port.
db (`int`, default to `0`):
The Redis database index.
password (`str | None`, optional):
The password for the Redis server, if required.
connection_pool (`ConnectionPool | None`, optional):
An optional Redis connection pool. If provided, it will be used
instead of creating a new connection.
key_prefix (`str`, default to `""`):
Optional Redis key prefix prepended to every key generated by
this storage. Useful for isolating keys across
apps/environments (e.g. `"prod"`, `"staging"`, `"myapp"`).
key_ttl (`int | None`, default to `None`):
The expired time in seconds for each key. If provided, the
expiration will be refreshed on every access (sliding TTL). If
`None`, the keys will not expire. **Note** the ttl will update
all session related keys, so it's not recommended to set for
large sessions.
**kwargs (`Any`):
Additional keyword arguments to pass to the Redis client.
"""
try:
import redis.asyncio as redis
except ImportError as e:
raise ImportError(
"The 'redis' package is required for RedisStorage. "
"Please install it via 'pip install redis[async]'.",
) from e
super().__init__()
self.session_id = session_id
self.user_id = user_id
self.key_prefix = key_prefix or ""
self.key_ttl = key_ttl
self._client = redis.Redis(
host=host,
port=port,
db=db,
password=password,
connection_pool=connection_pool,
decode_responses=True,
**kwargs,
)
[docs]
def get_client(self) -> Redis:
"""Get the underlying Redis client.
Returns:
`Redis`:
The Redis client instance.
"""
return self._client
def _get_session_key(self) -> str:
"""Get the Redis key for the current session.
Returns:
`str`:
The Redis key for storing messages in the current session.
"""
return self.key_prefix + self.SESSION_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
)
def _get_session_pattern(self) -> str:
"""Get the Redis key pattern for all keys in the current session.
Returns:
`str`:
The Redis key pattern for all keys in the current session.
"""
return self.key_prefix + self.SESSION_PATTERN.format(
user_id=self.user_id,
session_id=self.session_id,
)
def _get_mark_key(self, mark: str) -> str:
"""Get the Redis key for a specific mark.
Args:
mark (`str`):
The mark name.
Returns:
`str`:
The Redis key for storing message IDs with the given mark.
"""
return self.key_prefix + self.MARK_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
mark=mark,
)
def _get_mark_pattern(self) -> str:
"""Get the Redis key pattern for all marks in the current session.
Returns:
`str`:
The Redis key pattern for all mark keys.
"""
return self.key_prefix + self.MARK_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
mark="*",
)
def _get_message_key(self, msg_id: str) -> str:
"""Get the Redis key for a specific message.
Args:
msg_id (`str`):
The message ID.
Returns:
`str`:
The Redis key for storing the message data.
"""
return self.key_prefix + self.MESSAGE_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
msg_id=msg_id,
)
async def _refresh_session_ttl(
self,
pipe: Any | None = None,
) -> None:
"""Refresh the TTL for the session keys (if `key_ttl` is set).
Args:
pipe (`Any | None`, optional):
An optional Redis pipeline to use. If `None`, a new pipeline
will be created and executed immediately. If provided, the
expired commands will be added to the pipeline without
executing it.
"""
if self.key_ttl is None:
return
# Create a new pipeline if not provided
should_execute = pipe is None
if pipe is None:
pipe = self._client.pipeline()
cursor = 0
while True:
cursor, keys = await self._client.scan(
cursor,
match=self._get_session_pattern(),
count=100,
)
for key in keys:
await pipe.expire(key, self.key_ttl)
if cursor == 0:
break
if should_execute:
await pipe.execute()
[docs]
async def get_memory(
self,
mark: str | None = None,
exclude_mark: str | None = None,
prepend_summary: bool = True,
**kwargs: Any,
) -> list[Msg]:
"""Get the messages from the memory by mark (if provided). Otherwise,
get all messages.
.. note:: If `mark` and `exclude_mark` are both provided, the messages
will be filtered by both arguments.
.. note:: `mark` and `exclude_mark` should not overlap.
Args:
mark (`str | None`, optional):
The mark to filter messages. If `None`, return all messages.
exclude_mark (`str | None`, optional):
The mark to exclude messages. If provided, messages with
this mark will be excluded from the results.
prepend_summary (`bool`, defaults to True):
Whether to prepend the compressed summary as a message
Returns:
`list[Msg]`:
The list of messages retrieved from the storage.
"""
# Type checks
if not (mark is None or isinstance(mark, str)):
raise TypeError(
f"The mark should be a string or None, but got {type(mark)}.",
)
if not (exclude_mark is None or isinstance(exclude_mark, str)):
raise TypeError(
f"The exclude_mark should be a string or None, but got "
f"{type(exclude_mark)}.",
)
if mark is None:
# Obtain the message IDs from the session list
msg_ids = await self._client.lrange(self._get_session_key(), 0, -1)
else:
# Obtain the message IDs from the mark list
msg_ids = await self._client.lrange(
self._get_mark_key(mark),
0,
-1,
)
# Exclude messages by exclude_mark
if exclude_mark:
exclude_msg_ids = await self._client.lrange(
self._get_mark_key(exclude_mark),
0,
-1,
)
msg_ids = [_ for _ in msg_ids if _ not in exclude_msg_ids]
# Use mget for batch retrieval to avoid N+1 queries
messages: list[Msg] = []
if msg_ids:
msg_keys = [self._get_message_key(msg_id) for msg_id in msg_ids]
msg_data_list = await self._client.mget(msg_keys)
for msg_data in msg_data_list:
if msg_data is not None:
if isinstance(msg_data, (bytes, bytearray)):
msg_data = msg_data.decode("utf-8")
msg_dict = json.loads(msg_data)
messages.append(Msg.from_dict(msg_dict))
# Refresh TTLs
await self._refresh_session_ttl()
if prepend_summary and self._compressed_summary:
return [
Msg(
"user",
self._compressed_summary,
"user",
),
*messages,
]
return messages
[docs]
async def add(
self,
memories: Msg | list[Msg] | None,
marks: str | list[str] | None = None,
skip_duplicated: bool = True,
**kwargs: Any,
) -> None:
"""Add message into the storage with the given mark (if provided).
Args:
memories (`Msg | list[Msg]`):
The message(s) to be added.
marks (`str | list[str] | None`, optional):
The mark(s) to associate with the message(s). If `None`, no
mark is associated.
skip_duplicated (`bool`, defaults to `True`):
If `True`, skip messages with duplicate IDs that already exist
in the storage. If `False`, allow duplicate message IDs to be
added to the session list (though the message data will be
overwritten).
"""
if memories is None:
return
if isinstance(memories, Msg):
memories = [memories]
# Normalize marks to a list
if marks is None:
mark_list = []
elif isinstance(marks, str):
mark_list = [marks]
else:
mark_list = marks
# Filter out existing messages if skip_duplicated is True
messages_to_add = memories
if skip_duplicated:
# Get all existing message IDs in the current session
existing_msg_ids = await self._client.lrange(
self._get_session_key(),
0,
-1,
)
existing_msg_ids_set = set(existing_msg_ids)
# Filter out messages that already exist
messages_to_add = [
m for m in memories if m.id not in existing_msg_ids_set
]
# If all messages are duplicates, return early
if not messages_to_add:
return
# Use pipeline for atomic operations
pipe = self._client.pipeline()
# Push message ids into the session list
if messages_to_add:
await pipe.rpush(
self._get_session_key(),
*[m.id for m in messages_to_add],
)
# Store message data and marks
for m in messages_to_add:
# Record the message data
await pipe.set(
self._get_message_key(m.id),
json.dumps(m.to_dict(), ensure_ascii=False),
)
# Record the marks if provided
for mark in mark_list:
await pipe.rpush(self._get_mark_key(mark), m.id)
# Refresh TTLs
await self._refresh_session_ttl(pipe=pipe)
await pipe.execute()
[docs]
async def delete(
self,
msg_ids: list[str],
**kwargs: Any,
) -> int:
"""Remove message(s) from the storage by their IDs.
Args:
msg_ids (`list[str]`):
The list of message IDs to be removed.
Returns:
`int`:
The number of messages removed.
"""
if not msg_ids:
return 0
# Get all mark keys once before the pipeline
mark_keys = []
cursor = 0
while True:
cursor, keys = await self._client.scan(
cursor,
match=self._get_mark_pattern(),
count=50,
)
mark_keys.extend(keys)
if cursor == 0:
break
pipe = self._client.pipeline()
for msg_id in msg_ids:
# Remove from the session (0 means remove all occurrences)
await pipe.lrem(self._get_session_key(), 0, msg_id)
# Remove the message data
await pipe.delete(self._get_message_key(msg_id))
# Remove from all marks
for mark_key in mark_keys:
await pipe.lrem(mark_key, 0, msg_id)
# Refresh TTLs
await self._refresh_session_ttl(pipe=pipe)
results = await pipe.execute()
# Count actual deletions from lrem results (every 3rd result
# starting from 0)
removed_count = sum(
1
for i in range(
0,
len(msg_ids) * (2 + len(mark_keys)),
2 + len(mark_keys),
)
if results[i] > 0
)
return removed_count
[docs]
async def delete_by_mark(
self,
mark: str | list[str],
**kwargs: Any,
) -> int:
"""Remove messages from the storage by their marks.
Args:
mark (`str | list[str]`):
The mark(s) of the messages to be removed.
Returns:
`int`:
The number of messages removed.
"""
if isinstance(mark, str):
mark = [mark]
total_removed = 0
for m in mark:
mark_key = self._get_mark_key(m)
msg_ids = await self._client.lrange(mark_key, 0, -1)
if not msg_ids:
continue
# Remove messages by IDs
removed_count = await self.delete(
msg_ids,
)
total_removed += removed_count
# Delete the mark list
await self._client.delete(mark_key)
# Refresh TTLs
await self._refresh_session_ttl()
return total_removed
[docs]
async def clear(self) -> None:
"""Clear all messages belong to this session from the storage."""
msg_ids = await self._client.lrange(self._get_session_key(), 0, -1)
# Get all mark keys using SCAN
mark_keys = []
cursor = 0
while True:
cursor, keys = await self._client.scan(
cursor,
match=self._get_mark_pattern(),
count=50,
)
mark_keys.extend(keys)
if cursor == 0:
break
pipe = self._client.pipeline()
for msg_id in msg_ids:
# Remove the message data
await pipe.delete(self._get_message_key(msg_id))
# Delete the session list
await pipe.delete(self._get_session_key())
# Delete all mark lists
for mark_key in mark_keys:
await pipe.delete(mark_key)
await pipe.execute()
[docs]
async def size(self) -> int:
"""Get the number of messages in the storage.
Returns:
`int`:
The number of messages in the storage.
"""
size = await self._client.llen(self._get_session_key())
await self._refresh_session_ttl()
return size
[docs]
async def update_messages_mark(
self,
new_mark: str | None,
old_mark: str | None = None,
msg_ids: list[str] | None = None,
) -> int:
"""A unified method to update marks of messages in the storage (add,
remove, or change marks).
- If `msg_ids` is provided, the update will be applied to the messages
with the specified IDs.
- If `old_mark` is provided, the update will be applied to the
messages with the specified old mark. Otherwise, the `new_mark` will
be added to all messages (or those filtered by `msg_ids`).
- If `new_mark` is `None`, the mark will be removed from the messages.
Args:
new_mark (`str | None`, optional):
The new mark to set for the messages. If `None`, the mark
will be removed.
old_mark (`str | None`, optional):
The old mark to filter messages. If `None`, this constraint
is ignored.
msg_ids (`list[str] | None`, optional):
The list of message IDs to be updated. If `None`, this
constraint is ignored.
Returns:
`int`:
The number of messages updated.
"""
# Determine which message IDs to update
if old_mark is not None:
# Get message IDs from the old mark list
mark_msg_ids = await self._client.lrange(
self._get_mark_key(old_mark),
0,
-1,
)
else:
# Get all message IDs from the session
mark_msg_ids = await self._client.lrange(
self._get_session_key(),
0,
-1,
)
# Filter by msg_ids if provided
if msg_ids is not None:
msg_ids_set = set(msg_ids)
mark_msg_ids = [mid for mid in mark_msg_ids if mid in msg_ids_set]
if not mark_msg_ids:
return 0
# Get existing IDs in new_mark list once (if needed)
existing_ids_set = set()
new_mark_key = None
if new_mark is not None:
new_mark_key = self._get_mark_key(new_mark)
existing_ids = await self._client.lrange(new_mark_key, 0, -1)
existing_ids_set = set(existing_ids)
# Use pipeline for batch operations
pipe = self._client.pipeline()
updated_count = 0
for msg_id in mark_msg_ids:
# If new_mark is None, remove the old_mark
if new_mark is None:
if old_mark is not None:
await pipe.lrem(
self._get_mark_key(old_mark),
0,
msg_id,
)
updated_count += 1
else:
# Remove from old_mark list if applicable
if old_mark is not None:
await pipe.lrem(
self._get_mark_key(old_mark),
0,
msg_id,
)
# Add to new_mark list only if not already present
if msg_id not in existing_ids_set and new_mark_key is not None:
await pipe.rpush(new_mark_key, msg_id)
existing_ids_set.add(msg_id)
updated_count += 1
await self._refresh_session_ttl(pipe=pipe)
await pipe.execute()
return updated_count
[docs]
async def close(self, close_connection_pool: bool | None = None) -> None:
"""Close the Redis client connection.
Args:
close_connection_pool (`bool | None`, optional):
Decides whether to close the connection pool used by this
Redis client, overriding Redis.auto_close_connection_pool.
By default, let Redis.auto_close_connection_pool decide
whether to close the connection pool
"""
await self._client.aclose(close_connection_pool=close_connection_pool)
async def __aenter__(self) -> "RedisMemory":
"""Enter the async context manager.
Returns:
`RedisMemory`:
The memory instance itself.
"""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: Any,
) -> None:
"""Exit the async context manager and close the session.
Args:
exc_type (`type[BaseException] | None`):
The exception type if an exception was raised.
exc_value (`BaseException | None`):
The exception instance if an exception was raised.
traceback (`Any`):
The traceback object if an exception was raised.
"""
await self.close()