Cache pub/pri keys on retrieval
Pyca rightfully performs consistency checks when importing keys and
these operations are rather expensive. So cache keys once generated so
that repeated uses of the same JWK do not incur undue cost of reloading
the keys from scratch for each subsequent operation.
with a simple test by hand:
$ python
>>> from jwcrypto import jwk
>>> def test():
... key = jwk.JWK.generate(kty='RSA', size=2048)
... for i in range(1000):
... k = key._get_private_key()
...
>>> import timeit
Before the patch:
>>> print(timeit.timeit("test()", setup="from __main__ import test", number=10))
35.80328264506534
After the patch:
>>> print(timeit.timeit("test()", setup="from __main__ import test", number=10))
0.9109518649056554
Resolves #243
Signed-off-by: Simo Sorce <simo@redhat.com>
Simo Sorce authored 2 years ago
Simo Sorce committed 2 years ago
300 | 300 | are provided. |
301 | 301 | """ |
302 | 302 | super(JWK, self).__init__() |
303 | self._cache_pub_k = None | |
304 | self._cache_pri_k = None | |
303 | 305 | |
304 | 306 | if 'generate' in kwargs: |
305 | 307 | self.generate_key(**kwargs) |
484 | 486 | def import_key(self, **kwargs): |
485 | 487 | newkey = {} |
486 | 488 | key_vals = 0 |
489 | self._cache_pub_k = None | |
490 | self._cache_pri_k = None | |
487 | 491 | |
488 | 492 | names = list(kwargs.keys()) |
489 | 493 | |
729 | 733 | def _decode_int(self, n): |
730 | 734 | return int(hexlify(base64url_decode(n)), 16) |
731 | 735 | |
732 | def _rsa_pub(self): | |
736 | def _rsa_pub_n(self): | |
733 | 737 | e = self._decode_int(self.get('e')) |
734 | 738 | n = self._decode_int(self.get('n')) |
735 | 739 | return rsa.RSAPublicNumbers(e, n) |
736 | 740 | |
737 | def _rsa_pri(self): | |
741 | def _rsa_pri_n(self): | |
738 | 742 | p = self._decode_int(self.get('p')) |
739 | 743 | q = self._decode_int(self.get('q')) |
740 | 744 | d = self._decode_int(self.get('d')) |
741 | 745 | dp = self._decode_int(self.get('dp')) |
742 | 746 | dq = self._decode_int(self.get('dq')) |
743 | 747 | qi = self._decode_int(self.get('qi')) |
744 | return rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, self._rsa_pub()) | |
745 | ||
746 | def _ec_pub(self, curve): | |
748 | return rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, self._rsa_pub_n()) | |
749 | ||
750 | def _rsa_pub(self): | |
751 | k = self._cache_pub_k | |
752 | if k is None: | |
753 | k = self._rsa_pub_n().public_key(default_backend()) | |
754 | self._cache_pub_k = k | |
755 | return k | |
756 | ||
757 | def _rsa_pri(self): | |
758 | k = self._cache_pri_k | |
759 | if k is None: | |
760 | k = self._rsa_pri_n().private_key(default_backend()) | |
761 | self._cache_pri_k = k | |
762 | return k | |
763 | ||
764 | def _ec_pub_n(self, curve): | |
747 | 765 | x = self._decode_int(self.get('x')) |
748 | 766 | y = self._decode_int(self.get('y')) |
749 | 767 | return ec.EllipticCurvePublicNumbers(x, y, self.get_curve(curve)) |
750 | 768 | |
769 | def _ec_pri_n(self, curve): | |
770 | d = self._decode_int(self.get('d')) | |
771 | return ec.EllipticCurvePrivateNumbers(d, self._ec_pub_n(curve)) | |
772 | ||
773 | def _ec_pub(self, curve): | |
774 | k = self._cache_pub_k | |
775 | if k is None: | |
776 | k = self._ec_pub_n(curve).public_key(default_backend()) | |
777 | self._cache_pub_k = k | |
778 | return k | |
779 | ||
751 | 780 | def _ec_pri(self, curve): |
752 | d = self._decode_int(self.get('d')) | |
753 | return ec.EllipticCurvePrivateNumbers(d, self._ec_pub(curve)) | |
781 | k = self._cache_pri_k | |
782 | if k is None: | |
783 | k = self._ec_pri_n(curve).private_key(default_backend()) | |
784 | self._cache_pri_k = k | |
785 | return k | |
754 | 786 | |
755 | 787 | def _okp_pub(self): |
756 | crv = self.get('crv') | |
757 | try: | |
758 | pubkey = _OKP_CURVES_TABLE[crv].pubkey | |
759 | except KeyError as e: | |
760 | raise InvalidJWKValue('Unknown curve "%s"' % crv) from e | |
761 | ||
762 | x = base64url_decode(self.get('x')) | |
763 | return pubkey.from_public_bytes(x) | |
788 | k = self._cache_pub_k | |
789 | if k is None: | |
790 | crv = self.get('crv') | |
791 | try: | |
792 | pubkey = _OKP_CURVES_TABLE[crv].pubkey | |
793 | except KeyError as e: | |
794 | raise InvalidJWKValue('Unknown curve "%s"' % crv) from e | |
795 | ||
796 | x = base64url_decode(self.get('x')) | |
797 | k = pubkey.from_public_bytes(x) | |
798 | self._cache_pub_k = k | |
799 | return k | |
764 | 800 | |
765 | 801 | def _okp_pri(self): |
766 | crv = self.get('crv') | |
767 | try: | |
768 | privkey = _OKP_CURVES_TABLE[crv].privkey | |
769 | except KeyError as e: | |
770 | raise InvalidJWKValue('Unknown curve "%s"' % crv) from e | |
771 | ||
772 | d = base64url_decode(self.get('d')) | |
773 | return privkey.from_private_bytes(d) | |
802 | k = self._cache_pri_k | |
803 | if k is None: | |
804 | crv = self.get('crv') | |
805 | try: | |
806 | privkey = _OKP_CURVES_TABLE[crv].privkey | |
807 | except KeyError as e: | |
808 | raise InvalidJWKValue('Unknown curve "%s"' % crv) from e | |
809 | ||
810 | d = base64url_decode(self.get('d')) | |
811 | k = privkey.from_private_bytes(d) | |
812 | self._cache_pri_k = k | |
813 | return k | |
774 | 814 | |
775 | 815 | def _get_public_key(self, arg=None): |
776 | 816 | ktype = self.get('kty') |
777 | 817 | if ktype == 'oct': |
778 | 818 | return self.get('k') |
779 | 819 | elif ktype == 'RSA': |
780 | return self._rsa_pub().public_key(default_backend()) | |
820 | return self._rsa_pub() | |
781 | 821 | elif ktype == 'EC': |
782 | return self._ec_pub(arg).public_key(default_backend()) | |
822 | return self._ec_pub(arg) | |
783 | 823 | elif ktype == 'OKP': |
784 | 824 | return self._okp_pub() |
785 | 825 | else: |
790 | 830 | if ktype == 'oct': |
791 | 831 | return self.get('k') |
792 | 832 | elif ktype == 'RSA': |
793 | return self._rsa_pri().private_key(default_backend()) | |
833 | return self._rsa_pri() | |
794 | 834 | elif ktype == 'EC': |
795 | return self._ec_pri(arg).private_key(default_backend()) | |
835 | return self._ec_pri(arg) | |
796 | 836 | elif ktype == 'OKP': |
797 | 837 | return self._okp_pri() |
798 | 838 | else: |
968 | 1008 | |
969 | 1009 | # Check if item is a key value and verify its format |
970 | 1010 | if item in list(JWKValuesRegistry[kty].keys()): |
1011 | # Invalidate cached keys if any | |
1012 | self._cache_pub_k = None | |
1013 | self._cache_pri_k = None | |
971 | 1014 | if JWKValuesRegistry[kty][item].type == ParmType.b64: |
972 | 1015 | try: |
973 | 1016 | v = base64url_decode(value) |
1027 | 1070 | if self.get(name) is not None: |
1028 | 1071 | raise KeyError("Cannot remove 'kty', values present") |
1029 | 1072 | |
1073 | kty = self.get('kty') | |
1074 | if kty is not None and item in list(JWKValuesRegistry[kty].keys()): | |
1075 | # Invalidate cached keys if any | |
1076 | self._cache_pub_k = None | |
1077 | self._cache_pri_k = None | |
1078 | ||
1030 | 1079 | super(JWK, self).__delitem__(item) |
1031 | 1080 | |
1032 | 1081 | def __eq__(self, other): |