@@ -252,15 +252,16 @@ def __init__(
252252 self .source = None
253253 if isinstance (keys , dict ):
254254 if "keys" in keys :
255- self . _add_jwk_dicts ( keys ["keys" ])
255+ initial_keys = keys ["keys" ]
256256 else :
257- self . _add_jwk_dicts ( [keys ])
257+ initial_keys = [keys ]
258258 else :
259- self ._add_jwk_dicts (keys )
259+ initial_keys = keys
260+ self ._keys = self .jwk_dicts_as_keys (initial_keys )
260261 else :
261262 self ._set_source (source , fileformat )
262263 if self .local :
263- self ._do_local (kid )
264+ self ._keys = self . _do_local (kid )
264265
265266 def _set_source (self , source , fileformat ):
266267 if source .startswith ("file://" ):
@@ -283,9 +284,10 @@ def _set_source(self, source, fileformat):
283284
284285 def _do_local (self , kid ):
285286 if self .fileformat in ["jwks" , "jwk" ]:
286- self ._do_local_jwk (self .source )
287+ updated , keys = self ._do_local_jwk (self .source )
287288 elif self .fileformat == "der" :
288- self ._do_local_der (self .source , self .keytype , self .keyusage , kid )
289+ updated , keys = self ._do_local_der (self .source , self .keytype , self .keyusage , kid )
290+ return keys
289291
290292 def _local_update_required (self ) -> bool :
291293 stat = os .stat (self .source )
@@ -309,13 +311,8 @@ def add_jwk_dicts(self, keys):
309311 :param keys: List of JWK dictionaries
310312 :return:
311313 """
312- self ._add_jwk_dicts (keys )
313-
314- def _add_jwk_dicts (self , keys ):
315- _new_keys = self .jwk_dicts_as_keys (keys )
316- if _new_keys :
317- self ._keys .extend (_new_keys )
318- self .last_updated = time .time ()
314+ self ._keys .extend (self .jwk_dicts_as_keys (keys ))
315+ self .last_updated = time .time ()
319316
320317 def jwk_dicts_as_keys (self , keys ):
321318 """
@@ -384,18 +381,19 @@ def _do_local_jwk(self, filename):
384381 :return: True if load was successful or False if file hasn't been modified
385382 """
386383 if not self ._local_update_required ():
387- return False
384+ return False , None
388385
389386 LOGGER .info ("Reading local JWKS from %s" , filename )
390387 with open (filename ) as input_file :
391388 _info = json .load (input_file )
392389 if "keys" in _info :
393- self ._add_jwk_dicts (_info ["keys" ])
390+ new_keys = self .jwk_dicts_as_keys (_info ["keys" ])
394391 else :
395- self ._add_jwk_dicts ([_info ])
392+ new_keys = self .jwk_dicts_as_keys ([_info ])
393+
396394 self .last_local = time .time ()
397395 self .time_out = self .last_local + self .cache_time
398- return True
396+ return True , new_keys
399397
400398 def _do_local_der (self , filename , keytype , keyusage = None , kid = "" ):
401399 """
@@ -407,7 +405,7 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
407405 :return: True if load was successful or False if file hasn't been modified
408406 """
409407 if not self ._local_update_required ():
410- return False
408+ return False , None
411409
412410 LOGGER .info ("Reading local DER from %s" , filename )
413411 key_args = {}
@@ -428,12 +426,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
428426 if kid :
429427 key_args ["kid" ] = kid
430428
431- self ._add_jwk_dicts ([key_args ])
429+ new_keys = self .jwk_dicts_as_keys ([key_args ])
432430 self .last_local = time .time ()
433431 self .time_out = self .last_local + self .cache_time
434- return True
432+ return True , new_keys
435433
436- def _do_remote (self ):
434+ def _do_remote (self , set_keys = True ):
437435 """
438436 Load a JWKS from a webpage.
439437
@@ -448,7 +446,7 @@ def _do_remote(self):
448446 self .source ,
449447 datetime .fromtimestamp (self .ignore_errors_until ),
450448 )
451- return False
449+ return False , None
452450
453451 LOGGER .info ("Reading remote JWKS from %s" , self .source )
454452 try :
@@ -497,11 +495,12 @@ def _do_remote(self):
497495 self .ignore_errors_until = time .time () + self .ignore_errors_period
498496 raise UpdateFailed (REMOTE_FAILED .format (self .source , _http_resp .status_code ))
499497
500- if new_keys is not None :
498+ if set_keys and new_keys :
501499 self ._keys = new_keys
500+
502501 self .last_updated = time .time ()
503502 self .ignore_errors_until = None
504- return load_successful
503+ return load_successful , new_keys
505504
506505 def _parse_remote_response (self , response ):
507506 """
@@ -542,34 +541,31 @@ def update(self):
542541 :return: True if update was ok or False if we encountered an error during update.
543542 """
544543 if self .source :
545- _old_keys = self ._keys # just in case
546-
547- # reread everything
548- self ._keys = []
544+ new_keys = []
549545 updated = None
550546
551547 try :
552548 if self .local :
553549 if self .fileformat in ["jwks" , "jwk" ]:
554- updated = self ._do_local_jwk (self .source )
550+ updated , k = self ._do_local_jwk (self .source )
555551 elif self .fileformat == "der" :
556- updated = self ._do_local_der (self .source , self .keytype , self .keyusage )
552+ updated , k = self ._do_local_der (self .source , self .keytype , self .keyusage )
557553 elif self .remote :
558- updated = self ._do_remote ()
554+ updated , k = self ._do_remote (set_keys = False )
555+ if k :
556+ new_keys .extend (k )
559557 except Exception as err :
560558 LOGGER .error ("Key bundle update failed: %s" , err )
561- self ._keys = _old_keys # restore
562559 return False
563560
564561 if updated :
565562 now = time .time ()
566- for _key in _old_keys :
567- if _key not in self . _keys :
563+ for _key in self . _keys :
564+ if _key not in new_keys :
568565 if not _key .inactive_since : # If already marked don't mess
569566 _key .inactive_since = now
570- self ._keys .append (_key )
571- else :
572- self ._keys = _old_keys
567+ new_keys .append (_key )
568+ self ._keys = new_keys
573569
574570 return True
575571
@@ -585,9 +581,9 @@ def get(self, typ="", only_active=True):
585581
586582 if typ :
587583 _typs = [typ .lower (), typ .upper ()]
588- _keys = [k for k in self ._keys [:] if k .kty in _typs ]
584+ _keys = [k for k in self ._keys if k .kty in _typs ]
589585 else :
590- _keys = self ._keys [:]
586+ _keys = self ._keys
591587
592588 if only_active :
593589 return [k for k in _keys if not k .inactive_since ]
@@ -602,7 +598,7 @@ def keys(self, update: bool = True):
602598 """
603599 if update :
604600 self ._uptodate ()
605- return self ._keys [:]
601+ return self ._keys
606602
607603 def active_keys (self ):
608604 """Return the set of active keys."""
@@ -829,9 +825,11 @@ def load(self, spec):
829825 :param spec: Dictionary with attributes and value to populate the instance with
830826 :return: The instance itself
831827 """
828+
832829 _keys = spec .get ("keys" , [])
833830 if _keys :
834- self ._add_jwk_dicts (_keys )
831+ self ._keys .extend (self .jwk_dicts_as_keys (_keys ))
832+ self .last_updated = time .time ()
835833
836834 for attr , default in self .params .items ():
837835 val = spec .get (attr )
0 commit comments