@@ -180,6 +180,7 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
180180
181181 self ._keys = []
182182 self .remote = False
183+ self .local = False
183184 self .cache_time = cache_time
184185 self .time_out = 0
185186 self .etag = ""
@@ -189,7 +190,8 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
189190 self .keyusage = keyusage
190191 self .imp_jwks = None
191192 self .last_updated = 0
192- self .last_remote = None
193+ self .last_remote = None # HTTP Date of last remote update
194+ self .last_local = None # UNIX timestamp of last local update
193195
194196 if httpc :
195197 self .httpc = httpc
@@ -209,13 +211,13 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True,
209211 self .do_keys (keys )
210212 else :
211213 self ._set_source (source , fileformat )
212-
213- if not self .remote and self .source : # local file
214+ if self .local :
214215 self ._do_local (kid )
215216
216217 def _set_source (self , source , fileformat ):
217218 if source .startswith ("file://" ):
218219 self .source = source [7 :]
220+ self .local = True
219221 elif source .startswith ("http://" ) or source .startswith ("https://" ):
220222 self .source = source
221223 self .remote = True
@@ -225,6 +227,7 @@ def _set_source(self, source, fileformat):
225227 if fileformat .lower () in ['rsa' , 'der' , 'jwks' ]:
226228 if os .path .isfile (source ):
227229 self .source = source
230+ self .local = True
228231 else :
229232 raise ImportError ('No such file' )
230233 else :
@@ -236,6 +239,16 @@ def _do_local(self, kid):
236239 elif self .fileformat == "der" :
237240 self .do_local_der (self .source , self .keytype , self .keyusage , kid )
238241
242+ def _local_update_required (self ) -> bool :
243+ stat = os .stat (self .source )
244+ if self .last_local and stat .st_mtime < self .last_local :
245+ LOGGER .debug ("%s not modfied" , self .source )
246+ return False
247+ else :
248+ LOGGER .debug ("%s modfied" , self .source )
249+ self .last_local = stat .st_mtime
250+ return True
251+
239252 def do_keys (self , keys ):
240253 """
241254 Go from JWK description to binary keys
@@ -291,12 +304,15 @@ def do_local_jwk(self, filename):
291304
292305 :param filename: Name of the file from which the JWKS should be loaded
293306 """
307+ LOGGER .debug ("Reading JWKS from %s" , filename )
294308 with open (filename ) as input_file :
295309 _info = json .load (input_file )
296310 if 'keys' in _info :
297311 self .do_keys (_info ["keys" ])
298312 else :
299313 self .do_keys ([_info ])
314+ self .last_local = time .time ()
315+ self .time_out = self .last_local + self .cache_time
300316
301317 def do_local_der (self , filename , keytype , keyusage = None , kid = '' ):
302318 """
@@ -306,6 +322,7 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''):
306322 :param keytype: Presently 'rsa' and 'ec' supported
307323 :param keyusage: encryption ('enc') or signing ('sig') or both
308324 """
325+ LOGGER .debug ("Reading DER from %s" , filename )
309326 key_args = {}
310327 _kty = keytype .lower ()
311328 if _kty in ['rsa' , 'ec' ]:
@@ -325,6 +342,8 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''):
325342 key_args ['kid' ] = kid
326343
327344 self .do_keys ([key_args ])
345+ self .last_local = time .time ()
346+ self .time_out = self .last_local + self .cache_time
328347
329348 def do_remote (self ):
330349 """
@@ -400,14 +419,12 @@ def _parse_remote_response(self, response):
400419
401420 def _uptodate (self ):
402421 res = False
403- if not self ._keys :
404- if self .remote : # verify that it's not to old
405- if time .time () > self .time_out :
406- if self .update ():
407- res = True
408- elif self .remote :
409- if self .update ():
410- res = True
422+ if self .remote or self .local :
423+ if time .time () > self .time_out :
424+ if self .local and not self ._local_update_required ():
425+ res = True
426+ elif self .update ():
427+ res = True
411428 return res
412429
413430 def update (self ):
@@ -425,13 +442,13 @@ def update(self):
425442 self ._keys = []
426443
427444 try :
428- if self .remote is False :
445+ if self .local :
429446 if self .fileformat in ["jwks" , "jwk" ]:
430447 self .do_local_jwk (self .source )
431448 elif self .fileformat == "der" :
432449 self .do_local_der (self .source , self .keytype ,
433450 self .keyusage )
434- else :
451+ elif self . remote :
435452 res = self .do_remote ()
436453 except Exception as err :
437454 LOGGER .error ('Key bundle update failed: %s' , err )
@@ -674,8 +691,11 @@ def dump(self):
674691 "keys" : _keys ,
675692 "fileformat" : self .fileformat ,
676693 "last_updated" : self .last_updated ,
694+ "last_remote" : self .last_remote ,
695+ "last_local" : self .last_local ,
677696 "httpc_params" : self .httpc_params ,
678697 "remote" : self .remote ,
698+ "local" : self .local ,
679699 "imp_jwks" : self .imp_jwks ,
680700 "time_out" : self .time_out ,
681701 "cache_time" : self .cache_time
@@ -693,7 +713,10 @@ def load(self, spec):
693713 self .source = spec .get ("source" , None )
694714 self .fileformat = spec .get ("fileformat" , "jwks" )
695715 self .last_updated = spec .get ("last_updated" , 0 )
716+ self .last_remote = spec .get ("last_remote" , None )
717+ self .last_local = spec .get ("last_local" , None )
696718 self .remote = spec .get ("remote" , False )
719+ self .local = spec .get ("local" , False )
697720 self .imp_jwks = spec .get ('imp_jwks' , None )
698721 self .time_out = spec .get ('time_out' , 0 )
699722 self .cache_time = spec .get ('cache_time' , 0 )
0 commit comments