<html><head><meta name="color-scheme" content="light dark"></head><body><pre style="word-wrap: break-word; white-space: pre-wrap;">import asyncio
import logging
import sys
from abc import abstractmethod
from dataclasses import dataclass
from types import TracebackType
from typing import (
    Any,
    Awaitable,
    Callable,
    Dict,
    Generator,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)

from aiohttp import ClientResponse, ClientSession, hdrs
from aiohttp.typedefs import StrOrURL
from yarl import URL as YARL_URL

from .retry_options import ExponentialRetry, RetryOptionsBase

if sys.version_info &gt;= (3, 8):
    from typing import Protocol
else:
    from typing_extensions import Protocol


class _Logger(Protocol):
    """
    _Logger defines which methods logger object should have
    """

    @abstractmethod
    def debug(self, msg: str, *args: Any, **kwargs: Any) -&gt; None: pass

    @abstractmethod
    def warning(self, msg: str, *args: Any, **kwargs: Any) -&gt; None: pass

    @abstractmethod
    def exception(self, msg: str, *args: Any, **kwargs: Any) -&gt; None: pass


# url itself or list of urls for changing between retries
_RAW_URL_TYPE = Union[StrOrURL, YARL_URL]
_URL_TYPE = Union[_RAW_URL_TYPE, List[_RAW_URL_TYPE], Tuple[_RAW_URL_TYPE, ...]]
_LoggerType = Union[_Logger, logging.Logger]

RequestFunc = Callable[..., Awaitable[ClientResponse]]


@dataclass
class RequestParams:
    method: str
    url: _RAW_URL_TYPE
    headers: Optional[Dict[str, Any]] = None
    trace_request_ctx: Optional[Dict[str, Any]] = None
    kwargs: Optional[Dict[str, Any]] = None


class _RequestContext:
    def __init__(
        self,
        request_func: RequestFunc,
        params_list: List[RequestParams],
        logger: _LoggerType,
        retry_options: RetryOptionsBase,
        raise_for_status: bool = False,
    ) -&gt; None:
        assert len(params_list) &gt; 0

        self._request_func = request_func
        self._params_list = params_list
        self._logger = logger
        self._retry_options = retry_options
        self._raise_for_status = raise_for_status

        self._response: Optional[ClientResponse] = None

    def _is_status_code_ok(self, code: int) -&gt; bool:
        if code &gt;= 500 and self._retry_options.retry_all_server_errors:
            return False
        return code not in self._retry_options.statuses

    async def _do_request(self) -&gt; ClientResponse:
        current_attempt = 0
        while True:
            self._logger.debug(f"Attempt {current_attempt+1} out of {self._retry_options.attempts}")

            current_attempt += 1
            try:
                try:
                    params = self._params_list[current_attempt - 1]
                except IndexError:
                    params = self._params_list[-1]

                response: ClientResponse = await self._request_func(
                    params.method,
                    params.url,
                    headers=params.headers,
                    trace_request_ctx={
                        'current_attempt': current_attempt,
                        **(params.trace_request_ctx or {}),
                    },
                    **(params.kwargs or {}),
                )

                if self._is_status_code_ok(response.status) or current_attempt == self._retry_options.attempts:
                    if self._raise_for_status:
                        response.raise_for_status()

                    if self._retry_options.evaluate_response_callback is not None:
                        try:
                            is_response_correct = await self._retry_options.evaluate_response_callback(response)
                        except Exception:
                            self._logger.exception('while evaluating response an exception occurred')
                            is_response_correct = False
                    else:
                        is_response_correct = True

                    if is_response_correct or current_attempt == self._retry_options.attempts:
                        self._response = response
                        return response
                    else:
                        self._logger.debug(f"Retrying after evaluate response callback check")
                else:
                    self._logger.debug(f"Retrying after response code: {response.status}")
                retry_wait = self._retry_options.get_timeout(attempt=current_attempt, response=response)
            except Exception as e:
                if current_attempt &gt;= self._retry_options.attempts:
                    raise e

                is_exc_valid = any([isinstance(e, exc) for exc in self._retry_options.exceptions])
                if not is_exc_valid:
                    raise e

                self._logger.debug(f"Retrying after exception: {repr(e)}")
                retry_wait = self._retry_options.get_timeout(attempt=current_attempt, response=None)

            await asyncio.sleep(retry_wait)

    def __await__(self) -&gt; Generator[Any, None, ClientResponse]:
        return self.__aenter__().__await__()

    async def __aenter__(self) -&gt; ClientResponse:
        return await self._do_request()

    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -&gt; None:
        if self._response is not None:
            if not self._response.closed:
                self._response.close()


def _url_to_urls(url: _URL_TYPE) -&gt; Tuple[StrOrURL, ...]:
    if isinstance(url, str) or isinstance(url, YARL_URL):
        return (url,)

    if isinstance(url, list):
        urls = tuple(url)
    elif isinstance(url, tuple):
        urls = url
    else:
        raise ValueError("you can pass url only by str or list/tuple")

    if len(urls) == 0:
        raise ValueError("you can pass url by str or list/tuple with attempts count size")

    return urls


class RetryClient:
    def __init__(
        self,
        client_session: Optional[ClientSession] = None,
        logger: Optional[_LoggerType] = None,
        retry_options: Optional[RetryOptionsBase] = None,
        raise_for_status: bool = False,
        *args: Any,
        **kwargs: Any,
    ) -&gt; None:
        if client_session is not None:
            client = client_session
            closed = None
        else:
            client = ClientSession(*args, **kwargs)
            closed = False

        self._client = client
        self._closed = closed

        self._logger: _LoggerType = logger or logging.getLogger("aiohttp_retry")
        self._retry_options: RetryOptionsBase = retry_options or ExponentialRetry()
        self._raise_for_status = raise_for_status

    @property
    def retry_options(self) -&gt; RetryOptionsBase:
        return self._retry_options

    def requests(
        self,
        params_list: List[RequestParams],
        retry_options: Optional[RetryOptionsBase] = None,
        raise_for_status: Optional[bool] = None,
    ) -&gt; _RequestContext:
        return self._make_requests(
            params_list=params_list,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
        )

    def request(
        self,
        method: str,
        url: StrOrURL,
        retry_options: Optional[RetryOptionsBase] = None,
        raise_for_status: Optional[bool] = None,
        **kwargs: Any,
    ) -&gt; _RequestContext:
        return self._make_request(
            method=method,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def get(
        self,
        url: _URL_TYPE,
        retry_options: Optional[RetryOptionsBase] = None,
        raise_for_status: Optional[bool] = None,
        **kwargs: Any,
    ) -&gt; _RequestContext:
        return self._make_request(
            method=hdrs.METH_GET,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def options(
        self,
        url: _URL_TYPE,
        retry_options: Optional[RetryOptionsBase] = None,
        raise_for_status: Optional[bool] = None,
        **kwargs: Any,
    ) -&gt; _RequestContext:
        return self._make_request(
            method=hdrs.METH_OPTIONS,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def head(
        self,
        url: _URL_TYPE,
        retry_options: Optional[RetryOptionsBase] = None,
        raise_for_status: Optional[bool] = None, **kwargs: Any,
    ) -&gt; _RequestContext:
        return self._make_request(
            method=hdrs.METH_HEAD,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def post(
        self,
        url: _URL_TYPE,
        retry_options: Optional[RetryOptionsBase] = None,
        raise_for_status: Optional[bool] = None,
        **kwargs: Any,
    ) -&gt; _RequestContext:
        return self._make_request(
            method=hdrs.METH_POST,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def put(
        self,
        url: _URL_TYPE,
        retry_options: Optional[RetryOptionsBase] = None,
        raise_for_status: Optional[bool] = None,
        **kwargs: Any,
    ) -&gt; _RequestContext:
        return self._make_request(
            method=hdrs.METH_PUT,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def patch(
        self,
        url: _URL_TYPE,
        retry_options: Optional[RetryOptionsBase] = None,
        raise_for_status: Optional[bool] = None,
        **kwargs: Any,
    ) -&gt; _RequestContext:
        return self._make_request(
            method=hdrs.METH_PATCH,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def delete(
        self,
        url: _URL_TYPE,
        retry_options: Optional[RetryOptionsBase] = None,
        raise_for_status: Optional[bool] = None,
        **kwargs: Any,
    ) -&gt; _RequestContext:
        return self._make_request(
            method=hdrs.METH_DELETE,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    async def close(self) -&gt; None:
        await self._client.close()
        self._closed = True

    def _make_request(
        self,
        method: str,
        url: _URL_TYPE,
        retry_options: Optional[RetryOptionsBase] = None,
        raise_for_status: Optional[bool] = None,
        **kwargs: Any,
    ) -&gt; _RequestContext:
        url_list = _url_to_urls(url)
        params_list = [RequestParams(
            method=method,
            url=url,
            headers=kwargs.pop('headers', {}),
            trace_request_ctx=kwargs.pop('trace_request_ctx', None),
            kwargs=kwargs,
        ) for url in url_list]

        return self._make_requests(
            params_list=params_list,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
        )

    def _make_requests(
        self,
        params_list: List[RequestParams],
        retry_options: Optional[RetryOptionsBase] = None,
        raise_for_status: Optional[bool] = None,
    ) -&gt; _RequestContext:
        if retry_options is None:
            retry_options = self._retry_options
        if raise_for_status is None:
            raise_for_status = self._raise_for_status
        return _RequestContext(
            request_func=self._client.request,
            params_list=params_list,
            logger=self._logger,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
        )

    async def __aenter__(self) -&gt; 'RetryClient':
        return self

    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -&gt; None:
        await self.close()

    def __del__(self) -&gt; None:
        if getattr(self, '_closed', None) is None:
            # in case object was not initialized (__init__ raised an exception)
            return

        if not self._closed:
            self._logger.warning("Aiohttp retry client was not closed")
</pre></body></html>