# -*- coding: utf-8 -*-
"""A file embedding cache implementation for storing and retrieving
embeddings in binary files."""
import hashlib
import json
import os
from typing import Any, List
import numpy as np
from ._cache_base import EmbeddingCacheBase
from .._logging import logger
from ..types import (
Embedding,
JSONSerializableObject,
)
[docs]
class FileEmbeddingCache(EmbeddingCacheBase):
"""The embedding cache class that stores each embeddings vector in
binary files."""
[docs]
def __init__(
self,
cache_dir: str = "./.cache/embeddings",
max_file_number: int | None = None,
max_cache_size: int | None = None,
) -> None:
"""Initialize the file embedding cache class.
Args:
cache_dir (`str`, defaults to `"./.cache/embeddings"`):
The directory to store the embedding files.
max_file_number (`int | None`, defaults to `None`):
The maximum number of files to keep in the cache directory. If
exceeded, the oldest files will be removed.
max_cache_size (`int | None`, defaults to `None`):
The maximum size of the cache directory in MB. If exceeded,
the oldest files will be removed until the size is within the
limit.
"""
self._cache_dir = os.path.abspath(cache_dir)
self.max_file_number = max_file_number
self.max_cache_size = max_cache_size
@property
def cache_dir(self) -> str:
"""The cache directory where the embedding files are stored."""
if not os.path.exists(self._cache_dir):
os.makedirs(self._cache_dir, exist_ok=True)
return self._cache_dir
[docs]
async def store(
self,
embeddings: List[Embedding],
identifier: JSONSerializableObject,
overwrite: bool = False,
**kwargs: Any,
) -> None:
"""Store the embeddings with the given identifier.
Args:
embeddings (`List[Embedding]`):
The embeddings to store.
identifier (`JSONSerializableObject`):
The identifier to distinguish the embeddings, which will be
used to generate a hashable filename, so it should be
JSON serializable (e.g. a string, number, list, dict).
overwrite (`bool`, defaults to `False`):
Whether to overwrite existing embeddings with the same
identifier. If `True`, existing embeddings will be replaced.
"""
filename = self._get_filename(identifier)
path_file = os.path.join(self.cache_dir, filename)
if os.path.exists(path_file):
if not os.path.isfile(path_file):
raise RuntimeError(
f"Path {path_file} exists but is not a file.",
)
if overwrite:
np.save(path_file, embeddings)
await self._maintain_cache_dir()
else:
np.save(path_file, embeddings)
await self._maintain_cache_dir()
[docs]
async def retrieve(
self,
identifier: JSONSerializableObject,
) -> List[Embedding] | None:
"""Retrieve the embeddings with the given identifier. If not found,
return `None`.
Args:
identifier (`JSONSerializableObject`):
The identifier to retrieve the embeddings, which will be
used to generate a hashable filename, so it should be
JSON serializable (e.g. a string, number, list, dict).
"""
filename = self._get_filename(identifier)
path_file = os.path.join(self.cache_dir, filename)
if os.path.exists(path_file):
return np.load(os.path.join(self.cache_dir, filename)).tolist()
return None
[docs]
async def remove(self, identifier: JSONSerializableObject) -> None:
"""Remove the embeddings with the given identifier.
Args:
identifier (`JSONSerializableObject`):
The identifiers to remove the embeddings, which will be
used to generate a hashable filename, so it should be
JSON serializable (e.g. a string, number, list, dict).
"""
filename = self._get_filename(identifier)
path_file = os.path.join(self.cache_dir, filename)
if os.path.exists(path_file):
os.remove(path_file)
else:
raise FileNotFoundError(f"File {path_file} does not exist.")
[docs]
async def clear(self) -> None:
"""Clear the cache directory by removing all files."""
for filename in os.listdir(self.cache_dir):
if filename.endswith(".npy"):
os.remove(os.path.join(self.cache_dir, filename))
def _get_cache_size(self) -> float:
"""Get the current size of the cache directory in MB."""
total_size = 0
for filename in os.listdir(self.cache_dir):
if filename.endswith(".npy"):
path_file = os.path.join(self.cache_dir, filename)
if os.path.isfile(path_file):
total_size += os.path.getsize(path_file)
return total_size / (1024.0 * 1024.0)
@staticmethod
def _get_filename(identifier: JSONSerializableObject) -> str:
"""Generate a filename based on the identifier."""
json_str = json.dumps(identifier, ensure_ascii=False)
return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + ".npy"
async def _maintain_cache_dir(self) -> None:
"""Maintain the cache directory by removing old files if the number of
files exceeds the maximum limit or if the cache size exceeds the
maximum size."""
files = [
(_.name, _.stat().st_mtime)
for _ in os.scandir(self.cache_dir)
if _.is_file() and _.name.endswith(".npy")
]
files.sort(key=lambda x: x[1])
if self.max_file_number and len(files) > self.max_file_number:
for file_name, _ in files[: 0 - self.max_file_number]:
os.remove(os.path.join(self.cache_dir, file_name))
logger.info(
"Remove cached embedding file %s for limited number "
"of files (%d).",
file_name,
self.max_file_number,
)
files = files[0 - self.max_file_number :]
if (
self.max_cache_size is not None
and self._get_cache_size() > self.max_cache_size
):
removed_files = []
for filename, _ in files:
os.remove(os.path.join(self.cache_dir, filename))
removed_files.append(filename)
if self._get_cache_size() <= self.max_cache_size:
break
if removed_files:
logger.info(
"Remove %d cached embedding file(s) for limited "
"cache size (%d MB).",
len(removed_files),
self.max_cache_size,
)