diff --git a/tests/test_ciphers.py b/tests/test_ciphers.py index b40ddfd..3151e1c 100644 --- a/tests/test_ciphers.py +++ b/tests/test_ciphers.py @@ -333,6 +333,14 @@ def test_chacha_enc_dec(chacha_obj): def rsa_private(vectors): return RsaPrivate(vectors[RsaPrivate].key) + @pytest.fixture + def rsa_private_oaep(vectors): + return RsaPrivate(vectors[RsaPrivate].key, hash_type=HASH_TYPE_SHA) + + @pytest.fixture + def rsa_private_pss(vectors): + return RsaPrivate(vectors[RsaPrivate].key, hash_type=HASH_TYPE_SHA256) + @pytest.fixture def rsa_private_pkcs8(vectors): return RsaPrivate(vectors[RsaPrivate].pkcs8_key) @@ -341,6 +349,14 @@ def rsa_private_pkcs8(vectors): def rsa_public(vectors): return RsaPublic(vectors[RsaPublic].key) + @pytest.fixture + def rsa_public_oaep(vectors): + return RsaPublic(vectors[RsaPublic].key, hash_type=HASH_TYPE_SHA) + + @pytest.fixture + def rsa_public_pss(vectors): + return RsaPublic(vectors[RsaPublic].key, hash_type=HASH_TYPE_SHA256) + @pytest.fixture def rsa_private_pem(vectors): with open(vectors[RsaPrivate].pem, "rb") as f: @@ -382,21 +398,21 @@ def test_rsa_encrypt_decrypt(rsa_private, rsa_public): assert 1024 / 8 == len(ciphertext) == rsa_private.output_size assert plaintext == rsa_private.decrypt(ciphertext) - def test_rsa_encrypt_decrypt_pad_oaep(rsa_private, rsa_public): + def test_rsa_encrypt_decrypt_pad_oaep(rsa_private_oaep, rsa_public_oaep): plaintext = t2b("Everyone gets Friday off.") # normal usage, encrypt with public, decrypt with private - ciphertext = rsa_public.encrypt_oaep(plaintext, HASH_TYPE_SHA, MGF1SHA1, "") + ciphertext = rsa_public_oaep.encrypt_oaep(plaintext) - assert 1024 / 8 == len(ciphertext) == rsa_public.output_size - assert plaintext == rsa_private.decrypt_oaep(ciphertext, HASH_TYPE_SHA, MGF1SHA1, "") + assert 1024 / 8 == len(ciphertext) == rsa_public_oaep.output_size + assert plaintext == rsa_private_oaep.decrypt_oaep(ciphertext) # private object holds both private and public info, so it can also encrypt # using the known public key. - ciphertext = rsa_private.encrypt_oaep(plaintext, HASH_TYPE_SHA, MGF1SHA1, "") + ciphertext = rsa_private_oaep.encrypt_oaep(plaintext) - assert 1024 / 8 == len(ciphertext) == rsa_private.output_size - assert plaintext == rsa_private.decrypt_oaep(ciphertext, HASH_TYPE_SHA, MGF1SHA1, "") + assert 1024 / 8 == len(ciphertext) == rsa_private_oaep.output_size + assert plaintext == rsa_private_oaep.decrypt_oaep(ciphertext) def test_rsa_pkcs8_encrypt_decrypt(rsa_private_pkcs8, rsa_public): @@ -433,21 +449,21 @@ def test_rsa_sign_verify(rsa_private, rsa_public): assert plaintext == rsa_private.verify(signature) if _lib.RSA_PSS_ENABLED: - def test_rsa_pss_sign_verify(rsa_private, rsa_public): + def test_rsa_pss_sign_verify(rsa_private_pss, rsa_public_pss): plaintext = t2b("Everyone gets Friday off yippee.") # normal usage, sign with private, verify with public - signature = rsa_private.sign_pss(plaintext, HASH_TYPE_SHA256, MGF1SHA256) + signature = rsa_private_pss.sign_pss(plaintext) - assert 1024 / 8 == len(signature) == rsa_private.output_size - assert 0 == rsa_public.verify_pss(plaintext, signature, HASH_TYPE_SHA256, MGF1SHA256) + assert 1024 / 8 == len(signature) == rsa_private_pss.output_size + assert 0 == rsa_public_pss.verify_pss(plaintext, signature) # private object holds both private and public info, so it can also verify # using the known public key. - signature = rsa_private.sign_pss(plaintext, HASH_TYPE_SHA256, MGF1SHA256) + signature = rsa_private_pss.sign_pss(plaintext) - assert 1024 / 8 == len(signature) == rsa_private.output_size - assert 0 == rsa_private.verify_pss(plaintext, signature, HASH_TYPE_SHA256, MGF1SHA256) + assert 1024 / 8 == len(signature) == rsa_private_pss.output_size + assert 0 == rsa_private_pss.verify_pss(plaintext, signature) def test_rsa_sign_verify_pem(rsa_private_pem, rsa_public_pem): plaintext = t2b("Everyone gets Friday off.") diff --git a/wolfcrypt/ciphers.py b/wolfcrypt/ciphers.py index 3e22bd9..49597ed 100644 --- a/wolfcrypt/ciphers.py +++ b/wolfcrypt/ciphers.py @@ -452,6 +452,8 @@ def _decrypt(self, destination, source): if _lib.RSA_ENABLED: class _Rsa(object): # pylint: disable=too-few-public-methods RSA_MIN_PAD_SIZE = 11 + _mgf = None + _hash_type = None def __init__(self): self.native_object = _ffi.new("RsaKey *") @@ -473,11 +475,30 @@ def __del__(self): if self.native_object: self._delete(self.native_object) + def set_mgf(self, mgf): + self._mgf = mgf + + def _get_mgf(self): + if self._hash_type == _lib.WC_HASH_TYPE_SHA: + self._mgf = _lib.WC_MGF1SHA1 + elif self._hash_type == _lib.WC_HASH_TYPE_SHA224: + self._mgf = _lib.WC_MGF1SHA224 + elif self._hash_type == _lib.WC_HASH_TYPE_SHA256: + self._mgf = _lib.WC_MGF1SHA256 + elif self._hash_type == _lib.WC_HASH_TYPE_SHA384: + self._mgf = _lib.WC_MGF1SHA384 + elif self._hash_type == _lib.WC_HASH_TYPE_SHA512: + self._mgf = _lib.WC_MGF1SHA512 + else: + self._mgf = _lib.WC_MGF1NONE + + class RsaPublic(_Rsa): - def __init__(self, key=None): + def __init__(self, key=None, hash_type=None): if key != None: key = t2b(key) + self._hash_type = hash_type _Rsa.__init__(self) @@ -524,17 +545,18 @@ def encrypt(self, plaintext): return _ffi.buffer(ciphertext)[:] - def encrypt_oaep(self, plaintext, hash_type, mgf, label): + def encrypt_oaep(self, plaintext, label=""): plaintext = t2b(plaintext) label = t2b(label) ciphertext = _ffi.new("byte[%d]" % self.output_size) - + if self._mgf is None: + self._get_mgf() ret = _lib.wc_RsaPublicEncrypt_ex(plaintext, len(plaintext), ciphertext, self.output_size, self.native_object, self._random.native_object, - _lib.WC_RSA_OAEP_PAD, hash_type, - mgf, label, len(label)) + _lib.WC_RSA_OAEP_PAD, self._hash_type, + self._mgf, label, len(label)) if ret != self.output_size: # pragma: no cover raise WolfCryptError("Encryption error (%d)" % ret) @@ -563,7 +585,7 @@ def verify(self, signature): return _ffi.buffer(plaintext, ret)[:] if _lib.RSA_PSS_ENABLED: - def verify_pss(self, plaintext, signature, hash_type, mgf): + def verify_pss(self, plaintext, signature): """ Verifies **signature**, using the public key data in the object. The signature's length must be equal to: @@ -574,17 +596,19 @@ def verify_pss(self, plaintext, signature, hash_type, mgf): """ plaintext = t2b(plaintext) signature = t2b(signature) + if self._mgf is None: + self._get_mgf() verify = _ffi.new("byte[%d]" % self.output_size) ret = _lib.wc_RsaPSS_Verify(signature, len(signature), verify, self.output_size, - hash_type, mgf, + self._hash_type, self._mgf, self.native_object) if ret < 0: # pragma: no cover raise WolfCryptError("Verify error (%d)" % ret) ret = _lib.wc_RsaPSS_CheckPadding(plaintext, len(plaintext), - verify, ret, hash_type) + verify, ret, self._hash_type) return ret @@ -613,10 +637,10 @@ def make_key(cls, size, rng=Random()): return rsa - def __init__(self, key = None): # pylint: disable=super-init-not-called + def __init__(self, key=None, hash_type=None): # pylint: disable=super-init-not-called _Rsa.__init__(self) # pylint: disable=non-parent-init-called - + self._hash_type = hash_type idx = _ffi.new("word32*") idx[0] = 0 @@ -692,7 +716,7 @@ def decrypt(self, ciphertext): return _ffi.buffer(plaintext, ret)[:] - def decrypt_oaep(self, ciphertext, hash_type, mgf, label): + def decrypt_oaep(self, ciphertext, label=""): """ Decrypts **ciphertext**, using the private key data in the object. The ciphertext's length must be equal to: @@ -704,11 +728,13 @@ def decrypt_oaep(self, ciphertext, hash_type, mgf, label): ciphertext = t2b(ciphertext) label = t2b(label) plaintext = _ffi.new("byte[%d]" % self.output_size) + if self._mgf is None: + self._get_mgf() ret = _lib.wc_RsaPrivateDecrypt_ex(ciphertext, len(ciphertext), plaintext, self.output_size, self.native_object, - _lib.WC_RSA_OAEP_PAD, hash_type, - mgf, label, len(label)) + _lib.WC_RSA_OAEP_PAD, self._hash_type, + self._mgf, label, len(label)) if ret < 0: # pragma: no cover raise WolfCryptError("Decryption error (%d)" % ret) @@ -738,7 +764,7 @@ def sign(self, plaintext): return _ffi.buffer(signature, self.output_size)[:] if _lib.RSA_PSS_ENABLED: - def sign_pss(self, plaintext, hash_type, mgf): + def sign_pss(self, plaintext): """ Signs **plaintext**, using the private key data in the object. The plaintext's length must not be greater than: @@ -749,10 +775,11 @@ def sign_pss(self, plaintext, hash_type, mgf): """ plaintext = t2b(plaintext) signature = _ffi.new("byte[%d]" % self.output_size) - + if self._mgf is None: + self._get_mgf() ret = _lib.wc_RsaPSS_Sign(plaintext, len(plaintext), signature, self.output_size, - hash_type, mgf, + self._hash_type, self._mgf, self.native_object, self._random.native_object)