# -*- coding: utf-8 -*-
"""The SQLAlchemy database storage module, which supports storing messages in
a SQL database using SQLAlchemy ORM (e.g., SQLite, PostgreSQL, MySQL)."""
from typing import Any
from sqlalchemy import (
Column,
String,
JSON,
BigInteger,
ForeignKey,
select,
delete,
update,
func,
)
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
)
from sqlalchemy.orm import declarative_base, relationship
from ._base import MemoryBase
from ...message import Msg
Base: Any = declarative_base()
[docs]
class AsyncSQLAlchemyMemory(MemoryBase):
"""The SQLAlchemy memory storage class for storing messages in a SQL
database using SQLAlchemy ORM, such as SQLite, PostgreSQL, MySQL, etc.
.. note:: All the operations in this class are within a specific session
and user context, identified by `session_id` and `user_id`. Cross-session
or cross-user operations are not supported. For example, the
`remove_messages` method will only remove messages that belong to the
specified `session_id` and `user_id`.
"""
[docs]
class MessageTable(Base):
"""The default message table definition."""
__tablename__ = "message"
"""The table name"""
id = Column(String(255), primary_key=True)
"""The id column, we use the f"{user_id}-{session_id}-{message_id}"
as the primary key to ensure uniqueness across users and sessions."""
msg = Column(JSON, nullable=False)
"""The message JSON content column"""
session = relationship(
"SessionTable",
back_populates="messages",
)
"""The foreign key to the session id relationship"""
session_id = Column(
String(255),
ForeignKey("session.id"),
nullable=False,
)
"""The foreign key to the session id"""
index = Column(BigInteger, nullable=False, index=True)
"""The index column for ordering messages, so that we can retrieve
messages in the order they were added."""
[docs]
class MessageMarkTable(Base):
"""The default message mark table definition."""
__tablename__ = "message_mark"
"""The table name"""
msg_id = Column(
String(255),
ForeignKey("message.id", ondelete="CASCADE"),
primary_key=True,
)
"""The message id column"""
mark = Column(String(255), primary_key=True)
"""The mark column"""
[docs]
class SessionTable(Base):
"""The default session table definition."""
__tablename__ = "session"
"""The table name"""
id = Column(String(255), primary_key=True)
"""The session id column"""
user = relationship("UserTable", back_populates="sessions")
"""The foreign key to the user id relationship"""
user_id = Column(String(255), ForeignKey("users.id"), nullable=False)
"""The foreign key to the user id"""
messages = relationship("MessageTable", back_populates="session")
"""The relationship to messages"""
[docs]
class UserTable(Base):
"""The default user table definition."""
__tablename__ = "users"
"""The table name"""
id = Column(String(255), primary_key=True)
"""The user id column"""
sessions = relationship("SessionTable", back_populates="user")
"""The relationship to sessions"""
[docs]
def __init__(
self,
engine_or_session: AsyncEngine | AsyncSession,
session_id: str | None = None,
user_id: str | None = None,
) -> None:
"""Initialize the SqlAlchemyDBStorage with a SQLAlchemy session.
Args:
engine_or_session (`AsyncEngine | AsyncSession`):
The SQLAlchemy asynchronous engine or session to use for
database operations. If you're using a connection pool, maybe
you want to pass in an `AsyncSession` instance.
session_id (`str | None`, optional):
The session ID for the messages. If `None`, a default session
ID will be used.
user_id (`str | None`, optional):
The user ID for the messages. If `None`, a default user ID
will be used.
Raises:
`ValueError`:
If the `engine` parameter is not an instance of
`sqlalchemy.ext.asyncio.AsyncEngine` or `sqlalchemy.
ext.asyncio.AsyncSession`.
"""
super().__init__()
self._db_session: AsyncSession | None = None
if isinstance(engine_or_session, AsyncEngine):
self._session_factory = async_sessionmaker(
bind=engine_or_session,
expire_on_commit=False,
)
elif isinstance(engine_or_session, AsyncSession):
self._session_factory = None
self._db_session = engine_or_session
else:
raise ValueError(
"The 'engine_or_session' parameter must be an instance of "
"sqlalchemy.ext.asyncio.AsyncEngine.",
)
self.session_id = session_id or "default_session"
self.user_id = user_id or "default_user"
# Flag to track if tables and records have been initialized
self._initialized = False
def _make_message_id(self, msg_id: str) -> str:
"""Generate a composite primary key for a message.
Args:
msg_id (`str`):
The original message ID.
Returns:
`str`:
The composite primary key in the format
"{user_id}-{session_id}-{message_id}".
"""
return f"{self.user_id}-{self.session_id}-{msg_id}"
@property
def session(self) -> AsyncSession:
"""Get the current database session, creating one if it doesn't exist.
Returns:
`AsyncSession`:
The current database session.
Note:
- If an external session was provided, it will be returned as-is
- If using internal session factory, a new session will be created
if the current one is None or inactive, and _initialized flag
will be reset to ensure proper re-initialization
"""
# External session: return as-is (managed by caller)
if self._session_factory is None:
return self._db_session
# Internal session: check validity and recreate if needed
if self._db_session is None or not self._db_session.is_active:
self._db_session = self._session_factory()
# Reset initialized flag when creating new session
self._initialized = False
return self._db_session
async def _create_table(self) -> None:
"""Create tables in database."""
# Skip if already initialized
if self._initialized:
return
# Obtain the engine first
engine: AsyncEngine = self.session.bind
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Track if we need to commit
needs_commit = False
# Create user record if not exists
result = await self.session.execute(
select(self.UserTable).filter(
self.UserTable.id == self.user_id,
),
)
user_record = result.scalar_one_or_none()
if user_record is None:
user_record = self.UserTable(
id=self.user_id,
)
self.session.add(user_record)
needs_commit = True
# Create session record if not exists
result = await self.session.execute(
select(self.SessionTable).filter(
self.SessionTable.id == self.session_id,
),
)
session_record = result.scalar_one_or_none()
if session_record is None:
session_record = self.SessionTable(
id=self.session_id,
user_id=self.user_id,
)
self.session.add(session_record)
needs_commit = True
# Commit once if any records were added
if needs_commit:
await self.session.commit()
# Mark as initialized
self._initialized = True
[docs]
async def get_memory(
self,
mark: str | None = None,
exclude_mark: str | None = None,
prepend_summary: bool = True,
**kwargs: Any,
) -> list[Msg]:
"""Get the messages from the memory by mark (if provided). Otherwise,
get all messages.
.. note:: If `mark` and `exclude_mark` are both provided, the messages
will be filtered by both arguments.
.. note:: `mark` and `exclude_mark` should not overlap.
Args:
mark (`str | None`, optional):
The mark to filter messages. If `None`, return all messages.
exclude_mark (`str | None`, optional):
The mark to exclude messages. If provided, messages with
this mark will be excluded from the results.
prepend_summary (`bool`, defaults to True):
Whether to prepend the compressed summary as a message
Raises:
`TypeError`:
If the provided mark is not a string or None.
Returns:
`list[Msg]`:
The list of messages retrieved from the storage.
"""
# Type checks
if mark is not None and not isinstance(mark, str):
raise TypeError(
f"The mark should be a string or None, but got {type(mark)}.",
)
if exclude_mark is not None and not isinstance(exclude_mark, str):
raise TypeError(
f"The exclude_mark should be a string or None, but got "
f"{type(exclude_mark)}.",
)
await self._create_table()
# Step 1: First filter by session_id to narrow down the dataset
# This ensures the database uses the session_id index first
base_query = select(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
)
# Step 2: Apply mark filtering if provided
if mark:
# Join with mark table only on the session-filtered messages
base_query = base_query.join(
self.MessageMarkTable,
self.MessageTable.id == self.MessageMarkTable.msg_id,
).filter(
self.MessageMarkTable.mark == mark,
)
# Step 3: Apply exclude_mark filtering if provided
if exclude_mark:
# Use a subquery to find message IDs with the exclude_mark
# within the current session only
exclude_subquery = (
select(self.MessageMarkTable.msg_id)
.filter(
self.MessageMarkTable.msg_id.in_(
select(self.MessageTable.id).filter(
self.MessageTable.session_id == self.session_id,
),
),
self.MessageMarkTable.mark == exclude_mark,
)
.scalar_subquery()
)
# Exclude messages whose IDs are in the subquery
base_query = base_query.filter(
self.MessageTable.id.notin_(exclude_subquery),
)
# Step 4: Order by index to maintain message order
query = base_query.order_by(self.MessageTable.index)
result = await self.session.execute(query)
results = result.scalars().all()
msgs = [Msg.from_dict(result.msg) for result in results]
if prepend_summary and self._compressed_summary:
return [
Msg(
"user",
self._compressed_summary,
"user",
),
*msgs,
]
return msgs
[docs]
async def add(
self,
memories: Msg | list[Msg] | None,
marks: str | list[str] | None = None,
skip_duplicated: bool = True,
**kwargs: Any,
) -> None:
"""Add message into the storage with the given mark (if provided).
Args:
memories (`Msg | list[Msg] | None`):
The message(s) to be added.
marks (`str | list[str] | None`, optional):
The mark(s) to associate with the message(s). If `None`, no
mark is associated.
skip_duplicated (`bool`, defaults to `True`):
If `True`, skip messages with duplicate IDs that already exist
in the storage. If `False`, raise an `IntegrityError` when
attempting to add a message with an existing ID.
Raises:
`IntegrityError`:
If a message with the same ID already exists in the storage
and `skip_duplicated` is set to `False`.
"""
if memories is None:
return
# Type checking
if isinstance(memories, Msg):
memories = [memories]
elif not (
isinstance(memories, list)
and all(isinstance(_, Msg) for _ in memories)
):
raise TypeError(
"The 'memories' parameter must be a Msg instance or a list of "
f"Msg instances, but got {type(memories)}.",
)
if isinstance(marks, str):
marks = [marks]
elif marks is not None and not (
isinstance(marks, list) and all(isinstance(m, str) for m in marks)
):
raise TypeError(
"The 'marks' parameter must be a string or a list of strings, "
f"but got {type(marks)}.",
)
# Create table if not exists
await self._create_table()
# If skip_duplicated is True, filter out existing messages
messages_to_add = memories
if skip_duplicated:
existing_msg_ids = set()
result = await self.session.execute(
select(self.MessageTable.id).filter(
self.MessageTable.id.in_(
[self._make_message_id(m.id) for m in memories],
),
),
)
existing_msg_ids = {row[0] for row in result.fetchall()}
messages_to_add = [
m
for m in memories
if self._make_message_id(m.id) not in existing_msg_ids
]
# If all messages are duplicates, return early
if not messages_to_add:
return
# Get the starting index once to avoid race conditions
start_index = await self._get_next_index()
# Add messages to message table
for i, m in enumerate(messages_to_add):
message_record = self.MessageTable(
id=self._make_message_id(m.id),
msg=m.to_dict(),
session_id=self.session_id,
index=start_index + i,
)
self.session.add(message_record)
# Create mark records if marks are provided (use bulk insert)
if marks:
mark_records = [
{"msg_id": self._make_message_id(msg.id), "mark": mark}
for msg in messages_to_add
for mark in marks
]
if mark_records:
if skip_duplicated:
# Query existing mark combinations to avoid duplicates
result = await self.session.execute(
select(
self.MessageMarkTable.msg_id,
self.MessageMarkTable.mark,
),
)
existing_marks = {
(row[0], row[1]) for row in result.fetchall()
}
# Filter out existing mark combinations
mark_records = [
r
for r in mark_records
if (r["msg_id"], r["mark"]) not in existing_marks
]
if mark_records:
await self.session.run_sync(
lambda session: session.bulk_insert_mappings(
self.MessageMarkTable,
mark_records,
),
)
await self.session.commit()
async def _get_next_index(self) -> int:
"""Get the next index for a new message in the current session.
Returns:
`int`:
The next index value.
"""
result = await self.session.execute(
select(self.MessageTable.index)
.filter(self.MessageTable.session_id == self.session_id)
.order_by(self.MessageTable.index.desc())
.limit(1),
)
max_index = result.scalar_one_or_none()
return (max_index + 1) if max_index is not None else 0
[docs]
async def size(self) -> int:
"""Get the size of the messages in the storage."""
result = await self.session.execute(
select(func.count(self.MessageTable.id)).filter(
self.MessageTable.session_id == self.session_id,
),
)
return result.scalar_one()
[docs]
async def clear(self) -> None:
"""Clear all messages from the storage."""
# Delete all marks for messages in this session
await self.session.execute(
delete(self.MessageMarkTable).where(
self.MessageMarkTable.msg_id.in_(
select(self.MessageTable.id).filter(
self.MessageTable.session_id == self.session_id,
),
),
),
)
# Then delete all messages
await self.session.execute(
delete(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
),
)
await self.session.commit()
[docs]
async def delete_by_mark(
self,
mark: str | list[str],
**kwargs: Any,
) -> int:
"""Remove messages from the storage by their marks.
Args:
mark (`str | list[str]`):
The mark(s) of the messages to be removed.
Returns:
`int`:
The number of messages removed.
"""
if isinstance(mark, str):
mark = [mark]
# First, find message IDs that have the specified marks
query = (
select(self.MessageTable.id)
.join(
self.MessageMarkTable,
self.MessageTable.id == self.MessageMarkTable.msg_id,
)
.filter(
self.MessageTable.session_id == self.session_id,
self.MessageMarkTable.mark.in_(mark),
)
)
result = await self.session.execute(query)
msg_ids = [row[0] for row in result.all()]
if not msg_ids:
return 0
# Store the count before deletion
deleted_count = len(msg_ids)
# Delete marks first
await self.session.execute(
delete(self.MessageMarkTable).filter(
self.MessageMarkTable.msg_id.in_(msg_ids),
),
)
# Then delete the messages
await self.session.execute(
delete(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
self.MessageTable.id.in_(msg_ids),
),
)
await self.session.commit()
return deleted_count
[docs]
async def delete(
self,
msg_ids: list[str],
**kwargs: Any,
) -> int:
"""Remove message(s) from the storage by their IDs.
.. note:: Although MessageMarkTable has CASCADE delete on foreign key,
we explicitly delete marks first for reliability across all database
engines and configurations. SQLAlchemy's bulk delete bypasses
ORM-level cascades, and SQLite requires foreign keys to be
explicitly enabled.
Args:
msg_ids (`list[str]`):
The list of message IDs to be removed.
Returns:
`int`:
The number of messages removed.
"""
# Convert to composite keys
composite_ids = [self._make_message_id(msg_id) for msg_id in msg_ids]
if not composite_ids:
return 0
# Store the count before deletion
deleted_count = len(composite_ids)
# Delete related marks first (explicit cleanup for reliability)
await self.session.execute(
delete(self.MessageMarkTable).filter(
self.MessageMarkTable.msg_id.in_(composite_ids),
),
)
# Then delete the messages
await self.session.execute(
delete(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
self.MessageTable.id.in_(composite_ids),
),
)
await self.session.commit()
return deleted_count
[docs]
async def update_messages_mark(
self,
new_mark: str | None,
old_mark: str | None = None,
msg_ids: list[str] | None = None,
) -> int:
"""A unified method to update marks of messages in the storage (add,
remove, or change marks).
- If `msg_ids` is provided, the update will be applied to the messages
with the specified IDs.
- If `old_mark` is provided, the update will be applied to the
messages with the specified old mark. Otherwise, the `new_mark` will
be added to all messages (or those filtered by `msg_ids`).
- If `new_mark` is `None`, the mark will be removed from the messages.
Args:
new_mark (`str | None`, optional):
The new mark to set for the messages. If `None`, the mark
will be removed.
old_mark (`str | None`, optional):
The old mark to filter messages. If `None`, this constraint
is ignored.
msg_ids (`list[str] | None`, optional):
The list of message IDs to be updated. If `None`, this
constraint is ignored.
Returns:
`int`:
The number of messages updated.
"""
# Type checking
if new_mark is not None and not isinstance(new_mark, str):
raise ValueError(
f"The 'new_mark' parameter must be a string or None, "
f"but got {type(new_mark)}.",
)
if old_mark is not None and not isinstance(old_mark, str):
raise ValueError(
f"The 'old_mark' parameter must be a string or None, "
f"but got {type(old_mark)}.",
)
if msg_ids is not None and not (
isinstance(msg_ids, list)
and all(isinstance(_, str) for _ in msg_ids)
):
raise ValueError(
f"The 'msg_ids' parameter must be a list of strings or None, "
f"but got {type(msg_ids)}.",
)
# First obtain the message ids that belong to this session
query = select(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
)
# Filter by msg_ids if provided
if msg_ids is not None:
# Convert to composite keys
composite_ids = [
self._make_message_id(msg_id) for msg_id in msg_ids
]
query = query.filter(self.MessageTable.id.in_(composite_ids))
# Filter by old_mark if provided
if old_mark is not None:
query = query.join(
self.MessageMarkTable,
self.MessageTable.id == self.MessageMarkTable.msg_id,
).filter(self.MessageMarkTable.mark == old_mark)
# Obtain the message records
result = await self.session.execute(query)
msg_ids = [str(_.id) for _ in result.scalars().all()]
# Return early if no messages found
if not msg_ids:
return 0
if new_mark:
if old_mark:
# Replace old_mark with new_mark
return await self._replace_message_mark(
msg_ids=msg_ids,
old_mark=old_mark,
new_mark=new_mark,
)
# Add new_mark to the messages
return await self._add_message_mark(
msg_ids=msg_ids,
mark=new_mark,
)
# Remove all marks from the messages
return await self._remove_message_mark(
msg_ids=msg_ids,
old_mark=old_mark,
)
async def _replace_message_mark(
self,
msg_ids: list[str],
old_mark: str,
new_mark: str,
) -> int:
"""Replace the old mark with the new mark for the given messages by
updating records in the message_mark table.
Args:
msg_ids (`list[str]`):
The list of message IDs to be updated.
old_mark (`str`):
The old mark to be replaced.
new_mark (`str`):
The new mark to be set.
Returns:
`int`:
The number of messages updated.
"""
await self.session.execute(
update(self.MessageMarkTable)
.filter(
self.MessageMarkTable.msg_id.in_(msg_ids),
self.MessageMarkTable.mark == old_mark,
)
.values(mark=new_mark),
)
await self.session.commit()
return len(msg_ids)
async def _add_message_mark(self, msg_ids: list[str], mark: str) -> int:
"""Mark the messages with the given mark by adding records to the
message_mark table.
Args:
msg_ids (`list[str]`):
The list of message IDs to be marked.
mark (`str`):
The mark to be added to the messages.
Returns:
`int`:
The number of messages marked.
"""
# Use bulk insert for better performance
mark_records = [{"msg_id": msg_id, "mark": mark} for msg_id in msg_ids]
if mark_records:
await self.session.run_sync(
lambda session: session.bulk_insert_mappings(
self.MessageMarkTable,
mark_records,
),
)
await self.session.commit()
return len(msg_ids)
async def _remove_message_mark(
self,
msg_ids: list[str],
old_mark: str | None,
) -> int:
"""Remove marks from the messages by deleting records from the
message_mark table.
Args:
msg_ids (`list[str]`):
The list of message IDs to be unmarked.
old_mark (`str | None`):
The old mark to be removed. If `None`, all marks will be
removed from the messages.
Returns:
`int`:
The number of messages unmarked.
"""
delete_query = delete(self.MessageMarkTable).filter(
self.MessageMarkTable.msg_id.in_(msg_ids),
)
if old_mark:
delete_query = delete_query.filter(
self.MessageMarkTable.mark == old_mark,
)
await self.session.execute(delete_query)
await self.session.commit()
return len(msg_ids)
[docs]
async def close(self) -> None:
"""Close the database session."""
if self._db_session and self._db_session.is_active:
await self._db_session.close()
self._db_session = None
self._initialized = False
async def __aenter__(self) -> "AsyncSQLAlchemyMemory":
"""Enter the async context manager.
Returns:
`AsyncSQLAlchemyMemory`:
The memory instance itself.
"""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: Any,
) -> None:
"""Exit the async context manager and close the session.
Args:
exc_type (`type[BaseException] | None`):
The exception type if an exception was raised.
exc_value (`BaseException | None`):
The exception instance if an exception was raised.
traceback (`Any`):
The traceback object if an exception was raised.
"""
await self.close()