Source code for agentscope.service.retrieval.retrieval_from_list

# -*- coding: utf-8 -*-
"""Retrieve service working with memory specially."""
from typing import Callable, Optional, Any, Sequence
from loguru import logger

from agentscope.service.service_response import ServiceResponse
from agentscope.service.service_status import ServiceExecStatus
from agentscope.models import ModelWrapperBase


[docs] def retrieve_from_list( query: Any, knowledge: Sequence, # TODO: rename score_func: Callable[[Any, Any], float], top_k: int = None, embedding_model: Optional[ModelWrapperBase] = None, preserve_order: bool = True, ) -> ServiceResponse: """ Retrieve data in a list. Memory retrieval with user-defined score function. The score function is expected to take the `query` and one of the element in 'knowledge' (a list). This function retrieves top-k elements in 'knowledge' with HIGHEST scores. If the 'query' is a dict but has no embedding, we use the embedding model to embed the query. Args: query (`Any`): A message to be retrieved. knowledge (`Sequence`): Data/knowledge to be retrieved from. score_func (`Callable[[Any, Any], float]`): User-defined function for comparing two messages. top_k (`int`, defaults to `None`): Maximum number of messages returned. embedding_model (`Optional[ModelWrapperBase]`, defaults to `None`): A model to embed the query/message. preserve_order (`bool`, defaults to `True`): Whether to preserve the original order of the retrieved data. Defaults to True. Returns: `ServiceResponse`: The top-k retrieved messages with HIGHEST scores. """ if isinstance(query, dict): if embedding_model is not None and "embedding" not in query: query["embedding"] = embedding_model( [query], return_embedding_only=True, ) elif embedding_model is None and "embedding" not in query: logger.warning( "Since the input query has no embedding, embedding model is " "is not provided either.", ) # (score, index, object) scores = [ (score_func(query, msg), i, msg) for i, msg in enumerate(knowledge) ] # ordered by score, and extract the top-k items with highest scores top_k = len(scores) if top_k is None else top_k ordered_top_k_scores = sorted(scores, key=lambda x: x[0], reverse=True)[ :top_k ] # if keep the original order if preserve_order: # ordered by index content = sorted(ordered_top_k_scores, key=lambda x: x[1]) else: content = ordered_top_k_scores # The returned content includes a list of triples of (score, index, object) return ServiceResponse( status=ServiceExecStatus.SUCCESS, content=content, )