Source code for agentscope.rpc.retry_strategy

# -*- coding: utf-8 -*-
"""
Timeout retry strategies
"""
from __future__ import annotations
import time
import random
import inspect
from abc import ABC, abstractmethod
from typing import Callable, Any
from functools import partial
from loguru import logger


[docs] class RetryBase(ABC): """The base class for all retry strategies"""
[docs] @abstractmethod def retry(self, func: Callable, *args: Any, **kwargs: Any) -> Any: """Retry the func when any exception occurs"""
def __call__(self, func: Callable, *args: Any, **kwargs: Any) -> Any: """Call the retry method""" return self.retry(func, *args, **kwargs)
[docs] @classmethod def load_dict(cls, data: dict) -> RetryBase: """Load the retry strategy from a dict""" retry_type = data.pop("type", None) if retry_type == "fixed": return RetryFixedTimes(**data) elif retry_type == "expential": return RetryExpential(**data) else: raise NotImplementedError( f"Unknown retry strategy type: {retry_type}", )
[docs] class RetryFixedTimes(RetryBase): """ Retry a fixed number of times, and wait a fixed delay time between each attempt. Init dict format: - type: 'fixed' - max_retries (`int`): The max retry times - delay (`float`): The delay time between each attempt .. code-block:: python retry = RetryBase.load_dict({ "type": "fixed", "max_retries": 10, "delay": 5, }) """
[docs] def __init__(self, max_retries: int = 10, delay: float = 5) -> None: """Initialize the retry strategy Args: max_retries (`int`): The max retry times delay (`float`): The delay time between each attempt """ self.max_retries = max_retries self.delay = delay
[docs] def retry( # pylint: disable=R1710 self, func: Callable, *args: Any, **kwargs: Any, ) -> Any: exception_type = kwargs.pop("expect_exception_type", Exception) func = partial(func, *args, **kwargs) for attempt in range(self.max_retries + 1): try: return func() except exception_type as e: if attempt == self.max_retries: raise TimeoutError("Max timeout exceeded.") from e random_delay = (random.random() + 0.5) * self.delay frame_info = inspect.getframeinfo( inspect.currentframe().f_back, # type: ignore[arg-type] ) file_name = frame_info.filename line_number = frame_info.lineno logger.debug( f"Attempt {attempt + 1} at [{file_name}:{line_number}] failed:" f"\n{e}.\nRetrying in {random_delay:.2f} seconds...", ) time.sleep(random_delay) logger.error(f"Max timeout exceeded at [{file_name}:{line_number}].") raise TimeoutError("Max retry exceeded.")
[docs] class RetryExpential(RetryBase): """ Retry with exponential backoff, which means the delay time will increase exponentially. Init dict format: - type: 'expential' - max_retries (`int`): The max retry times - base_delay (`float`): The base delay time - max_delay (`float`): The max delay time, which will be used if the calculated delay time - exceeds it. .. code-block:: python retry = RetryBase.load_dict({ "type": "expential", "max_retries": 10, "base_delay": 5, "max_delay": 300, }) """
[docs] def __init__( self, max_retries: int = 10, base_delay: float = 5, max_delay: float = 300, ) -> None: """Initialize the retry strategy Args: max_retries (`int`): The max retry times base_delay (`float`): The base delay time max_delay (`float`): The max delay time """ self.max_retries = max_retries self.base_delay = base_delay self.max_delay = max_delay
[docs] def retry( # pylint: disable=R1710 self, func: Callable, *args: Any, **kwargs: Any, ) -> Any: exception_type = kwargs.pop("expect_exception_type", Exception) func = partial(func, *args, **kwargs) delay = self.base_delay for attempt in range(self.max_retries + 1): try: return func() except exception_type as e: if attempt == self.max_retries: raise TimeoutError("Max timeout exceeded.") from e random_delay = min( (random.random() + 0.5) * delay, self.max_delay, ) frame_info = inspect.getframeinfo( inspect.currentframe().f_back, # type: ignore[arg-type] ) file_name = frame_info.filename line_number = frame_info.lineno logger.debug( f"Attempt {attempt + 1} at [{file_name}:{line_number}] failed:" f"\n{e}.\nRetrying in {random_delay:.2f} seconds...", ) time.sleep(random_delay) delay *= 2 logger.error(f"Max timeout exceeded at [{file_name}:{line_number}].") raise TimeoutError("Max retry exceeded.")
_DEAFULT_RETRY_STRATEGY = RetryFixedTimes(max_retries=10, delay=5)