Codebase list python-jwcrypto / ddff0b0
Cache pub/pri keys on retrieval Pyca rightfully performs consistency checks when importing keys and these operations are rather expensive. So cache keys once generated so that repeated uses of the same JWK do not incur undue cost of reloading the keys from scratch for each subsequent operation. with a simple test by hand: $ python >>> from jwcrypto import jwk >>> def test(): ... key = jwk.JWK.generate(kty='RSA', size=2048) ... for i in range(1000): ... k = key._get_private_key() ... >>> import timeit Before the patch: >>> print(timeit.timeit("test()", setup="from __main__ import test", number=10)) 35.80328264506534 After the patch: >>> print(timeit.timeit("test()", setup="from __main__ import test", number=10)) 0.9109518649056554 Resolves #243 Signed-off-by: Simo Sorce <simo@redhat.com> Simo Sorce authored 2 years ago Simo Sorce committed 2 years ago
1 changed file(s) with 76 addition(s) and 27 deletion(s). Raw diff Collapse all Expand all
300300 are provided.
301301 """
302302 super(JWK, self).__init__()
303 self._cache_pub_k = None
304 self._cache_pri_k = None
303305
304306 if 'generate' in kwargs:
305307 self.generate_key(**kwargs)
484486 def import_key(self, **kwargs):
485487 newkey = {}
486488 key_vals = 0
489 self._cache_pub_k = None
490 self._cache_pri_k = None
487491
488492 names = list(kwargs.keys())
489493
729733 def _decode_int(self, n):
730734 return int(hexlify(base64url_decode(n)), 16)
731735
732 def _rsa_pub(self):
736 def _rsa_pub_n(self):
733737 e = self._decode_int(self.get('e'))
734738 n = self._decode_int(self.get('n'))
735739 return rsa.RSAPublicNumbers(e, n)
736740
737 def _rsa_pri(self):
741 def _rsa_pri_n(self):
738742 p = self._decode_int(self.get('p'))
739743 q = self._decode_int(self.get('q'))
740744 d = self._decode_int(self.get('d'))
741745 dp = self._decode_int(self.get('dp'))
742746 dq = self._decode_int(self.get('dq'))
743747 qi = self._decode_int(self.get('qi'))
744 return rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, self._rsa_pub())
745
746 def _ec_pub(self, curve):
748 return rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, self._rsa_pub_n())
749
750 def _rsa_pub(self):
751 k = self._cache_pub_k
752 if k is None:
753 k = self._rsa_pub_n().public_key(default_backend())
754 self._cache_pub_k = k
755 return k
756
757 def _rsa_pri(self):
758 k = self._cache_pri_k
759 if k is None:
760 k = self._rsa_pri_n().private_key(default_backend())
761 self._cache_pri_k = k
762 return k
763
764 def _ec_pub_n(self, curve):
747765 x = self._decode_int(self.get('x'))
748766 y = self._decode_int(self.get('y'))
749767 return ec.EllipticCurvePublicNumbers(x, y, self.get_curve(curve))
750768
769 def _ec_pri_n(self, curve):
770 d = self._decode_int(self.get('d'))
771 return ec.EllipticCurvePrivateNumbers(d, self._ec_pub_n(curve))
772
773 def _ec_pub(self, curve):
774 k = self._cache_pub_k
775 if k is None:
776 k = self._ec_pub_n(curve).public_key(default_backend())
777 self._cache_pub_k = k
778 return k
779
751780 def _ec_pri(self, curve):
752 d = self._decode_int(self.get('d'))
753 return ec.EllipticCurvePrivateNumbers(d, self._ec_pub(curve))
781 k = self._cache_pri_k
782 if k is None:
783 k = self._ec_pri_n(curve).private_key(default_backend())
784 self._cache_pri_k = k
785 return k
754786
755787 def _okp_pub(self):
756 crv = self.get('crv')
757 try:
758 pubkey = _OKP_CURVES_TABLE[crv].pubkey
759 except KeyError as e:
760 raise InvalidJWKValue('Unknown curve "%s"' % crv) from e
761
762 x = base64url_decode(self.get('x'))
763 return pubkey.from_public_bytes(x)
788 k = self._cache_pub_k
789 if k is None:
790 crv = self.get('crv')
791 try:
792 pubkey = _OKP_CURVES_TABLE[crv].pubkey
793 except KeyError as e:
794 raise InvalidJWKValue('Unknown curve "%s"' % crv) from e
795
796 x = base64url_decode(self.get('x'))
797 k = pubkey.from_public_bytes(x)
798 self._cache_pub_k = k
799 return k
764800
765801 def _okp_pri(self):
766 crv = self.get('crv')
767 try:
768 privkey = _OKP_CURVES_TABLE[crv].privkey
769 except KeyError as e:
770 raise InvalidJWKValue('Unknown curve "%s"' % crv) from e
771
772 d = base64url_decode(self.get('d'))
773 return privkey.from_private_bytes(d)
802 k = self._cache_pri_k
803 if k is None:
804 crv = self.get('crv')
805 try:
806 privkey = _OKP_CURVES_TABLE[crv].privkey
807 except KeyError as e:
808 raise InvalidJWKValue('Unknown curve "%s"' % crv) from e
809
810 d = base64url_decode(self.get('d'))
811 k = privkey.from_private_bytes(d)
812 self._cache_pri_k = k
813 return k
774814
775815 def _get_public_key(self, arg=None):
776816 ktype = self.get('kty')
777817 if ktype == 'oct':
778818 return self.get('k')
779819 elif ktype == 'RSA':
780 return self._rsa_pub().public_key(default_backend())
820 return self._rsa_pub()
781821 elif ktype == 'EC':
782 return self._ec_pub(arg).public_key(default_backend())
822 return self._ec_pub(arg)
783823 elif ktype == 'OKP':
784824 return self._okp_pub()
785825 else:
790830 if ktype == 'oct':
791831 return self.get('k')
792832 elif ktype == 'RSA':
793 return self._rsa_pri().private_key(default_backend())
833 return self._rsa_pri()
794834 elif ktype == 'EC':
795 return self._ec_pri(arg).private_key(default_backend())
835 return self._ec_pri(arg)
796836 elif ktype == 'OKP':
797837 return self._okp_pri()
798838 else:
9681008
9691009 # Check if item is a key value and verify its format
9701010 if item in list(JWKValuesRegistry[kty].keys()):
1011 # Invalidate cached keys if any
1012 self._cache_pub_k = None
1013 self._cache_pri_k = None
9711014 if JWKValuesRegistry[kty][item].type == ParmType.b64:
9721015 try:
9731016 v = base64url_decode(value)
10271070 if self.get(name) is not None:
10281071 raise KeyError("Cannot remove 'kty', values present")
10291072
1073 kty = self.get('kty')
1074 if kty is not None and item in list(JWKValuesRegistry[kty].keys()):
1075 # Invalidate cached keys if any
1076 self._cache_pub_k = None
1077 self._cache_pri_k = None
1078
10301079 super(JWK, self).__delitem__(item)
10311080
10321081 def __eq__(self, other):