Source code for agentscope.rag.knowledge

# -*- coding: utf-8 -*-
"""
Base class module for retrieval augmented generation (RAG).
To accommodate the RAG process of different packages,
we abstract the RAG process into four stages:
- data loading: loading data into memory for following processing;
- data indexing and storage: document chunking, embedding generation,
and off-load the data into VDB;
- data retrieval: taking a query and return a batch of documents or
document chunks;
- post-processing of the retrieved data: use the retrieved data to
generate an answer.
"""

import importlib
from abc import ABC, abstractmethod
from typing import Any, Optional
from loguru import logger
from agentscope.models import ModelWrapperBase


[docs] class Knowledge(ABC): """ Base class for RAG, CANNOT be instantiated directly """
[docs] def __init__( self, knowledge_id: str, emb_model: Any = None, knowledge_config: Optional[dict] = None, model: Optional[ModelWrapperBase] = None, **kwargs: Any, ) -> None: # pylint: disable=unused-argument """ initialize the knowledge component Args: knowledge_id (str): The id of the knowledge unit. emb_model (ModelWrapperBase): The embedding model used for generate embeddings knowledge_config (dict): The configuration to generate or load the index. """ self.knowledge_id = knowledge_id self.emb_model = emb_model self.knowledge_config = knowledge_config or {} self.postprocessing_model = model
@abstractmethod def _init_rag( self, **kwargs: Any, ) -> Any: """ Initiate the RAG module. """
[docs] @abstractmethod def retrieve( self, query: Any, similarity_top_k: int = None, to_list_strs: bool = False, **kwargs: Any, ) -> list[Any]: """ retrieve list of content from database (vector stored index) to memory Args: query (Any): query for retrieval similarity_top_k (int): the number of most similar data returned by the retriever. to_list_strs (bool): whether return a list of str Returns: return a list with retrieved documents (in strings) """
[docs] def post_processing( self, retrieved_docs: list[str], prompt: str, **kwargs: Any, ) -> Any: """ A default solution for post-processing function, generates answer based on the retrieved documents. Args: retrieved_docs (list[str]): list of retrieved documents prompt (str): prompt for LLM generating answer with the retrieved documents Returns: Any: a synthesized answer from LLM with retrieved documents Example: self.postprocessing_model(prompt.format(retrieved_docs)) """ assert self.postprocessing_model prompt = prompt.format("\n".join(retrieved_docs)) return self.postprocessing_model(prompt, **kwargs).text
def _prepare_args_from_config(self, config: dict) -> Any: """ Helper function to build objects in RAG classes. Args: config (dict): a dictionary containing configurations Returns: Any: an object that is parsed/built to be an element of input to the function of RAG module. """ if not isinstance(config, dict): return config if "create_object" in config: # if a term in args is an object, # recursively create object with args from config module_name = config.get("module", "") class_name = config.get("class", "") init_args = config.get("init_args", {}) try: cur_module = importlib.import_module(module_name) cur_class = getattr(cur_module, class_name) init_args = self._prepare_args_from_config(init_args) logger.info( f"load and build object: {class_name}", ) return cur_class(**init_args) except ImportError as exc_inner: logger.error( f"Fail to load class {class_name} " f"from module {module_name}", ) raise ImportError( f"Fail to load class {class_name} " f"from module {module_name}", ) from exc_inner else: prepared_args = {} for key, value in config.items(): if isinstance(value, list): prepared_args[key] = [] for c in value: prepared_args[key].append( self._prepare_args_from_config(c), ) elif isinstance(value, dict): prepared_args[key] = self._prepare_args_from_config(value) else: prepared_args[key] = value return prepared_args