diff --git a/promptify/models/nlp/text2text/base_model.py b/promptify/models/nlp/text2text/base_model.py index 45b725e..fa802de 100644 --- a/promptify/models/nlp/text2text/base_model.py +++ b/promptify/models/nlp/text2text/base_model.py @@ -1,5 +1,6 @@ from abc import ABCMeta, abstractmethod -from typing import List, Optional, Union, Dict +from typing import Dict, List, Optional, Union + import tenacity @@ -396,6 +397,7 @@ def _retry_decorator(self): multiplier=0.3, exp_base=3, max=self.api_wait ), stop=tenacity.stop_after_attempt(self.api_retry), + reraise=True, ) def execute_with_retry(self, *args, **kwargs): diff --git a/promptify/models/nlp/text2text/openai_complete.py b/promptify/models/nlp/text2text/openai_complete.py index 49b3ee4..f3b286c 100644 --- a/promptify/models/nlp/text2text/openai_complete.py +++ b/promptify/models/nlp/text2text/openai_complete.py @@ -1,9 +1,12 @@ from typing import Dict, List, Optional, Tuple, Union -import openai + import json +import openai +import tenacity import tiktoken -from promptify.parser.parser import Parser + from promptify.models.nlp.text2text.base_model import Model +from promptify.parser.parser import Parser class OpenAI(Model): @@ -329,6 +332,32 @@ def model_output(self, response: Dict, max_completion_length: int) -> Dict: return data + def _retry_decorator(self): + """ + Decorator function for retrying API requests if they fail. + + Returns + ------- + tenacity.Retrying + A decorator function for retrying API requests. + + Notes + ----- + This method is a decorator function for retrying API requests using tenacity. + """ + + return tenacity.retry( + wait=tenacity.wait_random_exponential( + multiplier=0.3, exp_base=3, max=self.api_wait + ), + stop=tenacity.stop_after_attempt(self.api_retry), + retry=tenacity.retry_if_exception_type( + (openai.error.APIError, openai.error.TryAgain, openai.error.Timeout, + openai.error.APIConnectionError, openai.error.RateLimitError, + openai.error.ServiceUnavailableError, )), + reraise=True, + ) + def _store_session(self, session_identifier: str): import json import os