|
1 | 1 | """Basic JSON Web Token implementation."""
|
2 | 2 | import json
|
| 3 | +from json import JSONDecodeError |
3 | 4 | import logging
|
4 | 5 | import time
|
5 |
| -import uuid |
6 |
| -from json import JSONDecodeError |
7 | 6 | from typing import Dict
|
| 7 | +from typing import List |
| 8 | +from typing import MutableMapping |
8 | 9 | from typing import Optional
|
| 10 | +import uuid |
9 | 11 |
|
10 | 12 | from .exception import HeaderError
|
11 | 13 | from .exception import VerificationError
|
@@ -79,23 +81,24 @@ class JWT:
|
79 | 81 | """The basic JSON Web Token class."""
|
80 | 82 |
|
81 | 83 | def __init__(
|
82 |
| - self, |
83 |
| - key_jar=None, |
84 |
| - iss="", |
85 |
| - lifetime=0, |
86 |
| - sign=True, |
87 |
| - sign_alg="RS256", |
88 |
| - encrypt=False, |
89 |
| - enc_enc="A128GCM", |
90 |
| - enc_alg="RSA-OAEP-256", |
91 |
| - msg_cls=None, |
92 |
| - iss2msg_cls=None, |
93 |
| - skew=15, |
94 |
| - allowed_sign_algs=None, |
95 |
| - allowed_enc_algs=None, |
96 |
| - allowed_enc_encs=None, |
97 |
| - allowed_max_lifetime=None, |
98 |
| - zip="", |
| 84 | + self, |
| 85 | + key_jar=None, |
| 86 | + iss: str="", |
| 87 | + lifetime: int = 0, |
| 88 | + sign: bool = True, |
| 89 | + sign_alg: str = "RS256", |
| 90 | + encrypt: bool = False, |
| 91 | + enc_enc: str = "A128GCM", |
| 92 | + enc_alg: str = "RSA-OAEP-256", |
| 93 | + msg_cls: MutableMapping = None, |
| 94 | + iss2msg_cls: Dict[str, str] = None, |
| 95 | + skew: int = 15, |
| 96 | + allowed_sign_algs: List[str] = None, |
| 97 | + allowed_enc_algs: List[str] = None, |
| 98 | + allowed_enc_encs: List[str] = None, |
| 99 | + allowed_max_lifetime: int = None, |
| 100 | + zip: str = "", |
| 101 | + typ2msg_cls: Dict[str, str] = None |
99 | 102 | ):
|
100 | 103 | self.key_jar = key_jar # KeyJar instance
|
101 | 104 | self.iss = iss # My identifier
|
@@ -212,15 +215,15 @@ def message(self, signing_key, **kwargs):
|
212 | 215 | return json.dumps(kwargs)
|
213 | 216 |
|
214 | 217 | def pack(
|
215 |
| - self, |
216 |
| - payload: Optional[dict] = None, |
217 |
| - kid: Optional[str] = "", |
218 |
| - issuer_id: Optional[str] = "", |
219 |
| - recv: Optional[str] = "", |
220 |
| - aud: Optional[str] = None, |
221 |
| - iat: Optional[int] = None, |
222 |
| - jws_headers: Dict[str, str] = None, |
223 |
| - **kwargs |
| 218 | + self, |
| 219 | + payload: Optional[dict] = None, |
| 220 | + kid: Optional[str] = "", |
| 221 | + issuer_id: Optional[str] = "", |
| 222 | + recv: Optional[str] = "", |
| 223 | + aud: Optional[str] = None, |
| 224 | + iat: Optional[int] = None, |
| 225 | + jws_headers: Dict[str, str] = None, |
| 226 | + **kwargs |
224 | 227 | ) -> str:
|
225 | 228 | """
|
226 | 229 |
|
@@ -319,8 +322,7 @@ def verify_profile(msg_cls, info, **kwargs):
|
319 | 322 | :return: The verified message as a msg_cls instance.
|
320 | 323 | """
|
321 | 324 | _msg = msg_cls(**info)
|
322 |
| - if not _msg.verify(**kwargs): |
323 |
| - raise VerificationError() |
| 325 | + _msg.verify(**kwargs) |
324 | 326 | return _msg
|
325 | 327 |
|
326 | 328 | def unpack(self, token, timestamp=None):
|
@@ -392,11 +394,10 @@ def unpack(self, token, timestamp=None):
|
392 | 394 | if self.msg_cls:
|
393 | 395 | _msg_cls = self.msg_cls
|
394 | 396 | else:
|
395 |
| - try: |
396 |
| - # try to find a issuer specific message class |
397 |
| - _msg_cls = self.iss2msg_cls[_info["iss"]] |
398 |
| - except KeyError: |
399 |
| - _msg_cls = None |
| 397 | + # try to find an issuer specific message class |
| 398 | + _msg_cls = self.iss2msg_cls.get(_info["iss"]) |
| 399 | + if not _msg_cls: |
| 400 | + _msg_cls = self.typ2msg_cls.get(_jws_header['typ']) |
400 | 401 |
|
401 | 402 | timestamp = timestamp or utc_time_sans_frac()
|
402 | 403 |
|
|
0 commit comments