@@ -180,6 +180,7 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
180180
181181 self ._keys = []
182182 self .remote = False
183+ self .local = False
183184 self .cache_time = cache_time
184185 self .time_out = 0
185186 self .etag = ""
@@ -189,6 +190,8 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
189190 self .keyusage = keyusage
190191 self .imp_jwks = None
191192 self .last_updated = 0
193+ self .last_remote = None # HTTP Date of last remote update
194+ self .last_local = None # UNIX timestamp of last local update
192195
193196 if httpc :
194197 self .httpc = httpc
@@ -208,13 +211,13 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
208211 self .do_keys (keys )
209212 else :
210213 self ._set_source (source , fileformat )
211-
212- if not self .remote and self .source : # local file
214+ if self .local :
213215 self ._do_local (kid )
214216
215217 def _set_source (self , source , fileformat ):
216218 if source .startswith ("file://" ):
217219 self .source = source [7 :]
220+ self .local = True
218221 elif source .startswith ("http://" ) or source .startswith ("https://" ):
219222 self .source = source
220223 self .remote = True
@@ -224,6 +227,7 @@ def _set_source(self, source, fileformat):
224227 if fileformat .lower () in ['rsa' , 'der' , 'jwks' ]:
225228 if os .path .isfile (source ):
226229 self .source = source
230+ self .local = True
227231 else :
228232 raise ImportError ('No such file' )
229233 else :
@@ -235,6 +239,16 @@ def _do_local(self, kid):
235239 elif self .fileformat == "der" :
236240 self .do_local_der (self .source , self .keytype , self .keyusage , kid )
237241
242+ def _local_update_required (self ) -> bool :
243+ stat = os .stat (self .source )
244+ if self .last_local and stat .st_mtime < self .last_local :
245+ LOGGER .debug ("%s not modfied" , self .source )
246+ return False
247+ else :
248+ LOGGER .debug ("%s modfied" , self .source )
249+ self .last_local = stat .st_mtime
250+ return True
251+
238252 def do_keys (self , keys ):
239253 """
240254 Go from JWK description to binary keys
@@ -290,12 +304,15 @@ def do_local_jwk(self, filename):
290304
291305 :param filename: Name of the file from which the JWKS should be loaded
292306 """
307+ LOGGER .debug ("Reading JWKS from %s" , filename )
293308 with open (filename ) as input_file :
294309 _info = json .load (input_file )
295310 if 'keys' in _info :
296311 self .do_keys (_info ["keys" ])
297312 else :
298313 self .do_keys ([_info ])
314+ self .last_local = time .time ()
315+ self .time_out = self .last_local + self .cache_time
299316
300317 def do_local_der (self , filename , keytype , keyusage = None , kid = '' ):
301318 """
@@ -305,6 +322,7 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''):
305322 :param keytype: Presently 'rsa' and 'ec' supported
306323 :param keyusage: encryption ('enc') or signing ('sig') or both
307324 """
325+ LOGGER .debug ("Reading DER from %s" , filename )
308326 key_args = {}
309327 _kty = keytype .lower ()
310328 if _kty in ['rsa' , 'ec' ]:
@@ -324,6 +342,8 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''):
324342 key_args ['kid' ] = kid
325343
326344 self .do_keys ([key_args ])
345+ self .last_local = time .time ()
346+ self .time_out = self .last_local + self .cache_time
327347
328348 def do_remote (self ):
329349 """
@@ -336,6 +356,10 @@ def do_remote(self):
336356
337357 try :
338358 LOGGER .debug ('KeyBundle fetch keys from: %s' , self .source )
359+ if self .last_remote is not None :
360+ if "headers" not in self .httpc_params :
361+ self .httpc_params ["headers" ] = {}
362+ self .httpc_params ["headers" ]["If-Modified-Since" ] = self .last_remote
339363 _http_resp = self .httpc ('GET' , self .source , ** self .httpc_params )
340364 except Exception as err :
341365 LOGGER .error (err )
@@ -357,6 +381,14 @@ def do_remote(self):
357381 LOGGER .error ("No 'keys' keyword in JWKS" )
358382 raise UpdateFailed (MALFORMED .format (self .source ))
359383
384+ if hasattr (_http_resp , "headers" ):
385+ headers = getattr (_http_resp , "headers" )
386+ self .last_remote = headers .get ("last-modified" ) or headers .get ("date" )
387+
388+ elif _http_resp .status_code == 304 : # Not modified
389+ LOGGER .debug ("%s not modified since %s" , self .source , self .last_remote )
390+ pass
391+
360392 else :
361393 raise UpdateFailed (
362394 REMOTE_FAILED .format (self .source , _http_resp .status_code ))
@@ -387,14 +419,12 @@ def _parse_remote_response(self, response):
387419
388420 def _uptodate (self ):
389421 res = False
390- if not self ._keys :
391- if self .remote : # verify that it's not to old
392- if time .time () > self .time_out :
393- if self .update ():
394- res = True
395- elif self .remote :
396- if self .update ():
397- res = True
422+ if self .remote or self .local :
423+ if time .time () > self .time_out :
424+ if self .local and not self ._local_update_required ():
425+ res = True
426+ elif self .update ():
427+ res = True
398428 return res
399429
400430 def update (self ):
@@ -412,13 +442,13 @@ def update(self):
412442 self ._keys = []
413443
414444 try :
415- if self .remote is False :
445+ if self .local :
416446 if self .fileformat in ["jwks" , "jwk" ]:
417447 self .do_local_jwk (self .source )
418448 elif self .fileformat == "der" :
419449 self .do_local_der (self .source , self .keytype ,
420450 self .keyusage )
421- else :
451+ elif self . remote :
422452 res = self .do_remote ()
423453 except Exception as err :
424454 LOGGER .error ('Key bundle update failed: %s' , err )
@@ -661,8 +691,11 @@ def dump(self):
661691 "keys" : _keys ,
662692 "fileformat" : self .fileformat ,
663693 "last_updated" : self .last_updated ,
694+ "last_remote" : self .last_remote ,
695+ "last_local" : self .last_local ,
664696 "httpc_params" : self .httpc_params ,
665697 "remote" : self .remote ,
698+ "local" : self .local ,
666699 "imp_jwks" : self .imp_jwks ,
667700 "time_out" : self .time_out ,
668701 "cache_time" : self .cache_time
@@ -680,7 +713,10 @@ def load(self, spec):
680713 self .source = spec .get ("source" , None )
681714 self .fileformat = spec .get ("fileformat" , "jwks" )
682715 self .last_updated = spec .get ("last_updated" , 0 )
716+ self .last_remote = spec .get ("last_remote" , None )
717+ self .last_local = spec .get ("last_local" , None )
683718 self .remote = spec .get ("remote" , False )
719+ self .local = spec .get ("local" , False )
684720 self .imp_jwks = spec .get ('imp_jwks' , None )
685721 self .time_out = spec .get ('time_out' , 0 )
686722 self .cache_time = spec .get ('cache_time' , 0 )
0 commit comments