From b3a9f9c1751f7f72b3b89ff1ffe9271109b98bd0 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Wed, 1 Nov 2023 14:35:47 -0700 Subject: [PATCH] Convert symmetric ciphers to Rust --- .../hazmat/backends/openssl/backend.py | 166 +---- .../hazmat/backends/openssl/ciphers.py | 282 --------- .../bindings/_rust/openssl/__init__.pyi | 2 + .../hazmat/bindings/_rust/openssl/ciphers.pyi | 38 ++ .../hazmat/bindings/openssl/binding.py | 8 +- .../hazmat/primitives/ciphers/base.py | 143 +---- src/rust/src/backend/cipher_registry.rs | 181 +++++- src/rust/src/backend/ciphers.rs | 567 ++++++++++++++++++ src/rust/src/backend/mod.rs | 2 + src/rust/src/buf.rs | 59 +- src/rust/src/exceptions.rs | 2 + src/rust/src/types.rs | 43 ++ src/rust/src/x509/common.rs | 2 +- tests/hazmat/backends/test_openssl.py | 31 - tests/hazmat/primitives/test_aes_gcm.py | 53 +- tests/hazmat/primitives/test_pkcs12.py | 5 +- 16 files changed, 899 insertions(+), 685 deletions(-) delete mode 100644 src/cryptography/hazmat/backends/openssl/ciphers.py create mode 100644 src/cryptography/hazmat/bindings/_rust/openssl/ciphers.pyi create mode 100644 src/rust/src/backend/ciphers.rs diff --git a/src/cryptography/hazmat/backends/openssl/backend.py b/src/cryptography/hazmat/backends/openssl/backend.py index 3cf01664685c6..4a47f8902d0e9 100644 --- a/src/cryptography/hazmat/backends/openssl/backend.py +++ b/src/cryptography/hazmat/backends/openssl/backend.py @@ -6,22 +6,12 @@ import collections import contextlib -import itertools import typing from cryptography import utils, x509 from cryptography.exceptions import UnsupportedAlgorithm -from cryptography.hazmat.backends.openssl.ciphers import _CipherContext from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.hazmat.bindings.openssl import binding -from cryptography.hazmat.decrepit.ciphers.algorithms import ( - ARC4, - CAST5, - IDEA, - SEED, - Blowfish, - TripleDES, -) from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding from cryptography.hazmat.primitives.asymmetric import ec @@ -40,21 +30,9 @@ ) from cryptography.hazmat.primitives.ciphers.algorithms import ( AES, - AES128, - AES256, - SM4, - Camellia, - ChaCha20, ) from cryptography.hazmat.primitives.ciphers.modes import ( CBC, - CFB, - CFB8, - CTR, - ECB, - GCM, - OFB, - XTS, Mode, ) from cryptography.hazmat.primitives.serialization.pkcs12 import ( @@ -70,7 +48,7 @@ # Not actually supported, just used as a marker for some serialization tests. class _RC2: - pass + key_size = 128 class Backend: @@ -117,12 +95,6 @@ def __init__(self) -> None: self._lib = self._binding.lib self._fips_enabled = rust_openssl.is_fips_enabled() - self._cipher_registry: dict[ - tuple[type[CipherAlgorithm], type[Mode]], - typing.Callable, - ] = {} - self._register_default_ciphers() - def __repr__(self) -> str: return "".format( self.openssl_version_text(), @@ -130,12 +102,8 @@ def __repr__(self) -> str: self._binding._legacy_provider_loaded, ) - def openssl_assert( - self, - ok: bool, - errors: list[rust_openssl.OpenSSLError] | None = None, - ) -> None: - return binding._openssl_assert(ok, errors=errors) + def openssl_assert(self, ok: bool) -> None: + return binding._openssl_assert(ok) def _enable_fips(self) -> None: # This function enables FIPS mode for OpenSSL 3.0.0 on installs that @@ -210,103 +178,7 @@ def cipher_supported(self, cipher: CipherAlgorithm, mode: Mode) -> bool: if not isinstance(cipher, self._fips_ciphers): return False - try: - adapter = self._cipher_registry[type(cipher), type(mode)] - except KeyError: - return False - evp_cipher = adapter(self, cipher, mode) - return self._ffi.NULL != evp_cipher - - def register_cipher_adapter(self, cipher_cls, mode_cls, adapter) -> None: - if (cipher_cls, mode_cls) in self._cipher_registry: - raise ValueError( - f"Duplicate registration for: {cipher_cls} {mode_cls}." - ) - self._cipher_registry[cipher_cls, mode_cls] = adapter - - def _register_default_ciphers(self) -> None: - for cipher_cls in [AES, AES128, AES256]: - for mode_cls in [CBC, CTR, ECB, OFB, CFB, CFB8, GCM]: - self.register_cipher_adapter( - cipher_cls, - mode_cls, - GetCipherByName( - "{cipher.name}-{cipher.key_size}-{mode.name}" - ), - ) - for mode_cls in [CBC, CTR, ECB, OFB, CFB]: - self.register_cipher_adapter( - Camellia, - mode_cls, - GetCipherByName("{cipher.name}-{cipher.key_size}-{mode.name}"), - ) - for mode_cls in [CBC, CFB, CFB8, OFB]: - self.register_cipher_adapter( - TripleDES, mode_cls, GetCipherByName("des-ede3-{mode.name}") - ) - self.register_cipher_adapter( - TripleDES, ECB, GetCipherByName("des-ede3") - ) - # ChaCha20 uses the Long Name "chacha20" in OpenSSL, but in LibreSSL - # it uses "chacha" - self.register_cipher_adapter( - ChaCha20, - type(None), - GetCipherByName( - "chacha" if self._lib.CRYPTOGRAPHY_IS_LIBRESSL else "chacha20" - ), - ) - self.register_cipher_adapter(AES, XTS, _get_xts_cipher) - for mode_cls in [ECB, CBC, OFB, CFB, CTR, GCM]: - self.register_cipher_adapter( - SM4, mode_cls, GetCipherByName("sm4-{mode.name}") - ) - # Don't register legacy ciphers if they're unavailable. Hypothetically - # this wouldn't be necessary because we test availability by seeing if - # we get an EVP_CIPHER * in the _CipherContext __init__, but OpenSSL 3 - # will return a valid pointer even though the cipher is unavailable. - if ( - self._binding._legacy_provider_loaded - or not self._lib.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER - ): - for mode_cls in [CBC, CFB, OFB, ECB]: - self.register_cipher_adapter( - Blowfish, - mode_cls, - GetCipherByName("bf-{mode.name}"), - ) - for mode_cls in [CBC, CFB, OFB, ECB]: - self.register_cipher_adapter( - SEED, - mode_cls, - GetCipherByName("seed-{mode.name}"), - ) - for cipher_cls, mode_cls in itertools.product( - [CAST5, IDEA], - [CBC, OFB, CFB, ECB], - ): - self.register_cipher_adapter( - cipher_cls, - mode_cls, - GetCipherByName("{cipher.name}-{mode.name}"), - ) - self.register_cipher_adapter( - ARC4, type(None), GetCipherByName("rc4") - ) - # We don't actually support RC2, this is just used by some tests. - self.register_cipher_adapter( - _RC2, type(None), GetCipherByName("rc2") - ) - - def create_symmetric_encryption_ctx( - self, cipher: CipherAlgorithm, mode: Mode - ) -> _CipherContext: - return _CipherContext(self, cipher, mode, _CipherContext._ENCRYPT) - - def create_symmetric_decryption_ctx( - self, cipher: CipherAlgorithm, mode: Mode - ) -> _CipherContext: - return _CipherContext(self, cipher, mode, _CipherContext._DECRYPT) + return rust_openssl.ciphers.cipher_supported(cipher, mode) def pbkdf2_hmac_supported(self, algorithm: hashes.HashAlgorithm) -> bool: return self.hmac_supported(algorithm) @@ -841,34 +713,4 @@ def pkcs7_supported(self) -> bool: return not self._lib.CRYPTOGRAPHY_IS_BORINGSSL -class GetCipherByName: - def __init__(self, fmt: str): - self._fmt = fmt - - def __call__(self, backend: Backend, cipher: CipherAlgorithm, mode: Mode): - cipher_name = self._fmt.format(cipher=cipher, mode=mode).lower() - evp_cipher = backend._lib.EVP_get_cipherbyname( - cipher_name.encode("ascii") - ) - - # try EVP_CIPHER_fetch if present - if ( - evp_cipher == backend._ffi.NULL - and backend._lib.Cryptography_HAS_300_EVP_CIPHER - ): - evp_cipher = backend._lib.EVP_CIPHER_fetch( - backend._ffi.NULL, - cipher_name.encode("ascii"), - backend._ffi.NULL, - ) - - backend._consume_errors() - return evp_cipher - - -def _get_xts_cipher(backend: Backend, cipher: AES, mode): - cipher_name = f"aes-{cipher.key_size // 2}-xts" - return backend._lib.EVP_get_cipherbyname(cipher_name.encode("ascii")) - - backend = Backend() diff --git a/src/cryptography/hazmat/backends/openssl/ciphers.py b/src/cryptography/hazmat/backends/openssl/ciphers.py deleted file mode 100644 index 3916b1a510ad6..0000000000000 --- a/src/cryptography/hazmat/backends/openssl/ciphers.py +++ /dev/null @@ -1,282 +0,0 @@ -# This file is dual licensed under the terms of the Apache License, Version -# 2.0, and the BSD License. See the LICENSE file in the root of this repository -# for complete details. - -from __future__ import annotations - -import typing - -from cryptography.exceptions import InvalidTag, UnsupportedAlgorithm, _Reasons -from cryptography.hazmat.primitives import ciphers -from cryptography.hazmat.primitives.ciphers import algorithms, modes - -if typing.TYPE_CHECKING: - from cryptography.hazmat.backends.openssl.backend import Backend - - -class _CipherContext: - _ENCRYPT = 1 - _DECRYPT = 0 - _MAX_CHUNK_SIZE = 2**29 - - def __init__(self, backend: Backend, cipher, mode, operation: int) -> None: - self._backend = backend - self._cipher = cipher - self._mode = mode - self._operation = operation - self._tag: bytes | None = None - - if isinstance(self._cipher, ciphers.BlockCipherAlgorithm): - self._block_size_bytes = self._cipher.block_size // 8 - else: - self._block_size_bytes = 1 - - ctx = self._backend._lib.EVP_CIPHER_CTX_new() - ctx = self._backend._ffi.gc( - ctx, self._backend._lib.EVP_CIPHER_CTX_free - ) - - registry = self._backend._cipher_registry - try: - adapter = registry[type(cipher), type(mode)] - except KeyError: - raise UnsupportedAlgorithm( - "cipher {} in {} mode is not supported " - "by this backend.".format( - cipher.name, mode.name if mode else mode - ), - _Reasons.UNSUPPORTED_CIPHER, - ) - - evp_cipher = adapter(self._backend, cipher, mode) - if evp_cipher == self._backend._ffi.NULL: - msg = f"cipher {cipher.name} " - if mode is not None: - msg += f"in {mode.name} mode " - msg += ( - "is not supported by this backend (Your version of OpenSSL " - "may be too old. Current version: {}.)" - ).format(self._backend.openssl_version_text()) - raise UnsupportedAlgorithm(msg, _Reasons.UNSUPPORTED_CIPHER) - - if isinstance(mode, modes.ModeWithInitializationVector): - iv_nonce = self._backend._ffi.from_buffer( - mode.initialization_vector - ) - elif isinstance(mode, modes.ModeWithTweak): - iv_nonce = self._backend._ffi.from_buffer(mode.tweak) - elif isinstance(mode, modes.ModeWithNonce): - iv_nonce = self._backend._ffi.from_buffer(mode.nonce) - elif isinstance(cipher, algorithms.ChaCha20): - iv_nonce = self._backend._ffi.from_buffer(cipher.nonce) - else: - iv_nonce = self._backend._ffi.NULL - # begin init with cipher and operation type - res = self._backend._lib.EVP_CipherInit_ex( - ctx, - evp_cipher, - self._backend._ffi.NULL, - self._backend._ffi.NULL, - self._backend._ffi.NULL, - operation, - ) - self._backend.openssl_assert(res != 0) - # set the key length to handle variable key ciphers - res = self._backend._lib.EVP_CIPHER_CTX_set_key_length( - ctx, len(cipher.key) - ) - self._backend.openssl_assert(res != 0) - if isinstance(mode, modes.GCM): - res = self._backend._lib.EVP_CIPHER_CTX_ctrl( - ctx, - self._backend._lib.EVP_CTRL_AEAD_SET_IVLEN, - len(iv_nonce), - self._backend._ffi.NULL, - ) - self._backend.openssl_assert(res != 0) - if mode.tag is not None: - res = self._backend._lib.EVP_CIPHER_CTX_ctrl( - ctx, - self._backend._lib.EVP_CTRL_AEAD_SET_TAG, - len(mode.tag), - mode.tag, - ) - self._backend.openssl_assert(res != 0) - self._tag = mode.tag - - # pass key/iv - res = self._backend._lib.EVP_CipherInit_ex( - ctx, - self._backend._ffi.NULL, - self._backend._ffi.NULL, - self._backend._ffi.from_buffer(cipher.key), - iv_nonce, - operation, - ) - - # Check for XTS mode duplicate keys error - errors = self._backend._consume_errors() - lib = self._backend._lib - if res == 0 and ( - ( - not lib.CRYPTOGRAPHY_IS_LIBRESSL - and errors[0]._lib_reason_match( - lib.ERR_LIB_EVP, lib.EVP_R_XTS_DUPLICATED_KEYS - ) - ) - or ( - lib.Cryptography_HAS_PROVIDERS - and errors[0]._lib_reason_match( - lib.ERR_LIB_PROV, lib.PROV_R_XTS_DUPLICATED_KEYS - ) - ) - ): - raise ValueError("In XTS mode duplicated keys are not allowed") - - self._backend.openssl_assert(res != 0, errors=errors) - - # We purposely disable padding here as it's handled higher up in the - # API. - self._backend._lib.EVP_CIPHER_CTX_set_padding(ctx, 0) - self._ctx = ctx - - def update(self, data: bytes) -> bytes: - buf = bytearray(len(data) + self._block_size_bytes - 1) - n = self.update_into(data, buf) - return bytes(buf[:n]) - - def update_into(self, data: bytes, buf: bytes) -> int: - total_data_len = len(data) - if len(buf) < (total_data_len + self._block_size_bytes - 1): - raise ValueError( - "buffer must be at least {} bytes for this payload".format( - len(data) + self._block_size_bytes - 1 - ) - ) - - data_processed = 0 - total_out = 0 - outlen = self._backend._ffi.new("int *") - baseoutbuf = self._backend._ffi.from_buffer(buf, require_writable=True) - baseinbuf = self._backend._ffi.from_buffer(data) - - while data_processed != total_data_len: - outbuf = baseoutbuf + total_out - inbuf = baseinbuf + data_processed - inlen = min(self._MAX_CHUNK_SIZE, total_data_len - data_processed) - - res = self._backend._lib.EVP_CipherUpdate( - self._ctx, outbuf, outlen, inbuf, inlen - ) - if res == 0 and isinstance(self._mode, modes.XTS): - self._backend._consume_errors() - raise ValueError( - "In XTS mode you must supply at least a full block in the " - "first update call. For AES this is 16 bytes." - ) - else: - self._backend.openssl_assert(res != 0) - data_processed += inlen - total_out += outlen[0] - - return total_out - - def finalize(self) -> bytes: - if ( - self._operation == self._DECRYPT - and isinstance(self._mode, modes.ModeWithAuthenticationTag) - and self.tag is None - ): - raise ValueError( - "Authentication tag must be provided when decrypting." - ) - - buf = self._backend._ffi.new("unsigned char[]", self._block_size_bytes) - outlen = self._backend._ffi.new("int *") - res = self._backend._lib.EVP_CipherFinal_ex(self._ctx, buf, outlen) - if res == 0: - errors = self._backend._consume_errors() - - if not errors and isinstance(self._mode, modes.GCM): - raise InvalidTag - - lib = self._backend._lib - self._backend.openssl_assert( - errors[0]._lib_reason_match( - lib.ERR_LIB_EVP, - lib.EVP_R_DATA_NOT_MULTIPLE_OF_BLOCK_LENGTH, - ) - or ( - lib.Cryptography_HAS_PROVIDERS - and errors[0]._lib_reason_match( - lib.ERR_LIB_PROV, - lib.PROV_R_WRONG_FINAL_BLOCK_LENGTH, - ) - ) - or ( - lib.CRYPTOGRAPHY_IS_BORINGSSL - and errors[0].reason - == lib.CIPHER_R_DATA_NOT_MULTIPLE_OF_BLOCK_LENGTH - ), - errors=errors, - ) - raise ValueError( - "The length of the provided data is not a multiple of " - "the block length." - ) - - if ( - isinstance(self._mode, modes.GCM) - and self._operation == self._ENCRYPT - ): - tag_buf = self._backend._ffi.new( - "unsigned char[]", self._block_size_bytes - ) - res = self._backend._lib.EVP_CIPHER_CTX_ctrl( - self._ctx, - self._backend._lib.EVP_CTRL_AEAD_GET_TAG, - self._block_size_bytes, - tag_buf, - ) - self._backend.openssl_assert(res != 0) - self._tag = self._backend._ffi.buffer(tag_buf)[:] - - res = self._backend._lib.EVP_CIPHER_CTX_reset(self._ctx) - self._backend.openssl_assert(res == 1) - return self._backend._ffi.buffer(buf)[: outlen[0]] - - def finalize_with_tag(self, tag: bytes) -> bytes: - tag_len = len(tag) - if tag_len < self._mode._min_tag_length: - raise ValueError( - "Authentication tag must be {} bytes or longer.".format( - self._mode._min_tag_length - ) - ) - elif tag_len > self._block_size_bytes: - raise ValueError( - "Authentication tag cannot be more than {} bytes.".format( - self._block_size_bytes - ) - ) - res = self._backend._lib.EVP_CIPHER_CTX_ctrl( - self._ctx, self._backend._lib.EVP_CTRL_AEAD_SET_TAG, len(tag), tag - ) - self._backend.openssl_assert(res != 0) - self._tag = tag - return self.finalize() - - def authenticate_additional_data(self, data: bytes) -> None: - outlen = self._backend._ffi.new("int *") - res = self._backend._lib.EVP_CipherUpdate( - self._ctx, - self._backend._ffi.NULL, - outlen, - self._backend._ffi.from_buffer(data), - len(data), - ) - self._backend.openssl_assert(res != 0) - - @property - def tag(self) -> bytes | None: - return self._tag diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi index 9cdb4d6a5c6e4..5180d3422979c 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi @@ -6,6 +6,7 @@ import typing from cryptography.hazmat.bindings._rust.openssl import ( aead, + ciphers, cmac, dh, dsa, @@ -26,6 +27,7 @@ __all__ = [ "openssl_version", "raise_openssl_error", "aead", + "ciphers", "cmac", "dh", "dsa", diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/ciphers.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/ciphers.pyi new file mode 100644 index 0000000000000..759f3b591cba6 --- /dev/null +++ b/src/cryptography/hazmat/bindings/_rust/openssl/ciphers.pyi @@ -0,0 +1,38 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +from cryptography.hazmat.primitives import ciphers +from cryptography.hazmat.primitives.ciphers import modes + +@typing.overload +def create_encryption_ctx( + algorithm: ciphers.CipherAlgorithm, mode: modes.ModeWithAuthenticationTag +) -> ciphers.AEADEncryptionContext: ... +@typing.overload +def create_encryption_ctx( + algorithm: ciphers.CipherAlgorithm, mode: modes.Mode +) -> ciphers.CipherContext: ... +@typing.overload +def create_decryption_ctx( + algorithm: ciphers.CipherAlgorithm, mode: modes.ModeWithAuthenticationTag +) -> ciphers.AEADDecryptionContext: ... +@typing.overload +def create_decryption_ctx( + algorithm: ciphers.CipherAlgorithm, mode: modes.Mode +) -> ciphers.CipherContext: ... +def cipher_supported( + algorithm: ciphers.CipherAlgorithm, mode: modes.Mode +) -> bool: ... +def _advance( + ctx: ciphers.AEADEncryptionContext | ciphers.AEADDecryptionContext, n: int +) -> None: ... +def _advance_aad( + ctx: ciphers.AEADEncryptionContext | ciphers.AEADDecryptionContext, n: int +) -> None: ... + +class CipherContext: ... +class AEADEncryptionContext: ... +class AEADDecryptionContext: ... diff --git a/src/cryptography/hazmat/bindings/openssl/binding.py b/src/cryptography/hazmat/bindings/openssl/binding.py index 40814f2a58a08..93c3acc833d18 100644 --- a/src/cryptography/hazmat/bindings/openssl/binding.py +++ b/src/cryptography/hazmat/bindings/openssl/binding.py @@ -17,13 +17,9 @@ from cryptography.hazmat.bindings.openssl._conditional import CONDITIONAL_NAMES -def _openssl_assert( - ok: bool, - errors: list[openssl.OpenSSLError] | None = None, -) -> None: +def _openssl_assert(ok: bool) -> None: if not ok: - if errors is None: - errors = openssl.capture_error_stack() + errors = openssl.capture_error_stack() raise InternalError( "Unknown OpenSSL error. This error is commonly encountered when " diff --git a/src/cryptography/hazmat/primitives/ciphers/base.py b/src/cryptography/hazmat/primitives/ciphers/base.py index 2082df669a23b..7c32cbec693e4 100644 --- a/src/cryptography/hazmat/primitives/ciphers/base.py +++ b/src/cryptography/hazmat/primitives/ciphers/base.py @@ -7,19 +7,10 @@ import abc import typing -from cryptography.exceptions import ( - AlreadyFinalized, - AlreadyUpdated, - NotYetFinalized, -) +from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.hazmat.primitives._cipheralgorithm import CipherAlgorithm from cryptography.hazmat.primitives.ciphers import modes -if typing.TYPE_CHECKING: - from cryptography.hazmat.backends.openssl.ciphers import ( - _CipherContext as _BackendCipherContext, - ) - class CipherContext(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -112,12 +103,10 @@ def encryptor(self): raise ValueError( "Authentication tag must be None when encrypting." ) - from cryptography.hazmat.backends.openssl.backend import backend - ctx = backend.create_symmetric_encryption_ctx( + return rust_openssl.ciphers.create_encryption_ctx( self.algorithm, self.mode ) - return self._wrap_ctx(ctx, encrypt=True) @typing.overload def decryptor( @@ -132,23 +121,9 @@ def decryptor( ... def decryptor(self): - from cryptography.hazmat.backends.openssl.backend import backend - - ctx = backend.create_symmetric_decryption_ctx( + return rust_openssl.ciphers.create_decryption_ctx( self.algorithm, self.mode ) - return self._wrap_ctx(ctx, encrypt=False) - - def _wrap_ctx( - self, ctx: _BackendCipherContext, encrypt: bool - ) -> AEADEncryptionContext | AEADDecryptionContext | CipherContext: - if isinstance(self.mode, modes.ModeWithAuthenticationTag): - if encrypt: - return _AEADEncryptionContext(ctx) - else: - return _AEADDecryptionContext(ctx) - else: - return _CipherContext(ctx) _CIPHER_TYPE = Cipher[ @@ -161,112 +136,6 @@ def _wrap_ctx( ] ] - -class _CipherContext(CipherContext): - _ctx: _BackendCipherContext | None - - def __init__(self, ctx: _BackendCipherContext) -> None: - self._ctx = ctx - - def update(self, data: bytes) -> bytes: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - return self._ctx.update(data) - - def update_into(self, data: bytes, buf: bytes) -> int: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - return self._ctx.update_into(data, buf) - - def finalize(self) -> bytes: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - data = self._ctx.finalize() - self._ctx = None - return data - - -class _AEADCipherContext(AEADCipherContext): - _ctx: _BackendCipherContext | None - _tag: bytes | None - - def __init__(self, ctx: _BackendCipherContext) -> None: - self._ctx = ctx - self._bytes_processed = 0 - self._aad_bytes_processed = 0 - self._tag = None - self._updated = False - - def _check_limit(self, data_size: int) -> None: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - self._updated = True - self._bytes_processed += data_size - if self._bytes_processed > self._ctx._mode._MAX_ENCRYPTED_BYTES: - raise ValueError( - "{} has a maximum encrypted byte limit of {}".format( - self._ctx._mode.name, self._ctx._mode._MAX_ENCRYPTED_BYTES - ) - ) - - def update(self, data: bytes) -> bytes: - self._check_limit(len(data)) - # mypy needs this assert even though _check_limit already checked - assert self._ctx is not None - return self._ctx.update(data) - - def update_into(self, data: bytes, buf: bytes) -> int: - self._check_limit(len(data)) - # mypy needs this assert even though _check_limit already checked - assert self._ctx is not None - return self._ctx.update_into(data, buf) - - def finalize(self) -> bytes: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - data = self._ctx.finalize() - self._tag = self._ctx.tag - self._ctx = None - return data - - def authenticate_additional_data(self, data: bytes) -> None: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - if self._updated: - raise AlreadyUpdated("Update has been called on this context.") - - self._aad_bytes_processed += len(data) - if self._aad_bytes_processed > self._ctx._mode._MAX_AAD_BYTES: - raise ValueError( - "{} has a maximum AAD byte limit of {}".format( - self._ctx._mode.name, self._ctx._mode._MAX_AAD_BYTES - ) - ) - - self._ctx.authenticate_additional_data(data) - - -class _AEADDecryptionContext(_AEADCipherContext, AEADDecryptionContext): - def finalize_with_tag(self, tag: bytes) -> bytes: - if self._ctx is None: - raise AlreadyFinalized("Context was already finalized.") - if self._ctx._tag is not None: - raise ValueError( - "tag provided both in mode and in call with finalize_with_tag:" - " tag should only be provided once" - ) - data = self._ctx.finalize_with_tag(tag) - self._tag = self._ctx.tag - self._ctx = None - return data - - -class _AEADEncryptionContext(_AEADCipherContext, AEADEncryptionContext): - @property - def tag(self) -> bytes: - if self._ctx is not None: - raise NotYetFinalized( - "You must finalize encryption before " "getting the tag." - ) - assert self._tag is not None - return self._tag +CipherContext.register(rust_openssl.ciphers.CipherContext) +AEADEncryptionContext.register(rust_openssl.ciphers.AEADEncryptionContext) +AEADDecryptionContext.register(rust_openssl.ciphers.AEADDecryptionContext) diff --git a/src/rust/src/backend/cipher_registry.rs b/src/rust/src/backend/cipher_registry.rs index 128f087ff4986..23e805fde2df8 100644 --- a/src/rust/src/backend/cipher_registry.rs +++ b/src/rust/src/backend/cipher_registry.rs @@ -56,6 +56,7 @@ impl std::hash::Hash for RegistryKey { enum RegistryCipher { Ref(&'static openssl::cipher::CipherRef), + Owned(Cipher), } impl From<&'static openssl::cipher::CipherRef> for RegistryCipher { @@ -64,6 +65,12 @@ impl From<&'static openssl::cipher::CipherRef> for RegistryCipher { } } +impl From for RegistryCipher { + fn from(c: Cipher) -> RegistryCipher { + RegistryCipher::Owned(c) + } +} + struct RegistryBuilder<'p> { py: pyo3::Python<'p>, m: HashMap, @@ -122,49 +129,182 @@ fn get_cipher_registry( let sm4 = types::SM4.get(py)?; #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_SEED"))] let seed = types::SEED.get(py)?; + let arc4 = types::ARC4.get(py)?; + let chacha20 = types::CHACHA20.get(py)?; + let rc2 = types::RC2.get(py)?; let cbc = types::CBC.get(py)?; + let cfb = types::CFB.get(py)?; + let cfb8 = types::CFB8.get(py)?; + let ofb = types::OFB.get(py)?; + let ecb = types::ECB.get(py)?; + let ctr = types::CTR.get(py)?; + let gcm = types::GCM.get(py)?; + let xts = types::XTS.get(py)?; + + let none = py.None(); + let none_type = none.as_ref(py).get_type(); m.add(aes, cbc, Some(128), Cipher::aes_128_cbc())?; m.add(aes, cbc, Some(192), Cipher::aes_192_cbc())?; m.add(aes, cbc, Some(256), Cipher::aes_256_cbc())?; + m.add(aes, ofb, Some(128), Cipher::aes_128_ofb())?; + m.add(aes, ofb, Some(192), Cipher::aes_192_ofb())?; + m.add(aes, ofb, Some(256), Cipher::aes_256_ofb())?; + + m.add(aes, gcm, Some(128), Cipher::aes_128_gcm())?; + m.add(aes, gcm, Some(192), Cipher::aes_192_gcm())?; + m.add(aes, gcm, Some(256), Cipher::aes_256_gcm())?; + + m.add(aes, ctr, Some(128), Cipher::aes_128_ctr())?; + m.add(aes, ctr, Some(192), Cipher::aes_192_ctr())?; + m.add(aes, ctr, Some(256), Cipher::aes_256_ctr())?; + + #[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))] + { + m.add(aes, cfb8, Some(128), Cipher::aes_128_cfb8())?; + m.add(aes, cfb8, Some(192), Cipher::aes_192_cfb8())?; + m.add(aes, cfb8, Some(256), Cipher::aes_256_cfb8())?; + + m.add(aes, cfb, Some(128), Cipher::aes_128_cfb128())?; + m.add(aes, cfb, Some(192), Cipher::aes_192_cfb128())?; + m.add(aes, cfb, Some(256), Cipher::aes_256_cfb128())?; + } + + m.add(aes, ecb, Some(128), Cipher::aes_128_ecb())?; + m.add(aes, ecb, Some(192), Cipher::aes_192_ecb())?; + m.add(aes, ecb, Some(256), Cipher::aes_256_ecb())?; + + #[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))] + { + m.add(aes, xts, Some(256), Cipher::aes_128_xts())?; + m.add(aes, xts, Some(512), Cipher::aes_256_xts())?; + } + m.add(aes128, cbc, Some(128), Cipher::aes_128_cbc())?; m.add(aes256, cbc, Some(256), Cipher::aes_256_cbc())?; - m.add(triple_des, cbc, Some(192), Cipher::des_ede3_cbc())?; + m.add(aes128, ofb, Some(128), Cipher::aes_128_ofb())?; + m.add(aes256, ofb, Some(256), Cipher::aes_256_ofb())?; - #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_CAMELLIA"))] - m.add(camellia, cbc, Some(128), Cipher::camellia128_cbc())?; - #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_CAMELLIA"))] - m.add(camellia, cbc, Some(192), Cipher::camellia192_cbc())?; - #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_CAMELLIA"))] - m.add(camellia, cbc, Some(256), Cipher::camellia256_cbc())?; + m.add(aes128, gcm, Some(128), Cipher::aes_128_gcm())?; + m.add(aes256, gcm, Some(256), Cipher::aes_256_gcm())?; - #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_SM4"))] - m.add(sm4, cbc, Some(128), Cipher::sm4_cbc())?; + m.add(aes128, ctr, Some(128), Cipher::aes_128_ctr())?; + m.add(aes256, ctr, Some(256), Cipher::aes_256_ctr())?; - #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_SEED"))] - m.add(seed, cbc, Some(128), Cipher::seed_cbc())?; + #[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))] + { + m.add(aes128, cfb8, Some(128), Cipher::aes_128_cfb8())?; + m.add(aes256, cfb8, Some(256), Cipher::aes_256_cfb8())?; - #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_BF"))] - m.add(blowfish, cbc, None, Cipher::bf_cbc())?; + m.add(aes128, cfb, Some(128), Cipher::aes_128_cfb128())?; + m.add(aes256, cfb, Some(256), Cipher::aes_256_cfb128())?; + } - #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_CAST"))] - m.add(cast5, cbc, None, Cipher::cast5_cbc())?; + m.add(aes128, ecb, Some(128), Cipher::aes_128_ecb())?; + m.add(aes256, ecb, Some(256), Cipher::aes_256_ecb())?; - #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_IDEA"))] - m.add(idea, cbc, Some(128), Cipher::idea_cbc())?; + m.add(triple_des, cbc, Some(192), Cipher::des_ede3_cbc())?; + m.add(triple_des, ecb, Some(192), Cipher::des_ede3_ecb())?; + #[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))] + { + m.add(triple_des, cfb8, Some(192), Cipher::des_ede3_cfb8())?; + m.add(triple_des, cfb, Some(192), Cipher::des_ede3_cfb64())?; + m.add(triple_des, ofb, Some(192), Cipher::des_ede3_ofb())?; + } + + #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_CAMELLIA"))] + { + m.add(camellia, cbc, Some(128), Cipher::camellia128_cbc())?; + m.add(camellia, cbc, Some(192), Cipher::camellia192_cbc())?; + m.add(camellia, cbc, Some(256), Cipher::camellia256_cbc())?; + + m.add(camellia, ecb, Some(128), Cipher::camellia128_ecb())?; + m.add(camellia, ecb, Some(192), Cipher::camellia192_ecb())?; + m.add(camellia, ecb, Some(256), Cipher::camellia256_ecb())?; + + m.add(camellia, ofb, Some(128), Cipher::camellia128_ofb())?; + m.add(camellia, ofb, Some(192), Cipher::camellia192_ofb())?; + m.add(camellia, ofb, Some(256), Cipher::camellia256_ofb())?; + + m.add(camellia, cfb, Some(128), Cipher::camellia128_cfb128())?; + m.add(camellia, cfb, Some(192), Cipher::camellia192_cfb128())?; + m.add(camellia, cfb, Some(256), Cipher::camellia256_cfb128())?; + } + + #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_SM4"))] + { + m.add(sm4, cbc, Some(128), Cipher::sm4_cbc())?; + m.add(sm4, ctr, Some(128), Cipher::sm4_ctr())?; + m.add(sm4, cfb, Some(128), Cipher::sm4_cfb128())?; + m.add(sm4, ofb, Some(128), Cipher::sm4_ofb())?; + m.add(sm4, ecb, Some(128), Cipher::sm4_ecb())?; + + #[cfg(CRYPTOGRAPHY_OPENSSL_300_OR_GREATER)] + if let Ok(c) = Cipher::fetch(None, "sm4-gcm", None) { + m.add(sm4, gcm, Some(128), c)?; + } + } + + #[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))] + m.add(chacha20, none_type, None, Cipher::chacha20())?; + + // Don't register legacy ciphers if they're unavailable. In theory + // this should't be necessary but OpenSSL 3 will return an EVP_CIPHER + // even when the cipher is unavailable. + if cfg!(not(CRYPTOGRAPHY_OPENSSL_300_OR_GREATER)) + || types::LEGACY_PROVIDER_LOADED.get(py)?.is_true()? + { + #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_BF"))] + { + m.add(blowfish, cbc, None, Cipher::bf_cbc())?; + m.add(blowfish, cfb, None, Cipher::bf_cfb64())?; + m.add(blowfish, ofb, None, Cipher::bf_ofb())?; + m.add(blowfish, ecb, None, Cipher::bf_ecb())?; + } + #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_SEED"))] + { + m.add(seed, cbc, Some(128), Cipher::seed_cbc())?; + m.add(seed, cfb, Some(128), Cipher::seed_cfb128())?; + m.add(seed, ofb, Some(128), Cipher::seed_ofb())?; + m.add(seed, ecb, Some(128), Cipher::seed_ecb())?; + } + + #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_CAST"))] + { + m.add(cast5, cbc, None, Cipher::cast5_cbc())?; + m.add(cast5, ecb, None, Cipher::cast5_ecb())?; + m.add(cast5, ofb, None, Cipher::cast5_ofb())?; + m.add(cast5, cfb, None, Cipher::cast5_cfb64())?; + } + + #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_IDEA"))] + { + m.add(idea, cbc, Some(128), Cipher::idea_cbc())?; + m.add(idea, ecb, Some(128), Cipher::idea_ecb())?; + m.add(idea, ofb, Some(128), Cipher::idea_ofb())?; + m.add(idea, cfb, Some(128), Cipher::idea_cfb64())?; + } + + m.add(arc4, none_type, None, Cipher::rc4())?; + + // We don't actually support RC2, this is just used by some tests. + if let Some(rc2_cbc) = Cipher::from_nid(openssl::nid::Nid::RC2_CBC) { + m.add(rc2, cbc, None, rc2_cbc)?; + } + } Ok(m.build()) }) } -pub(crate) fn get_cipher<'a>( - py: pyo3::Python<'_>, +pub(crate) fn get_cipher<'py>( + py: pyo3::Python<'py>, algorithm: &pyo3::PyAny, mode_cls: &pyo3::PyAny, -) -> CryptographyResult> { +) -> CryptographyResult> { let registry = get_cipher_registry(py)?; let key_size = algorithm @@ -174,6 +314,7 @@ pub(crate) fn get_cipher<'a>( match registry.get(&key) { Some(RegistryCipher::Ref(c)) => Ok(Some(c)), + Some(RegistryCipher::Owned(c)) => Ok(Some(c)), None => Ok(None), } } diff --git a/src/rust/src/backend/ciphers.rs b/src/rust/src/backend/ciphers.rs new file mode 100644 index 0000000000000..3695ca1d89df9 --- /dev/null +++ b/src/rust/src/backend/ciphers.rs @@ -0,0 +1,567 @@ +// This file is dual licensed under the terms of the Apache License, Version +// 2.0, and the BSD License. See the LICENSE file in the root of this repository +// for complete details. + +use crate::backend::cipher_registry; +use crate::buf::{CffiBuf, CffiMutBuf}; +use crate::error::{CryptographyError, CryptographyResult}; +use crate::exceptions; +use crate::types; +use pyo3::IntoPy; + +struct CipherContext { + ctx: openssl::cipher_ctx::CipherCtx, + py_mode: pyo3::PyObject, +} + +impl CipherContext { + fn new( + py: pyo3::Python<'_>, + algorithm: &pyo3::PyAny, + mode: &pyo3::PyAny, + side: openssl::symm::Mode, + ) -> CryptographyResult { + let cipher = match cipher_registry::get_cipher(py, algorithm, mode.get_type())? { + Some(c) => c, + None => { + return Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err(( + format!( + "cipher {} in {} mode is not supported ", + algorithm.getattr(pyo3::intern!(py, "name"))?, + if mode.is_true()? { + mode.getattr(pyo3::intern!(py, "name"))? + } else { + mode + } + ), + exceptions::Reasons::UNSUPPORTED_CIPHER, + )), + )) + } + }; + + let iv_nonce = if mode.is_instance(types::MODE_WITH_INITIALIZATION_VECTOR.get(py)?)? { + Some( + mode.getattr(pyo3::intern!(py, "initialization_vector"))? + .extract::>()?, + ) + } else if mode.is_instance(types::MODE_WITH_TWEAK.get(py)?)? { + Some( + mode.getattr(pyo3::intern!(py, "tweak"))? + .extract::>()?, + ) + } else if mode.is_instance(types::MODE_WITH_NONCE.get(py)?)? { + Some( + mode.getattr(pyo3::intern!(py, "nonce"))? + .extract::>()?, + ) + } else if algorithm.is_instance(types::CHACHA20.get(py)?)? { + Some( + algorithm + .getattr(pyo3::intern!(py, "nonce"))? + .extract::>()?, + ) + } else { + None + }; + + let key = algorithm + .getattr(pyo3::intern!(py, "key"))? + .extract::>()?; + + let init_op = match side { + openssl::symm::Mode::Encrypt => openssl::cipher_ctx::CipherCtxRef::encrypt_init, + openssl::symm::Mode::Decrypt => openssl::cipher_ctx::CipherCtxRef::decrypt_init, + }; + + let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; + init_op(&mut ctx, Some(cipher), None, None)?; + ctx.set_key_length(key.as_bytes().len())?; + + if let Some(iv) = iv_nonce.as_ref() { + if cipher.iv_length() != 0 && cipher.iv_length() != iv.as_bytes().len() { + ctx.set_iv_length(iv.as_bytes().len())?; + } + } + + if mode.is_instance(types::XTS.get(py)?)? { + init_op( + &mut ctx, + None, + Some(key.as_bytes()), + iv_nonce.as_ref().map(|b| b.as_bytes()), + ) + .map_err(|_| { + pyo3::exceptions::PyValueError::new_err( + "In XTS mode duplicated keys are not allowed", + ) + })?; + } else { + init_op( + &mut ctx, + None, + Some(key.as_bytes()), + iv_nonce.as_ref().map(|b| b.as_bytes()), + )?; + }; + + ctx.set_padding(false); + + Ok(CipherContext { + ctx, + py_mode: mode.into(), + }) + } + + fn update<'p>( + &mut self, + py: pyo3::Python<'p>, + buf: &[u8], + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let mut out_buf = vec![0; buf.len() + self.ctx.block_size()]; + let n = self.update_into(py, buf, &mut out_buf)?; + Ok(pyo3::types::PyBytes::new(py, &out_buf[..n])) + } + + fn update_into( + &mut self, + py: pyo3::Python<'_>, + buf: &[u8], + out_buf: &mut [u8], + ) -> CryptographyResult { + if out_buf.len() < (buf.len() + self.ctx.block_size() - 1) { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err(format!( + "buffer must be at least {} bytes for this payload", + buf.len() + self.ctx.block_size() - 1 + )), + )); + } + + let mut total_written = 0; + for chunk in buf.chunks(1 << 29) { + // SAFETY: We ensure that outbuf is sufficiently large above. + unsafe { + let n = if self.py_mode.as_ref(py).is_instance(types::XTS.get(py)?)? { + self.ctx.cipher_update_unchecked(chunk, Some(&mut out_buf[total_written..])).map_err(|_| { + pyo3::exceptions::PyValueError::new_err( + "In XTS mode you must supply at least a full block in the first update call. For AES this is 16 bytes." + ) + })? + } else { + self.ctx + .cipher_update_unchecked(chunk, Some(&mut out_buf[total_written..]))? + }; + total_written += n; + } + } + + Ok(total_written) + } + + fn authenticate_additional_data(&mut self, buf: &[u8]) -> CryptographyResult<()> { + self.ctx.cipher_update(buf, None)?; + Ok(()) + } + + fn finalize<'p>( + &mut self, + py: pyo3::Python<'p>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let mut out_buf = vec![0; self.ctx.block_size()]; + let n = self.ctx.cipher_final(&mut out_buf).or_else(|e| { + if e.errors().is_empty() + && self + .py_mode + .as_ref(py) + .is_instance(types::MODE_WITH_AUTHENTICATION_TAG.get(py)?)? + { + return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); + } + Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "The length of the provided data is not a multiple of the block length.", + ), + )) + })?; + Ok(pyo3::types::PyBytes::new(py, &out_buf[..n])) + } +} + +#[pyo3::prelude::pyclass( + module = "cryptography.hazmat.bindings._rust.openssl.ciphers", + name = "CipherContext" +)] +struct PyCipherContext { + ctx: Option, +} + +#[pyo3::prelude::pyclass( + module = "cryptography.hazmat.bindings._rust.openssl.ciphers", + name = "AEADEncryptionContext" +)] +struct PyAEADEncryptionContext { + ctx: Option, + tag: Option>, + updated: bool, + bytes_remaining: u64, + aad_bytes_remaining: u64, +} + +#[pyo3::prelude::pyclass( + module = "cryptography.hazmat.bindings._rust.openssl.ciphers", + name = "AEADDecryptionContext" +)] +struct PyAEADDecryptionContext { + ctx: Option, + updated: bool, + bytes_remaining: u64, + aad_bytes_remaining: u64, +} + +fn get_mut_ctx(ctx: Option<&mut CipherContext>) -> pyo3::PyResult<&mut CipherContext> { + ctx.ok_or_else(|| exceptions::AlreadyFinalized::new_err("Context was already finalized.")) +} + +#[pyo3::prelude::pymethods] +impl PyCipherContext { + fn update<'p>( + &mut self, + py: pyo3::Python<'p>, + buf: CffiBuf<'_>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + get_mut_ctx(self.ctx.as_mut())?.update(py, buf.as_bytes()) + } + + fn update_into( + &mut self, + py: pyo3::Python<'_>, + buf: CffiBuf<'_>, + mut out_buf: CffiMutBuf<'_>, + ) -> CryptographyResult { + get_mut_ctx(self.ctx.as_mut())?.update_into(py, buf.as_bytes(), out_buf.as_mut_bytes()) + } + + fn finalize<'p>( + &mut self, + py: pyo3::Python<'p>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let result = get_mut_ctx(self.ctx.as_mut())?.finalize(py)?; + self.ctx = None; + Ok(result) + } +} + +#[pyo3::prelude::pymethods] +impl PyAEADEncryptionContext { + fn update<'p>( + &mut self, + py: pyo3::Python<'p>, + buf: CffiBuf<'_>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let data = buf.as_bytes(); + + self.updated = true; + self.bytes_remaining = self + .bytes_remaining + .checked_sub(data.len().try_into().unwrap()) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err("Exceeded maximum encrypted byte limit") + })?; + get_mut_ctx(self.ctx.as_mut())?.update(py, data) + } + + fn update_into( + &mut self, + py: pyo3::Python<'_>, + buf: CffiBuf<'_>, + mut out_buf: CffiMutBuf<'_>, + ) -> CryptographyResult { + let data = buf.as_bytes(); + + self.updated = true; + self.bytes_remaining = self + .bytes_remaining + .checked_sub(data.len().try_into().unwrap()) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err("Exceeded maximum encrypted byte limit") + })?; + get_mut_ctx(self.ctx.as_mut())?.update_into(py, data, out_buf.as_mut_bytes()) + } + + fn authenticate_additional_data(&mut self, buf: CffiBuf<'_>) -> CryptographyResult<()> { + let ctx = get_mut_ctx(self.ctx.as_mut())?; + if self.updated { + return Err(CryptographyError::from( + exceptions::AlreadyUpdated::new_err("Update has been called on this context."), + )); + } + + let data = buf.as_bytes(); + self.aad_bytes_remaining = self + .aad_bytes_remaining + .checked_sub(data.len().try_into().unwrap()) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err("Exceeded maximum AAD byte limit") + })?; + ctx.authenticate_additional_data(data) + } + + fn finalize<'p>( + &mut self, + py: pyo3::Python<'p>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let ctx = get_mut_ctx(self.ctx.as_mut())?; + let result = ctx.finalize(py)?; + + // XXX: do not hard code 16 + let tag = pyo3::types::PyBytes::new_with(py, 16, |t| { + ctx.ctx.tag(t).map_err(CryptographyError::from)?; + Ok(()) + })?; + self.tag = Some(tag.into_py(py)); + self.ctx = None; + + Ok(result) + } + + #[getter] + fn tag(&self, py: pyo3::Python<'_>) -> CryptographyResult> { + Ok(self + .tag + .as_ref() + .ok_or_else(|| { + exceptions::NotYetFinalized::new_err( + "You must finalize encryption before getting the tag.", + ) + })? + .clone_ref(py)) + } +} + +#[pyo3::prelude::pymethods] +impl PyAEADDecryptionContext { + fn update<'p>( + &mut self, + py: pyo3::Python<'p>, + buf: CffiBuf<'_>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let data = buf.as_bytes(); + + self.updated = true; + self.bytes_remaining = self + .bytes_remaining + .checked_sub(data.len().try_into().unwrap()) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err("Exceeded maximum encrypted byte limit") + })?; + get_mut_ctx(self.ctx.as_mut())?.update(py, data) + } + + fn update_into( + &mut self, + py: pyo3::Python<'_>, + buf: CffiBuf<'_>, + mut out_buf: CffiMutBuf<'_>, + ) -> CryptographyResult { + let data = buf.as_bytes(); + + self.updated = true; + self.bytes_remaining = self + .bytes_remaining + .checked_sub(data.len().try_into().unwrap()) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err("Exceeded maximum encrypted byte limit") + })?; + get_mut_ctx(self.ctx.as_mut())?.update_into(py, data, out_buf.as_mut_bytes()) + } + + fn authenticate_additional_data(&mut self, buf: CffiBuf<'_>) -> CryptographyResult<()> { + let ctx = get_mut_ctx(self.ctx.as_mut())?; + if self.updated { + return Err(CryptographyError::from( + exceptions::AlreadyUpdated::new_err("Update has been called on this context."), + )); + } + + let data = buf.as_bytes(); + self.aad_bytes_remaining = self + .aad_bytes_remaining + .checked_sub(data.len().try_into().unwrap()) + .ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err("Exceeded maximum AAD byte limit") + })?; + ctx.authenticate_additional_data(data) + } + + fn finalize<'p>( + &mut self, + py: pyo3::Python<'p>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let ctx = get_mut_ctx(self.ctx.as_mut())?; + + if ctx + .py_mode + .as_ref(py) + .getattr(pyo3::intern!(py, "tag"))? + .is_none() + { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "Authentication tag must be provided when decrypting.", + ), + )); + } + + let result = ctx.finalize(py)?; + self.ctx = None; + Ok(result) + } + + fn finalize_with_tag<'p>( + &mut self, + py: pyo3::Python<'p>, + tag: &[u8], + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let ctx = get_mut_ctx(self.ctx.as_mut())?; + + if !ctx + .py_mode + .as_ref(py) + .getattr(pyo3::intern!(py, "tag"))? + .is_none() + { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "Authentication tag must be provided only once.", + ), + )); + } + + let min_tag_length = ctx + .py_mode + .as_ref(py) + .getattr(pyo3::intern!(py, "_min_tag_length"))? + .extract()?; + // XXX: Do not hard code 16 + if tag.len() < min_tag_length { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err(format!( + "Authentication tag must be {} bytes or longer.", + min_tag_length + )), + )); + } else if tag.len() > 16 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err(format!( + "Authentication tag cannot be more than {} bytes.", + 16 + )), + )); + } + + ctx.ctx.set_tag(tag)?; + let result = ctx.finalize(py)?; + self.ctx = None; + Ok(result) + } +} + +#[pyo3::prelude::pyfunction] +fn create_encryption_ctx( + py: pyo3::Python<'_>, + algorithm: &pyo3::PyAny, + mode: &pyo3::PyAny, +) -> CryptographyResult { + let ctx = CipherContext::new(py, algorithm, mode, openssl::symm::Mode::Encrypt)?; + + if mode.is_instance(types::MODE_WITH_AUTHENTICATION_TAG.get(py)?)? { + Ok(PyAEADEncryptionContext { + ctx: Some(ctx), + tag: None, + updated: false, + bytes_remaining: mode + .getattr(pyo3::intern!(py, "_MAX_ENCRYPTED_BYTES"))? + .extract()?, + aad_bytes_remaining: mode + .getattr(pyo3::intern!(py, "_MAX_AAD_BYTES"))? + .extract()?, + } + .into_py(py)) + } else { + Ok(PyCipherContext { ctx: Some(ctx) }.into_py(py)) + } +} + +#[pyo3::prelude::pyfunction] +fn create_decryption_ctx( + py: pyo3::Python<'_>, + algorithm: &pyo3::PyAny, + mode: &pyo3::PyAny, +) -> CryptographyResult { + let mut ctx = CipherContext::new(py, algorithm, mode, openssl::symm::Mode::Decrypt)?; + + if mode.is_instance(types::MODE_WITH_AUTHENTICATION_TAG.get(py)?)? { + if let Some(tag) = mode.getattr(pyo3::intern!(py, "tag"))?.extract()? { + ctx.ctx.set_tag(tag)?; + } + + Ok(PyAEADDecryptionContext { + ctx: Some(ctx), + updated: false, + bytes_remaining: mode + .getattr(pyo3::intern!(py, "_MAX_ENCRYPTED_BYTES"))? + .extract()?, + aad_bytes_remaining: mode + .getattr(pyo3::intern!(py, "_MAX_AAD_BYTES"))? + .extract()?, + } + .into_py(py)) + } else { + Ok(PyCipherContext { ctx: Some(ctx) }.into_py(py)) + } +} + +#[pyo3::prelude::pyfunction] +fn cipher_supported( + py: pyo3::Python<'_>, + algorithm: &pyo3::PyAny, + mode: &pyo3::PyAny, +) -> CryptographyResult { + Ok(cipher_registry::get_cipher(py, algorithm, mode.get_type())?.is_some()) +} + +#[pyo3::prelude::pyfunction] +fn _advance(ctx: &pyo3::PyAny, n: u64) { + if let Ok(c) = ctx.downcast::>() { + c.borrow_mut().bytes_remaining -= n; + } else if let Ok(c) = ctx.downcast::>() { + c.borrow_mut().bytes_remaining -= n; + } +} + +#[pyo3::prelude::pyfunction] +fn _advance_aad(ctx: &pyo3::PyAny, n: u64) { + if let Ok(c) = ctx.downcast::>() { + c.borrow_mut().aad_bytes_remaining -= n; + } else if let Ok(c) = ctx.downcast::>() { + c.borrow_mut().aad_bytes_remaining -= n; + } +} + +pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> { + let m = pyo3::prelude::PyModule::new(py, "ciphers")?; + m.add_function(pyo3::wrap_pyfunction!(create_encryption_ctx, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(create_decryption_ctx, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(cipher_supported, m)?)?; + + m.add_function(pyo3::wrap_pyfunction!(_advance, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(_advance_aad, m)?)?; + + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + Ok(m) +} diff --git a/src/rust/src/backend/mod.rs b/src/rust/src/backend/mod.rs index 7e085d623b40f..be7b2d0ac2808 100644 --- a/src/rust/src/backend/mod.rs +++ b/src/rust/src/backend/mod.rs @@ -4,6 +4,7 @@ pub(crate) mod aead; pub(crate) mod cipher_registry; +pub(crate) mod ciphers; pub(crate) mod cmac; pub(crate) mod dh; pub(crate) mod dsa; @@ -24,6 +25,7 @@ pub(crate) mod x448; pub(crate) fn add_to_module(module: &pyo3::prelude::PyModule) -> pyo3::PyResult<()> { module.add_submodule(aead::create_module(module.py())?)?; + module.add_submodule(ciphers::create_module(module.py())?)?; module.add_submodule(cmac::create_module(module.py())?)?; module.add_submodule(dh::create_module(module.py())?)?; module.add_submodule(dsa::create_module(module.py())?)?; diff --git a/src/rust/src/buf.rs b/src/rust/src/buf.rs index edc3860c10500..028322dfe0da2 100644 --- a/src/rust/src/buf.rs +++ b/src/rust/src/buf.rs @@ -2,9 +2,9 @@ // 2.0, and the BSD License. See the LICENSE file in the root of this repository // for complete details. -use std::slice; - use crate::types; +use pyo3::types::IntoPyDict; +use std::slice; pub(crate) struct CffiBuf<'p> { _pyobj: &'p pyo3::PyAny, @@ -12,9 +12,19 @@ pub(crate) struct CffiBuf<'p> { buf: &'p [u8], } -fn _extract_buffer_length(pyobj: &pyo3::PyAny) -> pyo3::PyResult<(&pyo3::PyAny, usize)> { +fn _extract_buffer_length( + pyobj: &pyo3::PyAny, + mutable: bool, +) -> pyo3::PyResult<(&pyo3::PyAny, usize)> { let py = pyobj.py(); - let bufobj = types::FFI_FROM_BUFFER.get(py)?.call1((pyobj,))?; + let bufobj = if mutable { + let kwargs = [(pyo3::intern!(py, "require_writable"), true)].into_py_dict(py); + types::FFI_FROM_BUFFER + .get(py)? + .call((pyobj,), Some(kwargs))? + } else { + types::FFI_FROM_BUFFER.get(py)?.call1((pyobj,))? + }; let ptrval = types::FFI_CAST .get(py)? .call1((pyo3::intern!(py, "uintptr_t"), bufobj))? @@ -31,7 +41,7 @@ impl CffiBuf<'_> { impl<'a> pyo3::conversion::FromPyObject<'a> for CffiBuf<'a> { fn extract(pyobj: &'a pyo3::PyAny) -> pyo3::PyResult { - let (bufobj, ptrval) = _extract_buffer_length(pyobj)?; + let (bufobj, ptrval) = _extract_buffer_length(pyobj, false)?; let len = bufobj.len()?; let buf = if len == 0 { &[] @@ -54,3 +64,42 @@ impl<'a> pyo3::conversion::FromPyObject<'a> for CffiBuf<'a> { }) } } + +pub(crate) struct CffiMutBuf<'p> { + _pyobj: &'p pyo3::PyAny, + _bufobj: &'p pyo3::PyAny, + buf: &'p mut [u8], +} + +impl CffiMutBuf<'_> { + pub(crate) fn as_mut_bytes(&mut self) -> &mut [u8] { + self.buf + } +} + +impl<'a> pyo3::conversion::FromPyObject<'a> for CffiMutBuf<'a> { + fn extract(pyobj: &'a pyo3::PyAny) -> pyo3::PyResult { + let (bufobj, ptrval) = _extract_buffer_length(pyobj, true)?; + + let len = bufobj.len()?; + let buf = if len == 0 { + &mut [] + } else { + // SAFETY: _extract_buffer_length ensures that we have a valid ptr + // and length (and we ensure we meet slice's requirements for + // 0-length slices above), we're keeping pyobj alive which ensures + // the buffer is valid. But! There is no actually guarantee + // against concurrent mutation. See + // https://alexgaynor.net/2022/oct/23/buffers-on-the-edge/ + // for details. This is the same as our cffi status quo ante, so + // we're doing an unsound thing and living with it. + unsafe { slice::from_raw_parts_mut(ptrval as *mut u8, len) } + }; + + Ok(CffiMutBuf { + _pyobj: pyobj, + _bufobj: bufobj, + buf, + }) + } +} diff --git a/src/rust/src/exceptions.rs b/src/rust/src/exceptions.rs index c9456513993d5..67f57b9adcb5b 100644 --- a/src/rust/src/exceptions.rs +++ b/src/rust/src/exceptions.rs @@ -23,10 +23,12 @@ pub(crate) enum Reasons { UNSUPPORTED_MAC, } +pyo3::import_exception!(cryptography.exceptions, AlreadyUpdated); pyo3::import_exception!(cryptography.exceptions, AlreadyFinalized); pyo3::import_exception!(cryptography.exceptions, InternalError); pyo3::import_exception!(cryptography.exceptions, InvalidSignature); pyo3::import_exception!(cryptography.exceptions, InvalidTag); +pyo3::import_exception!(cryptography.exceptions, NotYetFinalized); pyo3::import_exception!(cryptography.exceptions, UnsupportedAlgorithm); pyo3::import_exception!(cryptography.x509, AttributeNotFound); pyo3::import_exception!(cryptography.x509, DuplicateExtension); diff --git a/src/rust/src/types.rs b/src/rust/src/types.rs index ddd5d8f452ff1..10b9e0cbaa052 100644 --- a/src/rust/src/types.rs +++ b/src/rust/src/types.rs @@ -473,6 +473,10 @@ pub static AES256: LazyPyImport = LazyPyImport::new( "cryptography.hazmat.primitives.ciphers.algorithms", &["AES256"], ); +pub static CHACHA20: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.ciphers.algorithms", + &["ChaCha20"], +); pub static SM4: LazyPyImport = LazyPyImport::new( "cryptography.hazmat.primitives.ciphers.algorithms", &["SM4"], @@ -495,9 +499,48 @@ pub static CAST5: LazyPyImport = LazyPyImport::new( #[cfg(not(CRYPTOGRAPHY_OSSLCONF = "OPENSSL_NO_IDEA"))] pub static IDEA: LazyPyImport = LazyPyImport::new("cryptography.hazmat.decrepit.ciphers.algorithms", &["IDEA"]); +pub static ARC4: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.decrepit.ciphers.algorithms", &["ARC4"]); +pub static RC2: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.backends.openssl.backend", &["_RC2"]); +pub static MODE_WITH_INITIALIZATION_VECTOR: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.ciphers.modes", + &["ModeWithInitializationVector"], +); +pub static MODE_WITH_TWEAK: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.ciphers.modes", + &["ModeWithTweak"], +); +pub static MODE_WITH_NONCE: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.ciphers.modes", + &["ModeWithNonce"], +); +pub static MODE_WITH_AUTHENTICATION_TAG: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.ciphers.modes", + &["ModeWithAuthenticationTag"], +); pub static CBC: LazyPyImport = LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["CBC"]); +pub static CFB: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["CFB"]); +pub static CFB8: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["CFB8"]); +pub static OFB: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["OFB"]); +pub static ECB: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["ECB"]); +pub static CTR: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["CTR"]); +pub static GCM: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["GCM"]); +pub static XTS: LazyPyImport = + LazyPyImport::new("cryptography.hazmat.primitives.ciphers.modes", &["XTS"]); + +pub static LEGACY_PROVIDER_LOADED: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.bindings.openssl.binding", + &["Binding", "_legacy_provider_loaded"], +); #[cfg(test)] mod tests { diff --git a/src/rust/src/x509/common.rs b/src/rust/src/x509/common.rs index a941f50b928c0..d838c2f8dfe1a 100644 --- a/src/rust/src/x509/common.rs +++ b/src/rust/src/x509/common.rs @@ -216,7 +216,7 @@ fn parse_name_attribute( pyo3::types::PyString::new(py, parsed) } }; - let kwargs = [("_validate", false)].into_py_dict(py); + let kwargs = [(pyo3::intern!(py, "_validate"), false)].into_py_dict(py); Ok(types::NAME_ATTRIBUTE .get(py)? .call((oid, py_data, py_tag), Some(kwargs))? diff --git a/tests/hazmat/backends/test_openssl.py b/tests/hazmat/backends/test_openssl.py index e9cdcc432a504..6115e48f9cc34 100644 --- a/tests/hazmat/backends/test_openssl.py +++ b/tests/hazmat/backends/test_openssl.py @@ -14,9 +14,6 @@ from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding -from cryptography.hazmat.primitives.ciphers import Cipher -from cryptography.hazmat.primitives.ciphers.algorithms import AES -from cryptography.hazmat.primitives.ciphers.modes import CBC from ...doubles import ( DummyAsymmetricPadding, @@ -80,26 +77,6 @@ def test_supports_cipher(self): is False ) - def test_register_duplicate_cipher_adapter(self): - with pytest.raises(ValueError): - backend.register_cipher_adapter(AES, CBC, None) - - @pytest.mark.parametrize("mode", [DummyMode(), None]) - def test_nonexistent_cipher(self, mode, backend, monkeypatch): - # We can't use register_cipher_adapter because backend is a - # global singleton and we want to revert the change after the test - monkeypatch.setitem( - backend._cipher_registry, - (DummyCipherAlgorithm, type(mode)), - lambda backend, cipher, mode: backend._ffi.NULL, - ) - cipher = Cipher( - DummyCipherAlgorithm(), - mode, - ) - with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_CIPHER): - cipher.encryptor() - def test_openssl_assert(self): backend.openssl_assert(True) with pytest.raises(InternalError): @@ -128,14 +105,6 @@ def test_evp_ciphers_registered(self): cipher = backend._lib.EVP_get_cipherbyname(b"aes-256-cbc") assert cipher != backend._ffi.NULL - def test_unknown_error_in_cipher_finalize(self): - cipher = Cipher(AES(b"\0" * 16), CBC(b"\0" * 16), backend=backend) - enc = cipher.encryptor() - enc.update(b"\0") - backend._lib.ERR_put_error(0, 0, 1, b"test_openssl.py", -1) - with pytest.raises(InternalError): - enc.finalize() - class TestOpenSSLRSA: def test_rsa_padding_unsupported_pss_mgf1_hash(self): diff --git a/tests/hazmat/primitives/test_aes_gcm.py b/tests/hazmat/primitives/test_aes_gcm.py index d82e37470cae5..0543270413589 100644 --- a/tests/hazmat/primitives/test_aes_gcm.py +++ b/tests/hazmat/primitives/test_aes_gcm.py @@ -8,20 +8,13 @@ import pytest +from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.hazmat.primitives.ciphers import algorithms, base, modes from ...utils import load_nist_vectors from .utils import generate_aead_test -def _advance(ctx, n): - ctx._bytes_processed += n - - -def _advance_aad(ctx, n): - ctx._aad_bytes_processed += n - - @pytest.mark.supported( only_if=lambda backend: backend.cipher_supported( algorithms.AES(b"\x00" * 16), modes.GCM(b"\x00" * 12) @@ -80,7 +73,9 @@ def test_gcm_ciphertext_limit(self, backend): backend=backend, ) encryptor = cipher.encryptor() - _advance(encryptor, modes.GCM._MAX_ENCRYPTED_BYTES - 16) + rust_openssl.ciphers._advance( + encryptor, modes.GCM._MAX_ENCRYPTED_BYTES - 16 + ) encryptor.update(b"0" * 16) with pytest.raises(ValueError): encryptor.update(b"0") @@ -88,7 +83,9 @@ def test_gcm_ciphertext_limit(self, backend): encryptor.update_into(b"0", bytearray(1)) decryptor = cipher.decryptor() - _advance(decryptor, modes.GCM._MAX_ENCRYPTED_BYTES - 16) + rust_openssl.ciphers._advance( + decryptor, modes.GCM._MAX_ENCRYPTED_BYTES - 16 + ) decryptor.update(b"0" * 16) with pytest.raises(ValueError): decryptor.update(b"0") @@ -102,45 +99,21 @@ def test_gcm_aad_limit(self, backend): backend=backend, ) encryptor = cipher.encryptor() - _advance_aad(encryptor, modes.GCM._MAX_AAD_BYTES - 16) + rust_openssl.ciphers._advance_aad( + encryptor, modes.GCM._MAX_AAD_BYTES - 16 + ) encryptor.authenticate_additional_data(b"0" * 16) with pytest.raises(ValueError): encryptor.authenticate_additional_data(b"0") decryptor = cipher.decryptor() - _advance_aad(decryptor, modes.GCM._MAX_AAD_BYTES - 16) + rust_openssl.ciphers._advance_aad( + decryptor, modes.GCM._MAX_AAD_BYTES - 16 + ) decryptor.authenticate_additional_data(b"0" * 16) with pytest.raises(ValueError): decryptor.authenticate_additional_data(b"0") - def test_gcm_ciphertext_increments(self, backend): - encryptor = base.Cipher( - algorithms.AES(b"\x00" * 16), - modes.GCM(b"\x01" * 16), - backend=backend, - ).encryptor() - encryptor.update(b"0" * 8) - assert encryptor._bytes_processed == 8 # type: ignore[attr-defined] - encryptor.update(b"0" * 7) - assert encryptor._bytes_processed == 15 # type: ignore[attr-defined] - encryptor.update(b"0" * 18) - assert encryptor._bytes_processed == 33 # type: ignore[attr-defined] - - def test_gcm_aad_increments(self, backend): - encryptor = base.Cipher( - algorithms.AES(b"\x00" * 16), - modes.GCM(b"\x01" * 16), - backend=backend, - ).encryptor() - encryptor.authenticate_additional_data(b"0" * 8) - assert ( - encryptor._aad_bytes_processed == 8 # type: ignore[attr-defined] - ) - encryptor.authenticate_additional_data(b"0" * 18) - assert ( - encryptor._aad_bytes_processed == 26 # type: ignore[attr-defined] - ) - def test_gcm_tag_decrypt_none(self, backend): key = binascii.unhexlify(b"5211242698bed4774a090620a6ca56f3") iv = binascii.unhexlify(b"b1e1349120b6e832ef976f5d") diff --git a/tests/hazmat/primitives/test_pkcs12.py b/tests/hazmat/primitives/test_pkcs12.py index cd9c279ac4b02..f5a092ff54165 100644 --- a/tests/hazmat/primitives/test_pkcs12.py +++ b/tests/hazmat/primitives/test_pkcs12.py @@ -19,6 +19,7 @@ ed25519, rsa, ) +from cryptography.hazmat.primitives.ciphers import modes from cryptography.hazmat.primitives.serialization import ( Encoding, PublicFormat, @@ -81,7 +82,9 @@ def test_load_pkcs12_ec_keys(self, filename, password, backend): ], ) @pytest.mark.supported( - only_if=lambda backend: backend.cipher_supported(_RC2(), None), + only_if=lambda backend: backend.cipher_supported( + _RC2(), modes.CBC(initialization_vector=b"\x00" * 16) + ), skip_message="Does not support RC2", ) def test_load_pkcs12_ec_keys_rc2(self, filename, password, backend):