44import time
55import uuid
66from json import JSONDecodeError
7+ from typing import Dict
8+ from typing import List
9+ from typing import MutableMapping
10+ from typing import Optional
711
812from .exception import HeaderError
913from .exception import VerificationError
@@ -79,25 +83,26 @@ class JWT:
7983 def __init__ (
8084 self ,
8185 key_jar = None ,
82- iss = "" ,
83- lifetime = 0 ,
84- sign = True ,
85- sign_alg = "RS256" ,
86- encrypt = False ,
87- enc_enc = "A128GCM" ,
88- enc_alg = "RSA-OAEP-256" ,
89- msg_cls = None ,
90- iss2msg_cls = None ,
91- skew = 15 ,
92- allowed_sign_algs = None ,
93- allowed_enc_algs = None ,
94- allowed_enc_encs = None ,
95- allowed_max_lifetime = None ,
96- zip = "" ,
86+ iss : str = "" ,
87+ lifetime : int = 0 ,
88+ sign : bool = True ,
89+ sign_alg : str = "RS256" ,
90+ encrypt : bool = False ,
91+ enc_enc : str = "A128GCM" ,
92+ enc_alg : str = "RSA-OAEP-256" ,
93+ msg_cls : Optional [MutableMapping ] = None ,
94+ iss2msg_cls : Optional [Dict [str , str ]] = None ,
95+ skew : Optional [int ] = 15 ,
96+ allowed_sign_algs : Optional [List [str ]] = None ,
97+ allowed_enc_algs : Optional [List [str ]] = None ,
98+ allowed_enc_encs : Optional [List [str ]] = None ,
99+ allowed_max_lifetime : Optional [int ] = None ,
100+ zip : Optional [str ] = "" ,
101+ typ2msg_cls : Optional [Dict ] = None ,
97102 ):
98103 self .key_jar = key_jar # KeyJar instance
99104 self .iss = iss # My identifier
100- self .lifetime = lifetime # default life time of the signature
105+ self .lifetime = lifetime # default lifetime of the signature
101106 self .sign = sign # default signing or not
102107 self .alg = sign_alg # default signing algorithm
103108 self .encrypt = encrypt # default encrypting or not
@@ -107,6 +112,7 @@ def __init__(
107112 self .with_jti = False # If a jti should be added
108113 # A map between issuers and the message classes they use
109114 self .iss2msg_cls = iss2msg_cls or {}
115+ self .typ2msg_cls = typ2msg_cls or {}
110116 # Allowed time skew
111117 self .skew = skew
112118 # When verifying/decrypting
@@ -206,16 +212,30 @@ def pack_key(self, issuer_id="", kid=""):
206212
207213 return keys [0 ] # Might be more then one if kid == ''
208214
209- def pack (self , payload = None , kid = "" , issuer_id = "" , recv = "" , aud = None , iat = None , ** kwargs ):
215+ def message (self , signing_key , ** kwargs ):
216+ return json .dumps (kwargs )
217+
218+ def pack (
219+ self ,
220+ payload : Optional [dict ] = None ,
221+ kid : Optional [str ] = "" ,
222+ issuer_id : Optional [str ] = "" ,
223+ recv : Optional [str ] = "" ,
224+ aud : Optional [str ] = None ,
225+ iat : Optional [int ] = None ,
226+ jws_headers : Optional [Dict [str , str ]] = None ,
227+ ** kwargs
228+ ) -> str :
210229 """
211230
212231 :param payload: Information to be carried as payload in the JWT
213232 :param kid: Key ID
214- :param issuer_id: The owner of the the keys that are to be used for signing
233+ :param issuer_id: The owner of the keys that are to be used for signing
215234 :param recv: The intended immediate receiver
216235 :param aud: Intended audience for this JWS/JWE, not expected to
217236 contain the recipient.
218237 :param iat: Override issued at (default current timestamp)
238+ :param jws_headers: JWS headers
219239 :param kwargs: Extra keyword arguments
220240 :return: A signed or signed and encrypted Json Web Token
221241 """
@@ -249,10 +269,12 @@ def pack(self, payload=None, kid="", issuer_id="", recv="", aud=None, iat=None,
249269 else :
250270 _key = None
251271
252- _jws = JWS (json .dumps (_args ), alg = self .alg )
253- _sjwt = _jws .sign_compact ([_key ])
272+ jws_headers = jws_headers or {}
273+
274+ _jws = JWS (self .message (signing_key = _key , ** _args ), alg = self .alg )
275+ _sjwt = _jws .sign_compact ([_key ], protected = jws_headers )
254276 else :
255- _sjwt = json . dumps ( _args )
277+ _sjwt = self . message ( signing_key = None , ** _args )
256278
257279 if _encrypt :
258280 if not self .sign :
@@ -300,8 +322,7 @@ def verify_profile(msg_cls, info, **kwargs):
300322 :return: The verified message as a msg_cls instance.
301323 """
302324 _msg = msg_cls (** info )
303- if not _msg .verify (** kwargs ):
304- raise VerificationError ()
325+ _msg .verify (** kwargs )
305326 return _msg
306327
307328 def unpack (self , token , timestamp = None ):
@@ -373,11 +394,12 @@ def unpack(self, token, timestamp=None):
373394 if self .msg_cls :
374395 _msg_cls = self .msg_cls
375396 else :
376- try :
377- # try to find a issuer specific message class
378- _msg_cls = self .iss2msg_cls [_info ["iss" ]]
379- except KeyError :
380- _msg_cls = None
397+ _msg_cls = None
398+ # try to find an issuer specific message class
399+ if "iss" in _info :
400+ _msg_cls = self .iss2msg_cls .get (_info ["iss" ])
401+ if not _msg_cls and _jws_header and "typ" in _jws_header :
402+ _msg_cls = self .typ2msg_cls .get (_jws_header ["typ" ])
381403
382404 timestamp = timestamp or utc_time_sans_frac ()
383405
0 commit comments