Source code for agentscope.rag._store._alibabacloud_mysql_store

# -*- coding: utf-8 -*-
"""The AlibabaCloud MySQL vector store implementation."""
import json
from typing import Any, Literal, TYPE_CHECKING

from .._reader import Document
from ._store_base import VDBStoreBase
from .._document import DocMetadata

from ..._utils._common import _map_text_to_uuid
from ...types import Embedding

if TYPE_CHECKING:
    from mysql.connector import MySQLConnection
else:
    MySQLConnection = "mysql.connector.MySQLConnection"


[docs] class AlibabaCloudMySQLStore(VDBStoreBase): """The AlibabaCloud MySQL vector store implementation, supporting vector search operations using MySQL's native vector functions. .. note:: AlibabaCloud MySQL vector search requires MySQL 8.0+. This implementation uses MySQL's native vector functions (VEC_DISTANCE_COSINE, VEC_DISTANCE_EUCLIDEAN, VEC_FROMTEXT) for efficient vector similarity search with ORDER BY in SQL. Only COSINE and EUCLIDEAN distance metrics are supported. .. note:: Requires mysql-connector-python package. Install with: `pip install mysql-connector-python` .. note:: For AlibabaCloud MySQL instances, ensure vector search plugin is enabled. Contact AlibabaCloud support if needed. """
[docs] def __init__( self, host: str, port: int, user: str, password: str, database: str, table_name: str, dimensions: int, distance: Literal["COSINE", "EUCLIDEAN"] = "COSINE", hnsw_m: int = 16, connection_kwargs: dict[str, Any] | None = None, ) -> None: """Initialize the AlibabaCloud MySQL vector store. Args: host (`str`): The hostname of the AlibabaCloud MySQL server. Example: "rm-xxxxx.mysql.rds.aliyuncs.com" port (`int`): The port number of the MySQL server (typically 3306). user (`str`): The username for authentication. password (`str`): The password for authentication. database (`str`): The database name to use. table_name (`str`): The name of the table to store the embeddings. dimensions (`int`): The dimension of the embeddings. distance (`Literal["COSINE", "EUCLIDEAN"]`, default to "COSINE"): The distance metric to use for similarity search. Can be one of "COSINE" (cosine similarity) or "EUCLIDEAN" (Euclidean distance). Defaults to "COSINE". hnsw_m (`int`, default to 16): The M parameter for HNSW vector index, which controls the number of bi-directional links created for each node during construction. Higher values create denser graphs with better recall but use more memory. Typical values range from 4 to 64. Defaults to 16. connection_kwargs (`dict[str, Any] | None`, optional): Other keyword arguments for the MySQL connector. Example: {"ssl_ca": "/path/to/ca.pem", "charset": "utf8mb4"} """ try: import mysql.connector except ImportError as e: raise ImportError( "MySQL connector is not installed. Please install it with " "`pip install mysql-connector-python`.", ) from e connection_kwargs = connection_kwargs or {} # Initialize connection parameters self.connection_params = { "host": host, "port": port, "user": user, "password": password, "database": database, **connection_kwargs, } self.table_name = table_name self.dimensions = dimensions self.distance = distance self.hnsw_m = hnsw_m # Initialize connection self._conn = mysql.connector.connect(**self.connection_params) self._cursor = self._conn.cursor(dictionary=True)
def _get_distance_function(self) -> str: """Get the MySQL native vector distance function name based on the distance metric. Returns: `str`: The SQL vector distance function name. """ if self.distance == "COSINE": return "VEC_DISTANCE_COSINE" elif self.distance == "EUCLIDEAN": return "VEC_DISTANCE_EUCLIDEAN" else: raise ValueError( f"Unsupported distance metric: {self.distance}. " f"AlibabaCloud MySQL only supports 'COSINE' and 'EUCLIDEAN'.", ) def _format_vector_for_sql(self, vector: list[float]) -> str: """Format a vector as a string for MySQL VEC_FROMTEXT function. Args: vector (`list[float]`): The vector to format. Returns: `str`: The formatted vector string like "[1,2,3,4]". """ return "[" + ",".join(map(str, vector)) + "]" async def _validate_table(self) -> None: """Validate the table exists, if not, create it. Creates a table with VECTOR type columns and automatically creates a vector index based on the specified distance metric using HNSW algorithm. """ # Get distance metric in lowercase for SQL distance_metric = self.distance.lower() # Create table with VECTOR INDEX in a single statement # VECTOR(dimensions) type is available in AlibabaCloud MySQL 8.0+ # VECTOR INDEX uses HNSW algorithm with M parameter for # graph connectivity # IF NOT EXISTS prevents errors if table already exists create_table_sql = f""" CREATE TABLE IF NOT EXISTS {self.table_name} ( id VARCHAR(255) PRIMARY KEY, embedding VECTOR({self.dimensions}) NOT NULL, doc_id VARCHAR(255) NOT NULL, chunk_id INT NOT NULL, content TEXT NOT NULL, total_chunks INT NOT NULL, INDEX idx_doc_id (doc_id), INDEX idx_chunk_id (chunk_id), VECTOR INDEX (embedding) M={self.hnsw_m} DISTANCE={distance_metric} ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """ self._cursor.execute(create_table_sql) self._conn.commit()
[docs] async def add(self, documents: list[Document], **kwargs: Any) -> None: """Add embeddings to the AlibabaCloud MySQL vector store. Args: documents (`list[Document]`): A list of embedding records to be recorded in the MySQL store. **kwargs (`Any`): Additional arguments for the insert operation. """ await self._validate_table() # Prepare data for insertion for doc in documents: # Generate a unique ID unique_string = json.dumps( { "doc_id": doc.metadata.doc_id, "chunk_id": doc.metadata.chunk_id, "content": doc.metadata.content, }, ensure_ascii=False, ) unique_id = _map_text_to_uuid(unique_string) # Format vector for MySQL VEC_FROMTEXT if doc.embedding is None: raise ValueError( f"Document embedding cannot be None for doc_id: " f"{doc.metadata.doc_id}", ) vector_text = self._format_vector_for_sql(doc.embedding) # Insert data using VEC_FROMTEXT to convert text to vector insert_sql = f""" INSERT INTO {self.table_name} (id, embedding, doc_id, chunk_id, content, total_chunks) VALUES (%s, VEC_FROMTEXT(%s), %s, %s, %s, %s) ON DUPLICATE KEY UPDATE embedding = VEC_FROMTEXT(%s), doc_id = VALUES(doc_id), chunk_id = VALUES(chunk_id), content = VALUES(content), total_chunks = VALUES(total_chunks) """ # Serialize content to JSON if it's not a string content_str = ( doc.metadata.content if isinstance(doc.metadata.content, str) else json.dumps(doc.metadata.content, ensure_ascii=False) ) self._cursor.execute( insert_sql, ( unique_id, vector_text, doc.metadata.doc_id, doc.metadata.chunk_id, content_str, doc.metadata.total_chunks, vector_text, # For ON DUPLICATE KEY UPDATE ), ) self._conn.commit()
[docs] async def search( self, query_embedding: Embedding, limit: int, score_threshold: float | None = None, **kwargs: Any, ) -> list[Document]: """Search relevant documents from the AlibabaCloud MySQL vector store. Args: query_embedding (`Embedding`): The embedding of the query text. limit (`int`): The number of relevant documents to retrieve. score_threshold (`float | None`, optional): The minimum similarity score threshold to filter the results. Score is calculated as 1 - distance, where higher scores indicate higher similarity. Only documents with score >= score_threshold will be returned. **kwargs (`Any`): Additional arguments for the search operation. - filter (`str`): WHERE clause to filter the search results. """ # Format query vector for MySQL VEC_FROMTEXT query_vector_text = self._format_vector_for_sql(query_embedding) # Get the distance function distance_func = self._get_distance_function() # Build WHERE clause where_conditions = [] if "filter" in kwargs and kwargs["filter"]: where_conditions.append(kwargs["filter"]) # Add score threshold condition if specified # Score is calculated as 1 - distance, so higher scores # indicate higher similarity # To filter by score_threshold, we need: # 1.0 - distance >= score_threshold if score_threshold is not None: where_conditions.append( f"{distance_func}(embedding, VEC_FROMTEXT(%s)) <= %s", ) where_clause = "" if where_conditions: where_clause = "WHERE " + " AND ".join(where_conditions) # Build and execute the search query using MySQL native vector # functions with ORDER BY for efficient sorting search_sql = f""" SELECT id, doc_id, chunk_id, content, total_chunks, {distance_func}(embedding, VEC_FROMTEXT(%s)) as distance FROM {self.table_name} {where_clause} ORDER BY distance ASC LIMIT %s """ # Prepare parameters params: list[str | float | int] = [query_vector_text] if score_threshold is not None: # Convert score threshold to distance threshold: # distance <= 1.0 - score params.extend([query_vector_text, 1.0 - score_threshold]) params.append(limit) self._cursor.execute(search_sql, params) results = self._cursor.fetchall() # Process results collected_res = [] for row in results: # Deserialize content from JSON if it's a JSON string content = row["content"] try: content = json.loads(content) except (json.JSONDecodeError, TypeError): # If it's not valid JSON, keep it as is (plain string) pass doc_metadata = DocMetadata( content=content, doc_id=row["doc_id"], chunk_id=row["chunk_id"], total_chunks=row["total_chunks"], ) # Create Document # Convert distance to score: score = 1 - distance # Higher scores indicate higher similarity collected_res.append( Document( embedding=None, # Vector not returned for efficiency score=1.0 - row["distance"], metadata=doc_metadata, ), ) return collected_res
[docs] async def delete( self, ids: list[str] | None = None, filter: str | None = None, # pylint: disable=redefined-builtin **kwargs: Any, ) -> None: """Delete documents from the AlibabaCloud MySQL vector store. Args: ids (`list[str] | None`, optional): List of entity IDs to delete. filter (`str | None`, optional): WHERE clause expression to filter documents to delete. **kwargs (`Any`): Additional arguments for the delete operation. """ if ids is None and filter is None: raise ValueError( "Either ids or filter must be provided for deletion.", ) if ids is not None: # Delete by IDs placeholders = ",".join(["%s"] * len(ids)) delete_sql = ( f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})" ) self._cursor.execute(delete_sql, ids) elif filter is not None: # Delete by filter delete_sql = f"DELETE FROM {self.table_name} WHERE {filter}" self._cursor.execute(delete_sql) self._conn.commit()
[docs] def get_client(self) -> MySQLConnection: """Get the underlying MySQL connection, so that developers can access the full functionality of AlibabaCloud MySQL. Returns: `MySQLConnection`: The underlying MySQL connection. """ return self._conn
[docs] def close(self) -> None: """Close the database connection.""" if self._cursor: self._cursor.close() if self._conn: self._conn.close()
def __del__(self) -> None: """Destructor to ensure connection is closed.""" try: self.close() except Exception: # pylint: disable=broad-except pass