1313# under the License.
1414from __future__ import annotations
1515
16+ import asyncio
17+
1618from json import JSONDecodeError
1719from os import environ
18- from typing import TYPE_CHECKING , Any , Awaitable , Callable , Dict
20+ from typing import TYPE_CHECKING , Any , Awaitable , Callable , Dict , Optional
1921
2022from httpx import AsyncClient , ConnectTimeout , NetworkError , Response
2123
2527 API_VERSION_HEADER ,
2628 RID_KEY_HEADER ,
2729 SUPPORTED_CDI_VERSIONS ,
30+ RATE_LIMIT_STATUS_CODE ,
2831)
2932from .normalised_url_path import NormalisedURLPath
3033
@@ -42,7 +45,7 @@ class Querier:
4245 __init_called = False
4346 __hosts : List [Host ] = []
4447 __api_key : Union [None , str ] = None
45- __api_version = None
48+ api_version = None
4649 __last_tried_index : int = 0
4750 __hosts_alive_for_testing : Set [str ] = set ()
4851
@@ -69,8 +72,8 @@ def get_hosts_alive_for_testing():
6972 return Querier .__hosts_alive_for_testing
7073
7174 async def get_api_version (self ):
72- if Querier .__api_version is not None :
73- return Querier .__api_version
75+ if Querier .api_version is not None :
76+ return Querier .api_version
7477
7578 ProcessState .get_instance ().add_state (
7679 AllowedProcessStates .CALLING_SERVICE_IN_GET_API_VERSION
@@ -96,8 +99,8 @@ async def f(url: str) -> Response:
9699 "to find the right versions"
97100 )
98101
99- Querier .__api_version = api_version
100- return Querier .__api_version
102+ Querier .api_version = api_version
103+ return Querier .api_version
101104
102105 @staticmethod
103106 def get_instance (rid_to_core : Union [str , None ] = None ):
@@ -113,7 +116,7 @@ def init(hosts: List[Host], api_key: Union[str, None] = None):
113116 Querier .__init_called = True
114117 Querier .__hosts = hosts
115118 Querier .__api_key = api_key
116- Querier .__api_version = None
119+ Querier .api_version = None
117120 Querier .__last_tried_index = 0
118121 Querier .__hosts_alive_for_testing = set ()
119122
@@ -196,6 +199,7 @@ async def __send_request_helper(
196199 method : str ,
197200 http_function : Callable [[str ], Awaitable [Response ]],
198201 no_of_tries : int ,
202+ retry_info_map : Optional [Dict [str , int ]] = None ,
199203 ) -> Any :
200204 if no_of_tries == 0 :
201205 raise_general_exception ("No SuperTokens core available to query" )
@@ -212,6 +216,14 @@ async def __send_request_helper(
212216 Querier .__last_tried_index %= len (self .__hosts )
213217 url = current_host + path .get_as_string_dangerous ()
214218
219+ max_retries = 5
220+
221+ if retry_info_map is None :
222+ retry_info_map = {}
223+
224+ if retry_info_map .get (url ) is None :
225+ retry_info_map [url ] = max_retries
226+
215227 ProcessState .get_instance ().add_state (
216228 AllowedProcessStates .CALLING_SERVICE_IN_REQUEST_HELPER
217229 )
@@ -221,6 +233,20 @@ async def __send_request_helper(
221233 ):
222234 Querier .__hosts_alive_for_testing .add (current_host )
223235
236+ if response .status_code == RATE_LIMIT_STATUS_CODE :
237+ retries_left = retry_info_map [url ]
238+
239+ if retries_left > 0 :
240+ retry_info_map [url ] = retries_left - 1
241+
242+ attempts_made = max_retries - retries_left
243+ delay = (10 + attempts_made * 250 ) / 1000
244+
245+ await asyncio .sleep (delay )
246+ return await self .__send_request_helper (
247+ path , method , http_function , no_of_tries , retry_info_map
248+ )
249+
224250 if is_4xx_error (response .status_code ) or is_5xx_error (response .status_code ): # type: ignore
225251 raise_general_exception (
226252 "SuperTokens core threw an error for a "
@@ -238,9 +264,9 @@ async def __send_request_helper(
238264 except JSONDecodeError :
239265 return response .text
240266
241- except (ConnectionError , NetworkError , ConnectTimeout ):
267+ except (ConnectionError , NetworkError , ConnectTimeout ) as _ :
242268 return await self .__send_request_helper (
243- path , method , http_function , no_of_tries - 1
269+ path , method , http_function , no_of_tries - 1 , retry_info_map
244270 )
245271 except Exception as e :
246272 raise_general_exception (e )
0 commit comments