4
4
import time
5
5
import uuid
6
6
from json import JSONDecodeError
7
+ from typing import Dict
8
+ from typing import List
9
+ from typing import MutableMapping
10
+ from typing import Optional
7
11
8
12
from .exception import HeaderError
9
13
from .exception import VerificationError
@@ -79,25 +83,26 @@ class JWT:
79
83
def __init__ (
80
84
self ,
81
85
key_jar = None ,
82
- iss = "" ,
83
- lifetime = 0 ,
84
- sign = True ,
85
- sign_alg = "RS256" ,
86
- encrypt = False ,
87
- enc_enc = "A128GCM" ,
88
- enc_alg = "RSA-OAEP-256" ,
89
- msg_cls = None ,
90
- iss2msg_cls = None ,
91
- skew = 15 ,
92
- allowed_sign_algs = None ,
93
- allowed_enc_algs = None ,
94
- allowed_enc_encs = None ,
95
- allowed_max_lifetime = None ,
96
- zip = "" ,
86
+ iss : str = "" ,
87
+ lifetime : int = 0 ,
88
+ sign : bool = True ,
89
+ sign_alg : str = "RS256" ,
90
+ encrypt : bool = False ,
91
+ enc_enc : str = "A128GCM" ,
92
+ enc_alg : str = "RSA-OAEP-256" ,
93
+ msg_cls : Optional [MutableMapping ] = None ,
94
+ iss2msg_cls : Optional [Dict [str , str ]] = None ,
95
+ skew : Optional [int ] = 15 ,
96
+ allowed_sign_algs : Optional [List [str ]] = None ,
97
+ allowed_enc_algs : Optional [List [str ]] = None ,
98
+ allowed_enc_encs : Optional [List [str ]] = None ,
99
+ allowed_max_lifetime : Optional [int ] = None ,
100
+ zip : Optional [str ] = "" ,
101
+ typ2msg_cls : Optional [Dict ] = None ,
97
102
):
98
103
self .key_jar = key_jar # KeyJar instance
99
104
self .iss = iss # My identifier
100
- self .lifetime = lifetime # default life time of the signature
105
+ self .lifetime = lifetime # default lifetime of the signature
101
106
self .sign = sign # default signing or not
102
107
self .alg = sign_alg # default signing algorithm
103
108
self .encrypt = encrypt # default encrypting or not
@@ -107,6 +112,7 @@ def __init__(
107
112
self .with_jti = False # If a jti should be added
108
113
# A map between issuers and the message classes they use
109
114
self .iss2msg_cls = iss2msg_cls or {}
115
+ self .typ2msg_cls = typ2msg_cls or {}
110
116
# Allowed time skew
111
117
self .skew = skew
112
118
# When verifying/decrypting
@@ -206,16 +212,30 @@ def pack_key(self, issuer_id="", kid=""):
206
212
207
213
return keys [0 ] # Might be more then one if kid == ''
208
214
209
- def pack (self , payload = None , kid = "" , issuer_id = "" , recv = "" , aud = None , iat = None , ** kwargs ):
215
+ def message (self , signing_key , ** kwargs ):
216
+ return json .dumps (kwargs )
217
+
218
+ def pack (
219
+ self ,
220
+ payload : Optional [dict ] = None ,
221
+ kid : Optional [str ] = "" ,
222
+ issuer_id : Optional [str ] = "" ,
223
+ recv : Optional [str ] = "" ,
224
+ aud : Optional [str ] = None ,
225
+ iat : Optional [int ] = None ,
226
+ jws_headers : Optional [Dict [str , str ]] = None ,
227
+ ** kwargs
228
+ ) -> str :
210
229
"""
211
230
212
231
:param payload: Information to be carried as payload in the JWT
213
232
:param kid: Key ID
214
- :param issuer_id: The owner of the the keys that are to be used for signing
233
+ :param issuer_id: The owner of the keys that are to be used for signing
215
234
:param recv: The intended immediate receiver
216
235
:param aud: Intended audience for this JWS/JWE, not expected to
217
236
contain the recipient.
218
237
:param iat: Override issued at (default current timestamp)
238
+ :param jws_headers: JWS headers
219
239
:param kwargs: Extra keyword arguments
220
240
:return: A signed or signed and encrypted Json Web Token
221
241
"""
@@ -249,10 +269,12 @@ def pack(self, payload=None, kid="", issuer_id="", recv="", aud=None, iat=None,
249
269
else :
250
270
_key = None
251
271
252
- _jws = JWS (json .dumps (_args ), alg = self .alg )
253
- _sjwt = _jws .sign_compact ([_key ])
272
+ jws_headers = jws_headers or {}
273
+
274
+ _jws = JWS (self .message (signing_key = _key , ** _args ), alg = self .alg )
275
+ _sjwt = _jws .sign_compact ([_key ], protected = jws_headers )
254
276
else :
255
- _sjwt = json . dumps ( _args )
277
+ _sjwt = self . message ( signing_key = None , ** _args )
256
278
257
279
if _encrypt :
258
280
if not self .sign :
@@ -300,8 +322,7 @@ def verify_profile(msg_cls, info, **kwargs):
300
322
:return: The verified message as a msg_cls instance.
301
323
"""
302
324
_msg = msg_cls (** info )
303
- if not _msg .verify (** kwargs ):
304
- raise VerificationError ()
325
+ _msg .verify (** kwargs )
305
326
return _msg
306
327
307
328
def unpack (self , token , timestamp = None ):
@@ -373,11 +394,12 @@ def unpack(self, token, timestamp=None):
373
394
if self .msg_cls :
374
395
_msg_cls = self .msg_cls
375
396
else :
376
- try :
377
- # try to find a issuer specific message class
378
- _msg_cls = self .iss2msg_cls [_info ["iss" ]]
379
- except KeyError :
380
- _msg_cls = None
397
+ _msg_cls = None
398
+ # try to find an issuer specific message class
399
+ if "iss" in _info :
400
+ _msg_cls = self .iss2msg_cls .get (_info ["iss" ])
401
+ if not _msg_cls and _jws_header and "typ" in _jws_header :
402
+ _msg_cls = self .typ2msg_cls .get (_jws_header ["typ" ])
381
403
382
404
timestamp = timestamp or utc_time_sans_frac ()
383
405
0 commit comments