diff --git a/.github/workflows/ci_linting.yml b/.github/workflows/ci_linting.yml index 1211cca3494..21af472ab13 100644 --- a/.github/workflows/ci_linting.yml +++ b/.github/workflows/ci_linting.yml @@ -74,20 +74,25 @@ jobs: run: | ./codebuild/bin/run_kwstyle.sh ./codebuild/bin/cpp_style_comment_linter.sh - pepeight: + ruff: runs-on: ubuntu-latest steps: - name: checkout uses: actions/checkout@v4 - - name: Run autopep8 - id: autopep8 - uses: peter-evans/autopep8@v2 - with: - args: --diff --exit-code . + + - name: Set up uv + uses: astral-sh/setup-uv@v5 + + - name: Run Ruff formatting check + working-directory: tests/integrationv2 + id: ruff_format + run: uv run ruff format --check . + continue-on-error: true + - name: Check exit code - if: steps.autopep8.outputs.exit-code != 0 + if: steps.ruff_format.outcome == 'failure' run: | - echo "Run 'autopep8 --in-place .' to fix" + echo "Run 'ruff format .' to fix formatting issues" exit 1 clang-format: runs-on: ubuntu-latest diff --git a/.pep8 b/.pep8 deleted file mode 100644 index ae8b67bcec6..00000000000 --- a/.pep8 +++ /dev/null @@ -1,3 +0,0 @@ -[pep8] -max_line_length = 120 -recursive = true diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ad07d88bed..1882c3b50a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -165,7 +165,7 @@ endif () if(BUILD_TESTING AND BUILD_SHARED_LIBS OR S2N_FUZZ_TEST) target_compile_options(${PROJECT_NAME} PRIVATE -fvisibility=default) else() - target_compile_options(${PROJECT_NAME} PRIVATE -fvisibility=hidden -DS2N_EXPORTS) + target_compile_options(${PROJECT_NAME} PRIVATE -fvisibility=hidden -DS2N_EXPORTS=1) endif() if(S2N_LTO) @@ -197,7 +197,7 @@ target_compile_options(${PROJECT_NAME} PRIVATE -include "${S2N_PRELUDE}") # Match on Release, RelWithDebInfo and MinSizeRel # See: https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html#variable:CMAKE_BUILD_TYPE if(CMAKE_BUILD_TYPE MATCHES Rel) - add_definitions(-DS2N_BUILD_RELEASE) + add_definitions(-DS2N_BUILD_RELEASE=1) endif() if(NO_STACK_PROTECTOR) @@ -251,7 +251,7 @@ endif() if (NOT S2N_OVERRIDE_LIBCRYPTO_RAND_ENGINE) message(STATUS "Disabling libcrypto RAND engine override") - add_definitions(-DS2N_DISABLE_RAND_ENGINE_OVERRIDE) + add_definitions(-DS2N_DISABLE_RAND_ENGINE_OVERRIDE=1) endif() # For interning, we need to find the static libcrypto library. Cmake configs @@ -316,7 +316,7 @@ function(feature_probe_result PROBE_NAME IS_AVAILABLE) # define the probe if available if(NORMALIZED) - add_definitions(-D${PROBE_NAME}) + add_definitions(-D${PROBE_NAME}=1) endif() endfunction() @@ -426,7 +426,7 @@ if (S2N_INTERN_LIBCRYPTO) DEPENDS libcrypto.symbols ) add_dependencies(${PROJECT_NAME} s2n_libcrypto) - add_definitions(-DS2N_INTERN_LIBCRYPTO) + add_definitions(-DS2N_INTERN_LIBCRYPTO=1) if ((BUILD_SHARED_LIBS AND BUILD_TESTING) OR NOT BUILD_SHARED_LIBS) # if libcrypto needs to be interned, rewrite libcrypto references so use of internal functions will link correctly diff --git a/api/unstable/crl.h b/api/unstable/crl.h index 0e0388c0c92..22856b8af9c 100644 --- a/api/unstable/crl.h +++ b/api/unstable/crl.h @@ -187,12 +187,16 @@ struct s2n_cert_validation_info; * * If the validation performed in the callback is successful, `s2n_cert_validation_accept()` MUST be called to allow * `s2n_negotiate()` to continue the handshake. If the validation is unsuccessful, `s2n_cert_validation_reject()` - * MUST be called, which will cause `s2n_negotiate()` to error. The behavior of `s2n_negotiate()` is undefined if - * neither `s2n_cert_validation_accept()` or `s2n_cert_validation_reject()` are called. + * MUST be called, which will cause `s2n_negotiate()` to error. + * + * To use the validation callback asynchronously, return `S2N_SUCCESS` without calling `s2n_cert_validation_accept()` + * or `s2n_cert_validation_reject()`. This will pause the handshake, and `s2n_negotiate()` will throw an `S2N_ERR_T_BLOCKED` + * error and `s2n_blocked_status` will be set to `S2N_BLOCKED_ON_APPLICATION_INPUT`. Applications should call + * `s2n_cert_validation_accept()` or `s2n_cert_validation_reject()` to unpause the handshake before retrying `s2n_negotiate()`. * * The `info` parameter is passed to the callback in order to call APIs specific to the cert validation callback, like - * `s2n_cert_validation_accept()` and `s2n_cert_validation_reject()`. The `info` argument is only valid for the - * lifetime of the callback, and must not be used after the callback has finished. + * `s2n_cert_validation_accept()` and `s2n_cert_validation_reject()`. The `info` argument shares the same lifetime as + * `s2n_connection`. * * After calling `s2n_cert_validation_reject()`, `s2n_negotiate()` will fail with a protocol error indicating that * the cert has been rejected from the callback. If more information regarding an application's custom validation diff --git a/bindings/rust/extended/s2n-tls/src/cert_chain.rs b/bindings/rust/extended/s2n-tls/src/cert_chain.rs index 4c5790cf654..9b15e022477 100644 --- a/bindings/rust/extended/s2n-tls/src/cert_chain.rs +++ b/bindings/rust/extended/s2n-tls/src/cert_chain.rs @@ -1,9 +1,11 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use crate::error::{Error, Fallible}; +use crate::error::{Error, ErrorType, Fallible}; use s2n_tls_sys::*; use std::{ + any::Any, + ffi::c_void, marker::PhantomData, ptr::{self, NonNull}, sync::Arc, @@ -13,6 +15,7 @@ use std::{ /// /// [CertificateChain] is internally reference counted. The reference counted `T` /// must have a drop implementation. +#[derive(Debug)] pub(crate) struct CertificateChainHandle<'a> { pub(crate) cert: NonNull, is_owned: bool, @@ -45,20 +48,57 @@ impl CertificateChainHandle<'_> { _lifetime: PhantomData, } } + + /// Corresponds to [s2n_cert_chain_and_key_get_ctx]. + fn context_mut(&mut self) -> Option<&mut Context> { + let context = unsafe { s2n_cert_chain_and_key_get_ctx(self.cert.as_ptr()) }; + if context.is_null() { + None + } else { + Some(unsafe { &mut *(context as *mut Context) }) + } + } + + /// Corresponds to [s2n_cert_chain_and_key_get_ctx]. + fn context(&self) -> Option<&Context> { + let context = unsafe { s2n_cert_chain_and_key_get_ctx(self.cert.as_ptr()) }; + if context.is_null() { + None + } else { + Some(unsafe { &*(context as *const Context) }) + } + } } impl Drop for CertificateChainHandle<'_> { /// Corresponds to [s2n_cert_chain_and_key_free]. fn drop(&mut self) { - // ignore failures since there's not much we can do about it if self.is_owned { + if let Some(internal_context) = self.context_mut() { + drop(unsafe { Box::from_raw(internal_context) }); + } + // ignore failures since there's not much we can do about it unsafe { + // null the cert chain context out of an abundance of caution + let _ = s2n_cert_chain_and_key_set_ctx(self.cert.as_ptr(), std::ptr::null_mut()) + .into_result(); + let _ = s2n_cert_chain_and_key_free(self.cert.as_ptr()).into_result(); } } } } +/// An internal container to hold the customer supplied application context. +/// +/// We can't directly store the application context on the `s2n_cert_chain_and_key`, +/// because `*mut dyn Any` is a fat pointer (16 bytes) and can not be stored as +/// a c_void (8 bytes). +struct Context { + application_context: Box, +} + +#[derive(Debug)] pub struct Builder { cert_handle: CertificateChainHandle<'static>, } @@ -125,6 +165,39 @@ impl Builder { Ok(self) } + /// Associates an arbitrary application context with the CertificateChain to + /// be later retrieved via [`CertificateChain::application_context()`]. + /// + /// This API will override an existing application context set on the Builder. + /// + /// Corresponds to [s2n_cert_chain_and_key_set_ctx]. + pub fn set_application_context( + &mut self, + app_context: T, + ) -> Result<&mut Self, Error> { + match self.cert_handle.context_mut() { + Some(_) => Err(Error::bindings( + ErrorType::UsageError, + "cert builder error", + "set_application_context can only be called once", + )), + None => { + let app_context = Box::new(app_context); + let internal_context = Box::new(Context { + application_context: app_context, + }); + unsafe { + s2n_cert_chain_and_key_set_ctx( + self.cert_handle.cert.as_ptr(), + Box::into_raw(internal_context) as *mut c_void, + ) + .into_result() + }?; + Ok(self) + } + } + } + /// Return an immutable, internally-reference counted CertificateChain. pub fn build(self) -> Result, Error> { // This method is currently infallible, but returning a result allows @@ -177,6 +250,23 @@ impl CertificateChain<'_> { } } + /// Retrieves a reference to the application context associated with the + /// CertificateChain. + /// + /// If an application context hasn't been set on the CertificateChain or if + /// the set application context isn't of type `T`, `None` will be returned. + /// + /// To set a context on the connection, use [`Builder::set_application_context()`]. + /// + /// Corresponds to [s2n_cert_chain_and_key_get_ctx]. + pub fn application_context(&self) -> Option<&T> { + if let Some(internal_context) = self.cert_handle.context() { + internal_context.application_context.downcast_ref() + } else { + None + } + } + /// Return the length of this certificate chain. /// /// Note that the underlying API currently traverses a linked list, so this is a relatively @@ -273,9 +363,12 @@ unsafe impl Send for Certificate<'_> {} mod tests { use crate::{ config, - error::{ErrorSource, ErrorType}, + error::{Error as S2NError, ErrorSource, ErrorType}, security::DEFAULT_TLS13, - testing::{InsecureAcceptAllCertificatesHandler, SniTestCerts, TestPair}, + testing::{ + config_builder, CertKeyPair, InsecureAcceptAllCertificatesHandler, SniTestCerts, + TestPair, + }, }; use super::*; @@ -495,4 +588,67 @@ mod tests { fn assert_send_sync() {} assert_send_sync::>(); } + + /// sanity check for basic cert chain context interactions + #[test] + fn application_context_workflow() -> Result<(), S2NError> { + let context: Arc = Arc::new(0xC0FFEE); + let handle = Arc::clone(&context); + assert_eq!(Arc::strong_count(&handle), 2); + + let default = CertKeyPair::default(); + let mut chain = Builder::new()?; + chain.load_pem(default.cert(), default.key())?; + chain.set_application_context(context)?; + let chain = chain.build()?; + + let invalid_type_get = chain.application_context::(); + assert!(invalid_type_get.is_none()); + + let retrieved_context = chain.application_context::>().unwrap(); + assert_eq!(*retrieved_context.as_ref(), 0xC0FFEE); + assert_eq!(Arc::strong_count(&handle), 2); + drop(chain); + assert_eq!(Arc::strong_count(&handle), 1); + Ok(()) + } + + /// When an application context is overridden, it should be error. + #[test] + fn application_context_override() -> Result<(), S2NError> { + let initial: Arc = Arc::new(0xC0FFEE); + let overridden: Arc<[u8; 6]> = Arc::new(*b"coffee"); + + let mut builder = Builder::new()?; + builder.set_application_context(initial)?; + let err = builder.set_application_context(overridden).unwrap_err(); + assert_eq!(err.kind(), ErrorType::UsageError); + + Ok(()) + } + + /// An application context should be retrievable from a selected cert after + /// the handshake. + #[test] + fn application_context_from_selected_cert() -> Result<(), S2NError> { + let default = CertKeyPair::default(); + let mut chain = Builder::new()?; + chain.load_pem(default.cert(), default.key())?; + chain.set_application_context(0xC0FFEE_u64)?; + + let mut server_config = config::Builder::new(); + server_config.load_chain(chain.build()?)?; + + let client_config = config_builder(&crate::security::DEFAULT).unwrap(); + + let mut test_pair = + TestPair::from_configs(&client_config.build()?, &server_config.build()?); + test_pair.handshake()?; + + let selected_cert = test_pair.server.selected_cert().unwrap(); + let context = selected_cert.application_context::(); + assert_eq!(context, Some(&0xC0FFEE_u64)); + + Ok(()) + } } diff --git a/codebuild/bin/grep_simple_mistakes.sh b/codebuild/bin/grep_simple_mistakes.sh index ad9806ec2d7..36af4b974b5 100755 --- a/codebuild/bin/grep_simple_mistakes.sh +++ b/codebuild/bin/grep_simple_mistakes.sh @@ -14,6 +14,20 @@ FAILED=0 +############################################# +# Grep for command line defines without values +############################################# +EMPTY_DEFINES=$(grep -Eon "\-D[^=]+=?" CMakeLists.txt | grep -v =) +if [ ! -z "${EMPTY_DEFINES}" ]; then + FAILED=1 + printf "\e[1;34mCommand line define is missing value:\e[0m " + printf "Compilers SHOULD set a default value of 1 when no default is given, " + printf "but that behavior is not required by any official spec. Set a value just in case. " + printf "For example: -DS2N_FOO=1 instead of -DS2N_FOO.\n" + printf "Found: \n" + echo "$EMPTY_DEFINES" +fi + ############################################# # Grep for bindings methods without C documentation links. ############################################# @@ -74,7 +88,6 @@ KNOWN_MEMCMP_USAGE["$PWD/stuffer/s2n_stuffer_text.c"]=1 KNOWN_MEMCMP_USAGE["$PWD/tls/s2n_psk.c"]=1 KNOWN_MEMCMP_USAGE["$PWD/tls/s2n_protocol_preferences.c"]=1 KNOWN_MEMCMP_USAGE["$PWD/tls/s2n_cipher_suites.c"]=1 -KNOWN_MEMCMP_USAGE["$PWD/tls/s2n_config.c"]=1 KNOWN_MEMCMP_USAGE["$PWD/utils/s2n_map.c"]=3 for file in $S2N_FILES_ASSERT_NOT_USING_MEMCMP; do diff --git a/codebuild/bin/install_awslc_fips.sh b/codebuild/bin/install_awslc_fips.sh new file mode 100755 index 00000000000..6e47a3c206c --- /dev/null +++ b/codebuild/bin/install_awslc_fips.sh @@ -0,0 +1,73 @@ +#!/usr/bin/env bash +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +set -eu + +usage() { + echo "install_awslc_fips.sh build_dir install_dir version" + exit 1 +} + +check_dep(){ + if [[ ! -f "$(which $1)" ]]; then + echo "Could not find $1" + exit 1 + fi +} + +clone(){ + git clone https://github.com/awslabs/aws-lc.git --branch "$AWSLC_BRANCH" --depth 1 $BUILD_DIR + cd "$BUILD_DIR" +} + +build() { + echo "Building with shared library=$1" + cmake $BUILD_DIR \ + -Bbuild \ + -GNinja \ + -DBUILD_SHARED_LIBS=$1 \ + -DCMAKE_BUILD_TYPE=relwithdebinfo \ + -DCMAKE_INSTALL_PREFIX="${INSTALL_DIR}" \ + -DCMAKE_C_COMPILER=$(which clang) \ + -DCMAKE_CXX_COMPILER=$(which clang++) \ + -DFIPS="true" + ninja -j "$(nproc)" -C build install + ninja -C build clean +} + +# main +if [ "$#" -ne "3" ]; then + usage +fi + +# Ensure tooling is available +check_dep clang +check_dep ninja +check_dep go + +BUILD_DIR=$1 +INSTALL_DIR=$2 +VERSION=$3 + +# Map version to a specific feature branch/tag. +case $VERSION in + "2022") + AWSLC_BRANCH=AWS-LC-FIPS-2.0.17 + ;; + "2024") + AWSLC_BRANCH=AWS-LC-FIPS-3.0.0 + ;; + *) + echo "Unknown version: $VERSION" + usage + ;; +esac + +clone +# Static lib +build false +# Shared lib +build true + +rm -rf $BUILD_DIR + diff --git a/codebuild/bin/install_awslc_fips_2022.sh b/codebuild/bin/install_awslc_fips_2022.sh index 4d8ae96517c..34cdefa2311 100755 --- a/codebuild/bin/install_awslc_fips_2022.sh +++ b/codebuild/bin/install_awslc_fips_2022.sh @@ -1,19 +1,7 @@ -#!/bin/bash +#!/usr/bin/env bash # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://aws.amazon.com/apache2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - +# SPDX-License-Identifier: Apache-2.0 set -eu -pushd "$(pwd)" usage() { echo "install_awslc_fips_2022.sh build_dir install_dir" @@ -24,38 +12,9 @@ if [ "$#" -ne "2" ]; then usage fi +CBPATH=$(dirname $0) BUILD_DIR=$1 INSTALL_DIR=$2 -if [[ ! -f "$(which clang)" ]]; then - echo "Could not find clang" - exit 1 -fi - -AWSLC_VERSION=AWS-LC-FIPS-2.0.17 - -mkdir -p "$BUILD_DIR" || true -cd "$BUILD_DIR" -# --branch can also take tags and detaches the HEAD at that commit in the resulting repository -# --depth 1 Create a shallow clone with a history truncated to 1 commit -git clone https://github.com/awslabs/aws-lc.git --branch "$AWSLC_VERSION" --depth 1 - -build() { - shared=$1 - cmake . \ - -Bbuild \ - -GNinja \ - -DBUILD_SHARED_LIBS="${shared}" \ - -DCMAKE_BUILD_TYPE=relwithdebinfo \ - -DCMAKE_INSTALL_PREFIX="${INSTALL_DIR}" \ - -DCMAKE_C_COMPILER=$(which clang) \ - -DCMAKE_CXX_COMPILER=$(which clang++) \ - -DFIPS=1 - ninja -j "$(nproc)" -C build install - ninja -C build clean -} - -build 0 -build 1 +$CBPATH/install_awslc_fips.sh $@ 2022 -exit 0 diff --git a/codebuild/bin/install_awslc_fips_2024.sh b/codebuild/bin/install_awslc_fips_2024.sh new file mode 100755 index 00000000000..cf4a951bf0d --- /dev/null +++ b/codebuild/bin/install_awslc_fips_2024.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +set -eu + +usage() { + echo "install_awslc_fips_2024.sh build_dir install_dir" + exit 1 +} + +if [ "$#" -ne "2" ]; then + usage +fi + +CBPATH=$(dirname $0) + +$CBPATH/install_awslc_fips.sh $@ 2024 + diff --git a/codebuild/bin/s2n_setup_env.sh b/codebuild/bin/s2n_setup_env.sh index 4ef3cdfa183..9933158bdaa 100755 --- a/codebuild/bin/s2n_setup_env.sh +++ b/codebuild/bin/s2n_setup_env.sh @@ -47,6 +47,7 @@ source codebuild/bin/s2n_set_build_preset.sh : "${AWSLC_INSTALL_DIR:=$TEST_DEPS_DIR/awslc}" : "${AWSLC_FIPS_INSTALL_DIR:=$TEST_DEPS_DIR/awslc-fips}" : "${AWSLC_FIPS_2022_INSTALL_DIR:=$TEST_DEPS_DIR/awslc-fips-2022}" +: "${AWSLC_FIPS_2024_INSTALL_DIR:=$TEST_DEPS_DIR/awslc-fips-2024}" : "${LIBRESSL_INSTALL_DIR:=$TEST_DEPS_DIR/libressl}" : "${CPPCHECK_INSTALL_DIR:=$TEST_DEPS_DIR/cppcheck}" : "${CTVERIF_INSTALL_DIR:=$TEST_DEPS_DIR/ctverif}" diff --git a/codebuild/spec/buildspec_disable_rand_override.yml b/codebuild/spec/buildspec_disable_rand_override.yml index 65802da6414..f1f0aefbb71 100644 --- a/codebuild/spec/buildspec_disable_rand_override.yml +++ b/codebuild/spec/buildspec_disable_rand_override.yml @@ -21,6 +21,12 @@ env: CTEST_OUTPUT_ON_FAILURE: 1 phases: + pre_build: + commands: + - | + if [ -d "third-party-src" ]; then + cd third-party-src; + fi build: on-failure: ABORT commands: diff --git a/codebuild/spec/buildspec_generalbatch.yml b/codebuild/spec/buildspec_generalbatch.yml index bb9249063e2..77087a75539 100644 --- a/codebuild/spec/buildspec_generalbatch.yml +++ b/codebuild/spec/buildspec_generalbatch.yml @@ -16,16 +16,6 @@ version: 0.2 batch: build-list: - - buildspec: codebuild/spec/buildspec_ubuntu.yml - env: - compute-type: BUILD_GENERAL1_LARGE - image: 024603541914.dkr.ecr.us-west-2.amazonaws.com/docker:ubuntu18codebuild - privileged-mode: true - variables: - GCC_VERSION: NONE - SAW: true - TESTS: sawHMACPlus - identifier: sawHMACPlus - buildspec: codebuild/spec/buildspec_ubuntu.yml env: compute-type: BUILD_GENERAL1_LARGE diff --git a/codebuild/spec/buildspec_openssl3fips.yml b/codebuild/spec/buildspec_openssl3fips.yml index 4caef166990..42713db2f5c 100644 --- a/codebuild/spec/buildspec_openssl3fips.yml +++ b/codebuild/spec/buildspec_openssl3fips.yml @@ -37,3 +37,4 @@ phases: # openssl3fips is still a work-in-progress. Not all tests pass. - make -C build test -- ARGS="-R 's2n_build_test|s2n_fips_test'" - make -C build test -- ARGS="-R 's2n_hash_test|s2n_hash_all_algs_test|s2n_openssl_test|s2n_init_test'" + - make -C build test -- ARGS="-R 's2n_evp_signing_test'" diff --git a/codebuild/spec/buildspec_ubuntu_integrationv2.yml b/codebuild/spec/buildspec_ubuntu_integrationv2.yml index 13ee3ab609e..8b4870f1f38 100644 --- a/codebuild/spec/buildspec_ubuntu_integrationv2.yml +++ b/codebuild/spec/buildspec_ubuntu_integrationv2.yml @@ -39,10 +39,10 @@ batch: - "test_client_authentication test_dynamic_record_sizes test_sslyze test_sslv2_client_hello" - "test_happy_path" - "test_cross_compatibility" - - "test_early_data test_well_known_endpoints test_hello_retry_requests test_sni_match test_pq_handshake test_fragmentation test_key_update" + - "test_early_data test_hello_retry_requests test_sni_match test_pq_handshake test_fragmentation test_key_update" - "test_session_resumption test_renegotiate_apache test_buffered_send" - "test_npn test_signature_algorithms" - - "test_version_negotiation test_external_psk test_ocsp test_renegotiate test_serialization" + - "test_version_negotiation test_external_psk test_ocsp test_renegotiate test_serialization test_record_padding" env: variables: diff --git a/crypto/s2n_evp_signing.c b/crypto/s2n_evp_signing.c index bd0b9c5972d..d4762b22767 100644 --- a/crypto/s2n_evp_signing.c +++ b/crypto/s2n_evp_signing.c @@ -16,6 +16,7 @@ #include "crypto/s2n_evp_signing.h" #include "crypto/s2n_evp.h" +#include "crypto/s2n_libcrypto.h" #include "crypto/s2n_pkey.h" #include "crypto/s2n_rsa_pss.h" #include "error/s2n_errno.h" @@ -24,12 +25,6 @@ DEFINE_POINTER_CLEANUP_FUNC(EVP_PKEY_CTX *, EVP_PKEY_CTX_free); -/* - * FIPS 140-3 requires that we don't pass raw digest bytes to the libcrypto signing methods. - * In order to do that, we need to use signing methods that both calculate the digest and - * perform the signature. - */ - static S2N_RESULT s2n_evp_md_ctx_set_pkey_ctx(EVP_MD_CTX *ctx, EVP_PKEY_CTX *pctx) { #ifdef S2N_LIBCRYPTO_SUPPORTS_EVP_MD_CTX_SET_PKEY_CTX @@ -50,51 +45,16 @@ static S2N_RESULT s2n_evp_pkey_set_rsa_pss_saltlen(EVP_PKEY_CTX *pctx) #endif } -bool s2n_evp_signing_supported() -{ -#ifdef S2N_LIBCRYPTO_SUPPORTS_EVP_MD_CTX_SET_PKEY_CTX - /* We can only use EVP signing if the hash state has an EVP_MD_CTX - * that we can pass to the EVP signing methods. - */ - return s2n_hash_evp_fully_supported(); -#else - return false; -#endif -} - -/* If using EVP signing, override the sign and verify pkey methods. - * The EVP methods can handle all pkey types / signature algorithms. +/* Always use EVP signing. + * + * TODO: Migrate the rest of the s2n_pkey methods to EVP and delete the legacy + * pkey logic and this method. */ S2N_RESULT s2n_evp_signing_set_pkey_overrides(struct s2n_pkey *pkey) { - if (s2n_evp_signing_supported()) { - RESULT_ENSURE_REF(pkey); - pkey->sign = &s2n_evp_sign; - pkey->verify = &s2n_evp_verify; - } - return S2N_RESULT_OK; -} - -static S2N_RESULT s2n_evp_signing_validate_hash_alg(s2n_signature_algorithm sig_alg, s2n_hash_algorithm hash_alg) -{ - switch (hash_alg) { - case S2N_HASH_NONE: - case S2N_HASH_MD5: - /* MD5 alone is never supported */ - RESULT_BAIL(S2N_ERR_HASH_INVALID_ALGORITHM); - break; - case S2N_HASH_MD5_SHA1: - /* Only RSA supports MD5+SHA1. - * This should not be a problem, as we only allow MD5+SHA1 when - * falling back to TLS1.0 or 1.1, which only support RSA. - */ - RESULT_ENSURE(sig_alg == S2N_SIGNATURE_RSA, S2N_ERR_HASH_INVALID_ALGORITHM); - break; - default: - break; - } - /* Hash algorithm must be recognized and supported by EVP_MD */ - RESULT_ENSURE(s2n_hash_alg_to_evp_md(hash_alg) != NULL, S2N_ERR_HASH_INVALID_ALGORITHM); + RESULT_ENSURE_REF(pkey); + pkey->sign = &s2n_evp_sign; + pkey->verify = &s2n_evp_verify; return S2N_RESULT_OK; } @@ -112,16 +72,123 @@ static S2N_RESULT s2n_evp_signing_validate_sig_alg(const struct s2n_pkey *key, s return S2N_RESULT_OK; } +static EVP_PKEY_CTX *s2n_evp_pkey_ctx_new(EVP_PKEY *pkey, s2n_hash_algorithm hash_alg) +{ + PTR_ENSURE_REF(pkey); + switch (hash_alg) { +#if S2N_LIBCRYPTO_SUPPORTS_PROVIDERS + /* For openssl-3.0, pkey methods will do an implicit fetch for the signing + * algorithm, which includes the hash algorithm. If using a legacy hash + * algorithm, specify the non-fips version. + */ + case S2N_HASH_MD5: + case S2N_HASH_MD5_SHA1: + case S2N_HASH_SHA1: + return EVP_PKEY_CTX_new_from_pkey(NULL, pkey, "-fips"); +#endif + default: + return EVP_PKEY_CTX_new(pkey, NULL); + } +} + +/* Our "digest-and-sign" EVP signing logic is intended to support FIPS 140-3. + * FIPS 140-3 does not allow signing or verifying externally calculated digests + * (except for signing, but not verifying, with ECDSA). + * See https://csrc.nist.gov/Projects/Cryptographic-Algorithm-Validation-Program/Digital-Signatures, + * and note that "component" tests only exist for ECDSA sign. + * + * In order to avoid signing externally calculated digests, we naively would + * need access to the full message to be signed at the time of signing. That's + * a problem for TLS1.2, where the client cert verify message requires signing + * every handshake message sent or received before the client cert verify message. + * To avoid storing every single handshake message in its entirety, we instead + * keep a running hash of the messages in an EVP hash state. Then, instead of + * digesting that hash state, we pass it unmodified to EVP_DigestSignFinal. + * That would normally not be allowed, since the hash state was initialized without + * a key using EVP_DigestInit instead of with a key using EVP_DigestSignInit. + * We make it work by using the EVP_MD_CTX_set_pkey_ctx method to attach a key + * to an existing hash state. + * + * All that means that "digest-and-sign" requires two things: + * - A single EVP hash state to sign. So we must not use a custom MD5_SHA1 hash, + * which doesn't produce a single hash state. + * - EVP_MD_CTX_set_pkey_ctx to exist and to behave as expected. Existence + * alone is not sufficient: the method exists in openssl-3.0-fips, but + * it cannot be used to setup a hash state for EVP_DigestSignFinal. + * + * Currently only awslc-fips meets both these requirements. New libcryptos + * should be assumed not to meet these requirements until proven otherwise. + */ +int s2n_evp_digest_and_sign(EVP_PKEY_CTX *pctx, s2n_signature_algorithm sig_alg, + struct s2n_hash_state *hash_state, struct s2n_blob *signature) +{ + POSIX_ENSURE_REF(pctx); + POSIX_ENSURE_REF(hash_state); + POSIX_ENSURE_REF(signature); + + /* Custom MD5_SHA1 involves combining separate MD5 and SHA1 hashes. + * That involves two hash states instead of the single hash state this + * method requires. + */ + POSIX_ENSURE(!s2n_hash_use_custom_md5_sha1(), S2N_ERR_SAFETY); + + /* Not all implementations of EVP_MD_CTX_set_pkey_ctx behave as required + * by this method. Using EVP_MD_CTX_set_pkey_ctx to convert a hash initialized + * with EVP_DigestInit to one that can be finalized with EVP_DigestSignFinal + * is not entirely standard. + * + * However, this behavior is known to work with awslc-fips. + */ + POSIX_ENSURE(s2n_libcrypto_is_awslc_fips(), S2N_ERR_SAFETY); + + EVP_MD_CTX *ctx = hash_state->digest.high_level.evp.ctx; + POSIX_ENSURE_REF(ctx); + POSIX_GUARD_RESULT(s2n_evp_md_ctx_set_pkey_ctx(ctx, pctx)); + + size_t signature_size = signature->size; + POSIX_GUARD_OSSL(EVP_DigestSignFinal(ctx, signature->data, &signature_size), S2N_ERR_SIGN); + POSIX_ENSURE(signature_size <= signature->size, S2N_ERR_SIZE_MISMATCH); + signature->size = signature_size; + POSIX_GUARD_RESULT(s2n_evp_md_ctx_set_pkey_ctx(ctx, NULL)); + + return S2N_SUCCESS; +} + +/* "digest-then-sign" means that we calculate the digest for a hash state, + * then sign the digest bytes. That is not allowed by FIPS 140-3, but is allowed + * in all other cases. + */ +int s2n_evp_digest_then_sign(EVP_PKEY_CTX *pctx, + struct s2n_hash_state *hash_state, struct s2n_blob *signature) +{ + POSIX_ENSURE_REF(pctx); + POSIX_ENSURE_REF(hash_state); + POSIX_ENSURE_REF(signature); + + uint8_t digest_length = 0; + POSIX_GUARD(s2n_hash_digest_size(hash_state->alg, &digest_length)); + POSIX_ENSURE_LTE(digest_length, S2N_MAX_DIGEST_LEN); + + uint8_t digest_out[S2N_MAX_DIGEST_LEN] = { 0 }; + POSIX_GUARD(s2n_hash_digest(hash_state, digest_out, digest_length)); + + size_t signature_size = signature->size; + POSIX_GUARD_OSSL(EVP_PKEY_sign(pctx, signature->data, &signature_size, + digest_out, digest_length), + S2N_ERR_SIGN); + POSIX_ENSURE(signature_size <= signature->size, S2N_ERR_SIZE_MISMATCH); + signature->size = signature_size; + + return S2N_SUCCESS; +} + int s2n_evp_sign(const struct s2n_pkey *priv, s2n_signature_algorithm sig_alg, struct s2n_hash_state *hash_state, struct s2n_blob *signature) { POSIX_ENSURE_REF(priv); POSIX_ENSURE_REF(hash_state); - POSIX_ENSURE_REF(signature); - POSIX_ENSURE(s2n_evp_signing_supported(), S2N_ERR_HASH_NOT_READY); - POSIX_GUARD_RESULT(s2n_evp_signing_validate_hash_alg(sig_alg, hash_state->alg)); - DEFER_CLEANUP(EVP_PKEY_CTX *pctx = EVP_PKEY_CTX_new(priv->pkey, NULL), EVP_PKEY_CTX_free_pointer); + DEFER_CLEANUP(EVP_PKEY_CTX *pctx = s2n_evp_pkey_ctx_new(priv->pkey, hash_state->alg), EVP_PKEY_CTX_free_pointer); POSIX_ENSURE_REF(pctx); POSIX_GUARD_OSSL(EVP_PKEY_sign_init(pctx), S2N_ERR_PKEY_CTX_INIT); POSIX_GUARD_OSSL(S2N_EVP_PKEY_CTX_set_signature_md(pctx, s2n_hash_alg_to_evp_md(hash_state->alg)), S2N_ERR_PKEY_CTX_INIT); @@ -131,15 +198,55 @@ int s2n_evp_sign(const struct s2n_pkey *priv, s2n_signature_algorithm sig_alg, POSIX_GUARD_RESULT(s2n_evp_pkey_set_rsa_pss_saltlen(pctx)); } + if (s2n_libcrypto_is_awslc_fips()) { + POSIX_GUARD(s2n_evp_digest_and_sign(pctx, sig_alg, hash_state, signature)); + } else { + POSIX_GUARD(s2n_evp_digest_then_sign(pctx, hash_state, signature)); + } + + return S2N_SUCCESS; +} + +/* See s2n_evp_digest_and_sign for more information */ +int s2n_evp_digest_and_verify(EVP_PKEY_CTX *pctx, s2n_signature_algorithm sig_alg, + struct s2n_hash_state *hash_state, struct s2n_blob *signature) +{ + POSIX_ENSURE_REF(pctx); + POSIX_ENSURE_REF(hash_state); + POSIX_ENSURE_REF(signature); + + /* See digest-and-sign requirements */ + POSIX_ENSURE(!s2n_hash_use_custom_md5_sha1(), S2N_ERR_SAFETY); + POSIX_ENSURE(s2n_libcrypto_is_awslc_fips(), S2N_ERR_SAFETY); + EVP_MD_CTX *ctx = hash_state->digest.high_level.evp.ctx; POSIX_ENSURE_REF(ctx); POSIX_GUARD_RESULT(s2n_evp_md_ctx_set_pkey_ctx(ctx, pctx)); - size_t signature_size = signature->size; - POSIX_GUARD_OSSL(EVP_DigestSignFinal(ctx, signature->data, &signature_size), S2N_ERR_SIGN); - POSIX_ENSURE(signature_size <= signature->size, S2N_ERR_SIZE_MISMATCH); - signature->size = signature_size; + POSIX_GUARD_OSSL(EVP_DigestVerifyFinal(ctx, signature->data, signature->size), S2N_ERR_VERIFY_SIGNATURE); POSIX_GUARD_RESULT(s2n_evp_md_ctx_set_pkey_ctx(ctx, NULL)); + + return S2N_SUCCESS; +} + +/* See s2n_evp_digest_then_sign for more information */ +int s2n_evp_digest_then_verify(EVP_PKEY_CTX *pctx, + struct s2n_hash_state *hash_state, struct s2n_blob *signature) +{ + POSIX_ENSURE_REF(pctx); + POSIX_ENSURE_REF(hash_state); + POSIX_ENSURE_REF(signature); + + uint8_t digest_length = 0; + POSIX_GUARD(s2n_hash_digest_size(hash_state->alg, &digest_length)); + POSIX_ENSURE_LTE(digest_length, S2N_MAX_DIGEST_LEN); + + uint8_t digest_out[S2N_MAX_DIGEST_LEN] = { 0 }; + POSIX_GUARD(s2n_hash_digest(hash_state, digest_out, digest_length)); + + POSIX_GUARD_OSSL(EVP_PKEY_verify(pctx, signature->data, signature->size, + digest_out, digest_length), + S2N_ERR_VERIFY_SIGNATURE); return S2N_SUCCESS; } @@ -149,11 +256,9 @@ int s2n_evp_verify(const struct s2n_pkey *pub, s2n_signature_algorithm sig_alg, POSIX_ENSURE_REF(pub); POSIX_ENSURE_REF(hash_state); POSIX_ENSURE_REF(signature); - POSIX_ENSURE(s2n_evp_signing_supported(), S2N_ERR_HASH_NOT_READY); - POSIX_GUARD_RESULT(s2n_evp_signing_validate_hash_alg(sig_alg, hash_state->alg)); POSIX_GUARD_RESULT(s2n_evp_signing_validate_sig_alg(pub, sig_alg)); - DEFER_CLEANUP(EVP_PKEY_CTX *pctx = EVP_PKEY_CTX_new(pub->pkey, NULL), EVP_PKEY_CTX_free_pointer); + DEFER_CLEANUP(EVP_PKEY_CTX *pctx = s2n_evp_pkey_ctx_new(pub->pkey, hash_state->alg), EVP_PKEY_CTX_free_pointer); POSIX_ENSURE_REF(pctx); POSIX_GUARD_OSSL(EVP_PKEY_verify_init(pctx), S2N_ERR_PKEY_CTX_INIT); POSIX_GUARD_OSSL(S2N_EVP_PKEY_CTX_set_signature_md(pctx, s2n_hash_alg_to_evp_md(hash_state->alg)), S2N_ERR_PKEY_CTX_INIT); @@ -163,11 +268,11 @@ int s2n_evp_verify(const struct s2n_pkey *pub, s2n_signature_algorithm sig_alg, POSIX_GUARD_RESULT(s2n_evp_pkey_set_rsa_pss_saltlen(pctx)); } - EVP_MD_CTX *ctx = hash_state->digest.high_level.evp.ctx; - POSIX_ENSURE_REF(ctx); - POSIX_GUARD_RESULT(s2n_evp_md_ctx_set_pkey_ctx(ctx, pctx)); + if (s2n_libcrypto_is_awslc_fips()) { + POSIX_GUARD(s2n_evp_digest_and_verify(pctx, sig_alg, hash_state, signature)); + } else { + POSIX_GUARD(s2n_evp_digest_then_verify(pctx, hash_state, signature)); + } - POSIX_GUARD_OSSL(EVP_DigestVerifyFinal(ctx, signature->data, signature->size), S2N_ERR_VERIFY_SIGNATURE); - POSIX_GUARD_RESULT(s2n_evp_md_ctx_set_pkey_ctx(ctx, NULL)); return S2N_SUCCESS; } diff --git a/crypto/s2n_evp_signing.h b/crypto/s2n_evp_signing.h index 720ce1b9f09..b0ee6a31981 100644 --- a/crypto/s2n_evp_signing.h +++ b/crypto/s2n_evp_signing.h @@ -20,7 +20,6 @@ #include "crypto/s2n_signature.h" #include "utils/s2n_blob.h" -bool s2n_evp_signing_supported(); S2N_RESULT s2n_evp_signing_set_pkey_overrides(struct s2n_pkey *pkey); int s2n_evp_sign(const struct s2n_pkey *priv, s2n_signature_algorithm sig_alg, struct s2n_hash_state *digest, struct s2n_blob *signature); diff --git a/crypto/s2n_hash.c b/crypto/s2n_hash.c index 090deb6f6c1..632264541cb 100644 --- a/crypto/s2n_hash.c +++ b/crypto/s2n_hash.c @@ -26,7 +26,7 @@ static EVP_MD *s2n_evp_mds[S2N_HASH_ALGS_COUNT] = { 0 }; static const EVP_MD *s2n_evp_mds[S2N_HASH_ALGS_COUNT] = { 0 }; #endif -static bool s2n_use_custom_md5_sha1() +bool s2n_hash_use_custom_md5_sha1() { #if defined(S2N_LIBCRYPTO_SUPPORTS_EVP_MD5_SHA1_HASH) return false; @@ -35,16 +35,6 @@ static bool s2n_use_custom_md5_sha1() #endif } -static bool s2n_use_evp_impl() -{ - return s2n_is_in_fips_mode(); -} - -bool s2n_hash_evp_fully_supported() -{ - return s2n_use_evp_impl() && !s2n_use_custom_md5_sha1(); -} - S2N_RESULT s2n_hash_algorithms_init() { #if S2N_LIBCRYPTO_SUPPORTS_PROVIDERS @@ -174,164 +164,10 @@ int s2n_hash_is_ready_for_input(struct s2n_hash_state *state) return state->is_ready_for_input; } -static int s2n_low_level_hash_new(struct s2n_hash_state *state) -{ - /* s2n_hash_new will always call the corresponding implementation of the s2n_hash - * being used. For the s2n_low_level_hash implementation, new is a no-op. - */ - - *state = (struct s2n_hash_state){ 0 }; - return S2N_SUCCESS; -} - -static int s2n_low_level_hash_init(struct s2n_hash_state *state, s2n_hash_algorithm alg) -{ - switch (alg) { - case S2N_HASH_NONE: - break; - case S2N_HASH_MD5: - POSIX_GUARD_OSSL(MD5_Init(&state->digest.low_level.md5), S2N_ERR_HASH_INIT_FAILED); - break; - case S2N_HASH_SHA1: - POSIX_GUARD_OSSL(SHA1_Init(&state->digest.low_level.sha1), S2N_ERR_HASH_INIT_FAILED); - break; - case S2N_HASH_SHA224: - POSIX_GUARD_OSSL(SHA224_Init(&state->digest.low_level.sha224), S2N_ERR_HASH_INIT_FAILED); - break; - case S2N_HASH_SHA256: - POSIX_GUARD_OSSL(SHA256_Init(&state->digest.low_level.sha256), S2N_ERR_HASH_INIT_FAILED); - break; - case S2N_HASH_SHA384: - POSIX_GUARD_OSSL(SHA384_Init(&state->digest.low_level.sha384), S2N_ERR_HASH_INIT_FAILED); - break; - case S2N_HASH_SHA512: - POSIX_GUARD_OSSL(SHA512_Init(&state->digest.low_level.sha512), S2N_ERR_HASH_INIT_FAILED); - break; - case S2N_HASH_MD5_SHA1: - POSIX_GUARD_OSSL(SHA1_Init(&state->digest.low_level.md5_sha1.sha1), S2N_ERR_HASH_INIT_FAILED); - POSIX_GUARD_OSSL(MD5_Init(&state->digest.low_level.md5_sha1.md5), S2N_ERR_HASH_INIT_FAILED); - break; - - default: - POSIX_BAIL(S2N_ERR_HASH_INVALID_ALGORITHM); - } - - state->alg = alg; - state->is_ready_for_input = 1; - state->currently_in_hash = 0; - - return 0; -} - -static int s2n_low_level_hash_update(struct s2n_hash_state *state, const void *data, uint32_t size) -{ - POSIX_ENSURE(state->is_ready_for_input, S2N_ERR_HASH_NOT_READY); - - switch (state->alg) { - case S2N_HASH_NONE: - break; - case S2N_HASH_MD5: - POSIX_GUARD_OSSL(MD5_Update(&state->digest.low_level.md5, data, size), S2N_ERR_HASH_UPDATE_FAILED); - break; - case S2N_HASH_SHA1: - POSIX_GUARD_OSSL(SHA1_Update(&state->digest.low_level.sha1, data, size), S2N_ERR_HASH_UPDATE_FAILED); - break; - case S2N_HASH_SHA224: - POSIX_GUARD_OSSL(SHA224_Update(&state->digest.low_level.sha224, data, size), S2N_ERR_HASH_UPDATE_FAILED); - break; - case S2N_HASH_SHA256: - POSIX_GUARD_OSSL(SHA256_Update(&state->digest.low_level.sha256, data, size), S2N_ERR_HASH_UPDATE_FAILED); - break; - case S2N_HASH_SHA384: - POSIX_GUARD_OSSL(SHA384_Update(&state->digest.low_level.sha384, data, size), S2N_ERR_HASH_UPDATE_FAILED); - break; - case S2N_HASH_SHA512: - POSIX_GUARD_OSSL(SHA512_Update(&state->digest.low_level.sha512, data, size), S2N_ERR_HASH_UPDATE_FAILED); - break; - case S2N_HASH_MD5_SHA1: - POSIX_GUARD_OSSL(SHA1_Update(&state->digest.low_level.md5_sha1.sha1, data, size), S2N_ERR_HASH_UPDATE_FAILED); - POSIX_GUARD_OSSL(MD5_Update(&state->digest.low_level.md5_sha1.md5, data, size), S2N_ERR_HASH_UPDATE_FAILED); - break; - default: - POSIX_BAIL(S2N_ERR_HASH_INVALID_ALGORITHM); - } - - POSIX_ENSURE(size <= (UINT64_MAX - state->currently_in_hash), S2N_ERR_INTEGER_OVERFLOW); - state->currently_in_hash += size; - - return S2N_SUCCESS; -} - -static int s2n_low_level_hash_digest(struct s2n_hash_state *state, void *out, uint32_t size) -{ - POSIX_ENSURE(state->is_ready_for_input, S2N_ERR_HASH_NOT_READY); - - switch (state->alg) { - case S2N_HASH_NONE: - break; - case S2N_HASH_MD5: - POSIX_ENSURE_EQ(size, MD5_DIGEST_LENGTH); - POSIX_GUARD_OSSL(MD5_Final(out, &state->digest.low_level.md5), S2N_ERR_HASH_DIGEST_FAILED); - break; - case S2N_HASH_SHA1: - POSIX_ENSURE_EQ(size, SHA_DIGEST_LENGTH); - POSIX_GUARD_OSSL(SHA1_Final(out, &state->digest.low_level.sha1), S2N_ERR_HASH_DIGEST_FAILED); - break; - case S2N_HASH_SHA224: - POSIX_ENSURE_EQ(size, SHA224_DIGEST_LENGTH); - POSIX_GUARD_OSSL(SHA224_Final(out, &state->digest.low_level.sha224), S2N_ERR_HASH_DIGEST_FAILED); - break; - case S2N_HASH_SHA256: - POSIX_ENSURE_EQ(size, SHA256_DIGEST_LENGTH); - POSIX_GUARD_OSSL(SHA256_Final(out, &state->digest.low_level.sha256), S2N_ERR_HASH_DIGEST_FAILED); - break; - case S2N_HASH_SHA384: - POSIX_ENSURE_EQ(size, SHA384_DIGEST_LENGTH); - POSIX_GUARD_OSSL(SHA384_Final(out, &state->digest.low_level.sha384), S2N_ERR_HASH_DIGEST_FAILED); - break; - case S2N_HASH_SHA512: - POSIX_ENSURE_EQ(size, SHA512_DIGEST_LENGTH); - POSIX_GUARD_OSSL(SHA512_Final(out, &state->digest.low_level.sha512), S2N_ERR_HASH_DIGEST_FAILED); - break; - case S2N_HASH_MD5_SHA1: - POSIX_ENSURE_EQ(size, MD5_DIGEST_LENGTH + SHA_DIGEST_LENGTH); - POSIX_GUARD_OSSL(SHA1_Final(((uint8_t *) out) + MD5_DIGEST_LENGTH, &state->digest.low_level.md5_sha1.sha1), S2N_ERR_HASH_DIGEST_FAILED); - POSIX_GUARD_OSSL(MD5_Final(out, &state->digest.low_level.md5_sha1.md5), S2N_ERR_HASH_DIGEST_FAILED); - break; - default: - POSIX_BAIL(S2N_ERR_HASH_INVALID_ALGORITHM); - } - - state->currently_in_hash = 0; - state->is_ready_for_input = 0; - return 0; -} - -static int s2n_low_level_hash_copy(struct s2n_hash_state *to, struct s2n_hash_state *from) -{ - POSIX_CHECKED_MEMCPY(to, from, sizeof(struct s2n_hash_state)); - return 0; -} - -static int s2n_low_level_hash_reset(struct s2n_hash_state *state) -{ - /* hash_init resets the ready_for_input and currently_in_hash fields. */ - return s2n_low_level_hash_init(state, state->alg); -} - -static int s2n_low_level_hash_free(struct s2n_hash_state *state) -{ - /* s2n_hash_free will always call the corresponding implementation of the s2n_hash - * being used. For the s2n_low_level_hash implementation, free is a no-op. - */ - state->is_ready_for_input = 0; - return S2N_SUCCESS; -} - static int s2n_evp_hash_new(struct s2n_hash_state *state) { POSIX_ENSURE_REF(state->digest.high_level.evp.ctx = S2N_EVP_MD_CTX_NEW()); - if (s2n_use_custom_md5_sha1()) { + if (s2n_hash_use_custom_md5_sha1()) { POSIX_ENSURE_REF(state->digest.high_level.evp_md5_secondary.ctx = S2N_EVP_MD_CTX_NEW()); } @@ -353,7 +189,7 @@ static int s2n_evp_hash_init(struct s2n_hash_state *state, s2n_hash_algorithm al return S2N_SUCCESS; } - if (alg == S2N_HASH_MD5_SHA1 && s2n_use_custom_md5_sha1()) { + if (alg == S2N_HASH_MD5_SHA1 && s2n_hash_use_custom_md5_sha1()) { POSIX_ENSURE_REF(state->digest.high_level.evp_md5_secondary.ctx); POSIX_GUARD_OSSL(EVP_DigestInit_ex(state->digest.high_level.evp.ctx, s2n_hash_alg_to_evp_md(S2N_HASH_SHA1), NULL), @@ -385,7 +221,7 @@ static int s2n_evp_hash_update(struct s2n_hash_state *state, const void *data, u POSIX_ENSURE_REF(EVP_MD_CTX_md(state->digest.high_level.evp.ctx)); POSIX_GUARD_OSSL(EVP_DigestUpdate(state->digest.high_level.evp.ctx, data, size), S2N_ERR_HASH_UPDATE_FAILED); - if (state->alg == S2N_HASH_MD5_SHA1 && s2n_use_custom_md5_sha1()) { + if (state->alg == S2N_HASH_MD5_SHA1 && s2n_hash_use_custom_md5_sha1()) { POSIX_ENSURE_REF(EVP_MD_CTX_md(state->digest.high_level.evp_md5_secondary.ctx)); POSIX_GUARD_OSSL(EVP_DigestUpdate(state->digest.high_level.evp_md5_secondary.ctx, data, size), S2N_ERR_HASH_UPDATE_FAILED); } @@ -411,7 +247,7 @@ static int s2n_evp_hash_digest(struct s2n_hash_state *state, void *out, uint32_t POSIX_ENSURE_REF(EVP_MD_CTX_md(state->digest.high_level.evp.ctx)); - if (state->alg == S2N_HASH_MD5_SHA1 && s2n_use_custom_md5_sha1()) { + if (state->alg == S2N_HASH_MD5_SHA1 && s2n_hash_use_custom_md5_sha1()) { POSIX_ENSURE_REF(EVP_MD_CTX_md(state->digest.high_level.evp_md5_secondary.ctx)); uint8_t sha1_digest_size = 0; @@ -447,7 +283,7 @@ static int s2n_evp_hash_copy(struct s2n_hash_state *to, struct s2n_hash_state *f POSIX_ENSURE_REF(to->digest.high_level.evp.ctx); POSIX_GUARD_OSSL(EVP_MD_CTX_copy_ex(to->digest.high_level.evp.ctx, from->digest.high_level.evp.ctx), S2N_ERR_HASH_COPY_FAILED); - if (from->alg == S2N_HASH_MD5_SHA1 && s2n_use_custom_md5_sha1()) { + if (from->alg == S2N_HASH_MD5_SHA1 && s2n_hash_use_custom_md5_sha1()) { POSIX_ENSURE_REF(to->digest.high_level.evp_md5_secondary.ctx); POSIX_GUARD_OSSL(EVP_MD_CTX_copy_ex(to->digest.high_level.evp_md5_secondary.ctx, from->digest.high_level.evp_md5_secondary.ctx), S2N_ERR_HASH_COPY_FAILED); } @@ -458,7 +294,7 @@ static int s2n_evp_hash_copy(struct s2n_hash_state *to, struct s2n_hash_state *f static int s2n_evp_hash_reset(struct s2n_hash_state *state) { POSIX_GUARD_OSSL(S2N_EVP_MD_CTX_RESET(state->digest.high_level.evp.ctx), S2N_ERR_HASH_WIPE_FAILED); - if (state->alg == S2N_HASH_MD5_SHA1 && s2n_use_custom_md5_sha1()) { + if (state->alg == S2N_HASH_MD5_SHA1 && s2n_hash_use_custom_md5_sha1()) { POSIX_GUARD_OSSL(S2N_EVP_MD_CTX_RESET(state->digest.high_level.evp_md5_secondary.ctx), S2N_ERR_HASH_WIPE_FAILED); } @@ -471,7 +307,7 @@ static int s2n_evp_hash_free(struct s2n_hash_state *state) S2N_EVP_MD_CTX_FREE(state->digest.high_level.evp.ctx); state->digest.high_level.evp.ctx = NULL; - if (s2n_use_custom_md5_sha1()) { + if (s2n_hash_use_custom_md5_sha1()) { S2N_EVP_MD_CTX_FREE(state->digest.high_level.evp_md5_secondary.ctx); state->digest.high_level.evp_md5_secondary.ctx = NULL; } @@ -480,16 +316,6 @@ static int s2n_evp_hash_free(struct s2n_hash_state *state) return S2N_SUCCESS; } -static const struct s2n_hash s2n_low_level_hash = { - .alloc = &s2n_low_level_hash_new, - .init = &s2n_low_level_hash_init, - .update = &s2n_low_level_hash_update, - .digest = &s2n_low_level_hash_digest, - .copy = &s2n_low_level_hash_copy, - .reset = &s2n_low_level_hash_reset, - .free = &s2n_low_level_hash_free, -}; - static const struct s2n_hash s2n_evp_hash = { .alloc = &s2n_evp_hash_new, .init = &s2n_evp_hash_init, @@ -502,10 +328,7 @@ static const struct s2n_hash s2n_evp_hash = { static int s2n_hash_set_impl(struct s2n_hash_state *state) { - state->hash_impl = &s2n_low_level_hash; - if (s2n_use_evp_impl()) { - state->hash_impl = &s2n_evp_hash; - } + state->hash_impl = &s2n_evp_hash; return S2N_SUCCESS; } diff --git a/crypto/s2n_hash.h b/crypto/s2n_hash.h index 6e18f47be69..f7489423641 100644 --- a/crypto/s2n_hash.h +++ b/crypto/s2n_hash.h @@ -37,20 +37,6 @@ typedef enum { S2N_HASH_ALGS_COUNT } s2n_hash_algorithm; -/* The low_level_digest stores all OpenSSL structs that are alg-specific to be used with OpenSSL's low-level hash API's. */ -union s2n_hash_low_level_digest { - MD5_CTX md5; - SHA_CTX sha1; - SHA256_CTX sha224; - SHA256_CTX sha256; - SHA512_CTX sha384; - SHA512_CTX sha512; - struct { - MD5_CTX md5; - SHA_CTX sha1; - } md5_sha1; -}; - /* The evp_digest stores all OpenSSL structs to be used with OpenSSL's EVP hash API's. */ struct s2n_hash_evp_digest { struct s2n_evp_digest evp; @@ -58,8 +44,8 @@ struct s2n_hash_evp_digest { struct s2n_evp_digest evp_md5_secondary; }; -/* s2n_hash_state stores the s2n_hash implementation being used (low-level or EVP), - * the hash algorithm being used at the time, and either low_level or high_level (EVP) OpenSSL digest structs. +/* s2n_hash_state stores the state and a reference to the implementation being used. + * Currently only EVP hashing is supported, so the only state are EVP_MD contexts. */ struct s2n_hash_state { const struct s2n_hash *hash_impl; @@ -67,13 +53,12 @@ struct s2n_hash_state { uint8_t is_ready_for_input; uint64_t currently_in_hash; union { - union s2n_hash_low_level_digest low_level; struct s2n_hash_evp_digest high_level; } digest; }; -/* The s2n hash implementation is abstracted to allow for separate implementations, using - * either OpenSSL's low-level algorithm-specific API's or OpenSSL's EVP API's. +/* The s2n hash implementation is abstracted to allow for separate implementations. + * Currently the only implementation uses the EVP APIs. */ struct s2n_hash { int (*alloc)(struct s2n_hash_state *state); @@ -87,7 +72,7 @@ struct s2n_hash { S2N_RESULT s2n_hash_algorithms_init(); S2N_RESULT s2n_hash_algorithms_cleanup(); -bool s2n_hash_evp_fully_supported(); +bool s2n_hash_use_custom_md5_sha1(); const EVP_MD *s2n_hash_alg_to_evp_md(s2n_hash_algorithm alg); int s2n_hash_digest_size(s2n_hash_algorithm alg, uint8_t *out); int s2n_hash_block_size(s2n_hash_algorithm alg, uint64_t *block_size); diff --git a/crypto/s2n_rsa_pss.c b/crypto/s2n_rsa_pss.c index 528656d234c..19c7a621612 100644 --- a/crypto/s2n_rsa_pss.c +++ b/crypto/s2n_rsa_pss.c @@ -106,8 +106,8 @@ static int s2n_rsa_pss_validate_sign_verify_match(const struct s2n_pkey *pub, co /* Sign and Verify the Hash of the Random Blob */ s2n_stack_blob(signature_data, RSA_PSS_SIGN_VERIFY_SIGNATURE_SIZE, RSA_PSS_SIGN_VERIFY_SIGNATURE_SIZE); - POSIX_GUARD(s2n_rsa_pss_key_sign(priv, S2N_SIGNATURE_RSA_PSS_PSS, &sign_hash, &signature_data)); - POSIX_GUARD(s2n_rsa_pss_key_verify(pub, S2N_SIGNATURE_RSA_PSS_PSS, &verify_hash, &signature_data)); + POSIX_GUARD(s2n_pkey_sign(priv, S2N_SIGNATURE_RSA_PSS_PSS, &sign_hash, &signature_data)); + POSIX_GUARD(s2n_pkey_verify(pub, S2N_SIGNATURE_RSA_PSS_PSS, &verify_hash, &signature_data)); return 0; } diff --git a/flake.lock b/flake.lock index e215dd49511..50db92bc51d 100644 --- a/flake.lock +++ b/flake.lock @@ -21,13 +21,209 @@ "type": "github" } }, + "awslcfips2022": { + "inputs": { + "flake-utils": "flake-utils_3", + "nix": "nix_2", + "nixpkgs": "nixpkgs_4" + }, + "locked": { + "lastModified": 1739234042, + "narHash": "sha256-d+ZytJ93CSKW6MiZZes6+Aa5N6XT/mAs/MaQvKjKM5s=", + "owner": "dougch", + "repo": "aws-lc", + "rev": "cf43ca76c26a67eac5ca26b0395a51e9159defa5", + "type": "github" + }, + "original": { + "owner": "dougch", + "ref": "nixAWS-LC-FIPS-2.0.17", + "repo": "aws-lc", + "type": "github" + } + }, + "awslcfips2024": { + "inputs": { + "flake-utils": "flake-utils_5", + "nix": "nix_3", + "nixpkgs": "nixpkgs_6" + }, + "locked": { + "lastModified": 1739214489, + "narHash": "sha256-OJ5Nk3H6GS2x/ZjlVRkELWE1dY2+MiUsMfAqvbiSiYY=", + "owner": "dougch", + "repo": "aws-lc", + "rev": "41bb647bdfbe268e349e09c14eee917424d46492", + "type": "github" + }, + "original": { + "owner": "dougch", + "ref": "nixfips-2024-09-27", + "repo": "aws-lc", + "type": "github" + } + }, + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1733328505, + "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "flake": false, + "locked": { + "lastModified": 1733328505, + "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_3": { + "flake": false, + "locked": { + "lastModified": 1733328505, + "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_4": { + "flake": false, + "locked": { + "lastModified": 1733328505, + "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-parts": { + "inputs": { + "nixpkgs-lib": [ + "awslc", + "nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1733312601, + "narHash": "sha256-4pDvzqnegAfRkPwO3wmwBhVi/Sye1mzps0zHWYnP88c=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "205b12d8b7cd4802fbcb8e8ef6a0f1408781a4f9", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-parts_2": { + "inputs": { + "nixpkgs-lib": [ + "awslcfips2022", + "nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1733312601, + "narHash": "sha256-4pDvzqnegAfRkPwO3wmwBhVi/Sye1mzps0zHWYnP88c=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "205b12d8b7cd4802fbcb8e8ef6a0f1408781a4f9", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-parts_3": { + "inputs": { + "nixpkgs-lib": [ + "awslcfips2024", + "nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1733312601, + "narHash": "sha256-4pDvzqnegAfRkPwO3wmwBhVi/Sye1mzps0zHWYnP88c=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "205b12d8b7cd4802fbcb8e8ef6a0f1408781a4f9", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, + "flake-parts_4": { + "inputs": { + "nixpkgs-lib": [ + "nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1733312601, + "narHash": "sha256-4pDvzqnegAfRkPwO3wmwBhVi/Sye1mzps0zHWYnP88c=", + "owner": "hercules-ci", + "repo": "flake-parts", + "rev": "205b12d8b7cd4802fbcb8e8ef6a0f1408781a4f9", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "flake-parts", + "type": "github" + } + }, "flake-utils": { + "inputs": { + "systems": "systems" + }, "locked": { - "lastModified": 1667395993, - "narHash": "sha256-nuEHfE/LcWyuSWnS8t12N1wc105Qtau+/OdUAjtQ0rA=", + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", "owner": "numtide", "repo": "flake-utils", - "rev": "5aed5285a952e0b949eb3ba02c12fa4fcfef535f", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", "type": "github" }, "original": { @@ -36,12 +232,33 @@ } }, "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_3": { + "inputs": { + "systems": "systems_3" + }, "locked": { - "lastModified": 1667395993, - "narHash": "sha256-nuEHfE/LcWyuSWnS8t12N1wc105Qtau+/OdUAjtQ0rA=", + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", "owner": "numtide", "repo": "flake-utils", - "rev": "5aed5285a952e0b949eb3ba02c12fa4fcfef535f", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", "type": "github" }, "original": { @@ -49,50 +266,228 @@ "type": "indirect" } }, - "lowdown-src": { - "flake": false, + "flake-utils_4": { + "inputs": { + "systems": "systems_4" + }, "locked": { - "lastModified": 1633514407, - "narHash": "sha256-Dw32tiMjdK9t3ETl5fzGrutQTzh2rufgZV4A/BbxuD4=", - "owner": "kristapsdz", - "repo": "lowdown", - "rev": "d2c2b44ff6c27b936ec27358a2653caaef8f73b8", + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", "type": "github" }, "original": { - "owner": "kristapsdz", - "repo": "lowdown", + "owner": "numtide", + "repo": "flake-utils", "type": "github" } }, - "lowdown-src_2": { - "flake": false, + "flake-utils_5": { + "inputs": { + "systems": "systems_5" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "id": "flake-utils", + "type": "indirect" + } + }, + "flake-utils_6": { + "inputs": { + "systems": "systems_6" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "id": "flake-utils", + "type": "indirect" + } + }, + "flake-utils_7": { + "inputs": { + "systems": "systems_7" + }, + "locked": { + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "git-hooks-nix": { + "inputs": { + "flake-compat": [ + "awslc", + "nix" + ], + "gitignore": [ + "awslc", + "nix" + ], + "nixpkgs": [ + "awslc", + "nix", + "nixpkgs" + ], + "nixpkgs-stable": [ + "awslc", + "nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1734279981, + "narHash": "sha256-NdaCraHPp8iYMWzdXAt5Nv6sA3MUzlCiGiR586TCwo0=", + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "aa9f40c906904ebd83da78e7f328cd8aeaeae785", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, + "git-hooks-nix_2": { + "inputs": { + "flake-compat": [ + "awslcfips2022", + "nix" + ], + "gitignore": [ + "awslcfips2022", + "nix" + ], + "nixpkgs": [ + "awslcfips2022", + "nix", + "nixpkgs" + ], + "nixpkgs-stable": [ + "awslcfips2022", + "nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1734279981, + "narHash": "sha256-NdaCraHPp8iYMWzdXAt5Nv6sA3MUzlCiGiR586TCwo0=", + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "aa9f40c906904ebd83da78e7f328cd8aeaeae785", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, + "git-hooks-nix_3": { + "inputs": { + "flake-compat": [ + "awslcfips2024", + "nix" + ], + "gitignore": [ + "awslcfips2024", + "nix" + ], + "nixpkgs": [ + "awslcfips2024", + "nix", + "nixpkgs" + ], + "nixpkgs-stable": [ + "awslcfips2024", + "nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1734279981, + "narHash": "sha256-NdaCraHPp8iYMWzdXAt5Nv6sA3MUzlCiGiR586TCwo0=", + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "aa9f40c906904ebd83da78e7f328cd8aeaeae785", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "git-hooks.nix", + "type": "github" + } + }, + "git-hooks-nix_4": { + "inputs": { + "flake-compat": [ + "nix" + ], + "gitignore": [ + "nix" + ], + "nixpkgs": [ + "nix", + "nixpkgs" + ], + "nixpkgs-stable": [ + "nix", + "nixpkgs" + ] + }, "locked": { - "lastModified": 1633514407, - "narHash": "sha256-Dw32tiMjdK9t3ETl5fzGrutQTzh2rufgZV4A/BbxuD4=", - "owner": "kristapsdz", - "repo": "lowdown", - "rev": "d2c2b44ff6c27b936ec27358a2653caaef8f73b8", + "lastModified": 1734279981, + "narHash": "sha256-NdaCraHPp8iYMWzdXAt5Nv6sA3MUzlCiGiR586TCwo0=", + "owner": "cachix", + "repo": "git-hooks.nix", + "rev": "aa9f40c906904ebd83da78e7f328cd8aeaeae785", "type": "github" }, "original": { - "owner": "kristapsdz", - "repo": "lowdown", + "owner": "cachix", + "repo": "git-hooks.nix", "type": "github" } }, "nix": { "inputs": { - "lowdown-src": "lowdown-src", + "flake-compat": "flake-compat", + "flake-parts": "flake-parts", + "git-hooks-nix": "git-hooks-nix", + "nixfmt": "nixfmt", "nixpkgs": "nixpkgs", + "nixpkgs-23-11": "nixpkgs-23-11", "nixpkgs-regression": "nixpkgs-regression" }, "locked": { - "lastModified": 1676058957, - "narHash": "sha256-qIyDaFtro2GqUejMG/0liegc6NqIhh5te+RlsU2mQ/I=", + "lastModified": 1739205346, + "narHash": "sha256-DnCGc1t8eRMnoRl/+OeVTfS/Tqxg511/3BxTEDLXCYc=", "owner": "NixOS", "repo": "nix", - "rev": "c18456604601dd233be4ad2462474488ef8f87e3", + "rev": "92bf150b1ce8ca15df3424a170ccb695adcbfe05", "type": "github" }, "original": { @@ -102,16 +497,42 @@ }, "nix_2": { "inputs": { - "lowdown-src": "lowdown-src_2", + "flake-compat": "flake-compat_2", + "flake-parts": "flake-parts_2", + "git-hooks-nix": "git-hooks-nix_2", + "nixfmt": "nixfmt_2", "nixpkgs": "nixpkgs_3", + "nixpkgs-23-11": "nixpkgs-23-11_2", "nixpkgs-regression": "nixpkgs-regression_2" }, "locked": { - "lastModified": 1674061467, - "narHash": "sha256-yvLbQusfeOizDwHFfTRtVwrUU15q2oaeDzImRGxoTs4=", + "lastModified": 1739205346, + "narHash": "sha256-DnCGc1t8eRMnoRl/+OeVTfS/Tqxg511/3BxTEDLXCYc=", + "owner": "NixOS", + "repo": "nix", + "rev": "92bf150b1ce8ca15df3424a170ccb695adcbfe05", + "type": "github" + }, + "original": { + "id": "nix", + "type": "indirect" + } + }, + "nix_3": { + "inputs": { + "flake-compat": "flake-compat_3", + "flake-parts": "flake-parts_3", + "git-hooks-nix": "git-hooks-nix_3", + "nixpkgs": "nixpkgs_5", + "nixpkgs-23-11": "nixpkgs-23-11_3", + "nixpkgs-regression": "nixpkgs-regression_3" + }, + "locked": { + "lastModified": 1736859128, + "narHash": "sha256-TbnLQ3Z2Voj0mMHhw30dJPEjQYmj6bfLMVGr8RU20v4=", "owner": "NixOS", "repo": "nix", - "rev": "2513eba46a20578f54fd3ac3cb0d25aeb0d0b310", + "rev": "8aafc0588594033fc6f1c3e2a36fe6f04559981f", "type": "github" }, "original": { @@ -119,19 +540,160 @@ "type": "indirect" } }, + "nix_4": { + "inputs": { + "flake-compat": "flake-compat_4", + "flake-parts": "flake-parts_4", + "git-hooks-nix": "git-hooks-nix_4", + "nixfmt": "nixfmt_3", + "nixpkgs": "nixpkgs_7", + "nixpkgs-23-11": "nixpkgs-23-11_4", + "nixpkgs-regression": "nixpkgs-regression_4" + }, + "locked": { + "lastModified": 1739205346, + "narHash": "sha256-DnCGc1t8eRMnoRl/+OeVTfS/Tqxg511/3BxTEDLXCYc=", + "owner": "NixOS", + "repo": "nix", + "rev": "92bf150b1ce8ca15df3424a170ccb695adcbfe05", + "type": "github" + }, + "original": { + "id": "nix", + "type": "indirect" + } + }, + "nixfmt": { + "inputs": { + "flake-utils": "flake-utils_2" + }, + "locked": { + "lastModified": 1736283758, + "narHash": "sha256-hrKhUp2V2fk/dvzTTHFqvtOg000G1e+jyIam+D4XqhA=", + "owner": "NixOS", + "repo": "nixfmt", + "rev": "8d4bd690c247004d90d8554f0b746b1231fe2436", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixfmt", + "type": "github" + } + }, + "nixfmt_2": { + "inputs": { + "flake-utils": "flake-utils_4" + }, + "locked": { + "lastModified": 1736283758, + "narHash": "sha256-hrKhUp2V2fk/dvzTTHFqvtOg000G1e+jyIam+D4XqhA=", + "owner": "NixOS", + "repo": "nixfmt", + "rev": "8d4bd690c247004d90d8554f0b746b1231fe2436", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixfmt", + "type": "github" + } + }, + "nixfmt_3": { + "inputs": { + "flake-utils": "flake-utils_7" + }, + "locked": { + "lastModified": 1736283758, + "narHash": "sha256-hrKhUp2V2fk/dvzTTHFqvtOg000G1e+jyIam+D4XqhA=", + "owner": "NixOS", + "repo": "nixfmt", + "rev": "8d4bd690c247004d90d8554f0b746b1231fe2436", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixfmt", + "type": "github" + } + }, "nixpkgs": { "locked": { - "lastModified": 1670461440, - "narHash": "sha256-jy1LB8HOMKGJEGXgzFRLDU1CBGL0/LlkolgnqIsF0D8=", + "lastModified": 1734359947, + "narHash": "sha256-1Noao/H+N8nFB4Beoy8fgwrcOQLVm9o4zKW1ODaqK9E=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "48d12d5e70ee91fe8481378e540433a7303dbf6a", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "release-24.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-23-11": { + "locked": { + "lastModified": 1717159533, + "narHash": "sha256-oamiKNfr2MS6yH64rUn99mIZjc45nGJlj9eGth/3Xuw=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "a62e6edd6d5e1fa0329b8653c801147986f8d446", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "a62e6edd6d5e1fa0329b8653c801147986f8d446", + "type": "github" + } + }, + "nixpkgs-23-11_2": { + "locked": { + "lastModified": 1717159533, + "narHash": "sha256-oamiKNfr2MS6yH64rUn99mIZjc45nGJlj9eGth/3Xuw=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "a62e6edd6d5e1fa0329b8653c801147986f8d446", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "a62e6edd6d5e1fa0329b8653c801147986f8d446", + "type": "github" + } + }, + "nixpkgs-23-11_3": { + "locked": { + "lastModified": 1717159533, + "narHash": "sha256-oamiKNfr2MS6yH64rUn99mIZjc45nGJlj9eGth/3Xuw=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "a62e6edd6d5e1fa0329b8653c801147986f8d446", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "a62e6edd6d5e1fa0329b8653c801147986f8d446", + "type": "github" + } + }, + "nixpkgs-23-11_4": { + "locked": { + "lastModified": 1717159533, + "narHash": "sha256-oamiKNfr2MS6yH64rUn99mIZjc45nGJlj9eGth/3Xuw=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "04a75b2eecc0acf6239acf9dd04485ff8d14f425", + "rev": "a62e6edd6d5e1fa0329b8653c801147986f8d446", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixos-22.11-small", "repo": "nixpkgs", + "rev": "a62e6edd6d5e1fa0329b8653c801147986f8d446", "type": "github" } }, @@ -167,13 +729,45 @@ "type": "github" } }, + "nixpkgs-regression_3": { + "locked": { + "lastModified": 1643052045, + "narHash": "sha256-uGJ0VXIhWKGXxkeNnq4TvV3CIOkUJ3PAoLZ3HMzNVMw=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2", + "type": "github" + } + }, + "nixpkgs-regression_4": { + "locked": { + "lastModified": 1643052045, + "narHash": "sha256-uGJ0VXIhWKGXxkeNnq4TvV3CIOkUJ3PAoLZ3HMzNVMw=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2", + "type": "github" + } + }, "nixpkgs_2": { "locked": { - "lastModified": 1675918889, - "narHash": "sha256-hy7re4F9AEQqwZxubct7jBRos6md26bmxnCjxf5utJA=", + "lastModified": 1688392541, + "narHash": "sha256-lHrKvEkCPTUO+7tPfjIcb7Trk6k31rz18vkyqmkeJfY=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "49efda9011e8cdcd6c1aad30384cb1dc230c82fe", + "rev": "ea4c80b39be4c09702b0cb3b42eab59e2ba4f24b", "type": "github" }, "original": { @@ -185,27 +779,91 @@ }, "nixpkgs_3": { "locked": { - "lastModified": 1670461440, - "narHash": "sha256-jy1LB8HOMKGJEGXgzFRLDU1CBGL0/LlkolgnqIsF0D8=", + "lastModified": 1734359947, + "narHash": "sha256-1Noao/H+N8nFB4Beoy8fgwrcOQLVm9o4zKW1ODaqK9E=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "04a75b2eecc0acf6239acf9dd04485ff8d14f425", + "rev": "48d12d5e70ee91fe8481378e540433a7303dbf6a", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixos-22.11-small", + "ref": "release-24.11", "repo": "nixpkgs", "type": "github" } }, "nixpkgs_4": { "locked": { - "lastModified": 1674781052, - "narHash": "sha256-nseKFXRvmZ+BDAeWQtsiad+5MnvI/M2Ak9iAWzooWBw=", + "lastModified": 1688392541, + "narHash": "sha256-lHrKvEkCPTUO+7tPfjIcb7Trk6k31rz18vkyqmkeJfY=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "ea4c80b39be4c09702b0cb3b42eab59e2ba4f24b", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-22.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_5": { + "locked": { + "lastModified": 1734359947, + "narHash": "sha256-1Noao/H+N8nFB4Beoy8fgwrcOQLVm9o4zKW1ODaqK9E=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "48d12d5e70ee91fe8481378e540433a7303dbf6a", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "release-24.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_6": { + "locked": { + "lastModified": 1688392541, + "narHash": "sha256-lHrKvEkCPTUO+7tPfjIcb7Trk6k31rz18vkyqmkeJfY=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "ea4c80b39be4c09702b0cb3b42eab59e2ba4f24b", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-22.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_7": { + "locked": { + "lastModified": 1734359947, + "narHash": "sha256-1Noao/H+N8nFB4Beoy8fgwrcOQLVm9o4zKW1ODaqK9E=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "48d12d5e70ee91fe8481378e540433a7303dbf6a", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "release-24.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_8": { + "locked": { + "lastModified": 1688392541, + "narHash": "sha256-lHrKvEkCPTUO+7tPfjIcb7Trk6k31rz18vkyqmkeJfY=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "cc4bb87f5457ba06af9ae57ee4328a49ce674b1b", + "rev": "ea4c80b39be4c09702b0cb3b42eab59e2ba4f24b", "type": "github" }, "original": { @@ -218,9 +876,116 @@ "root": { "inputs": { "awslc": "awslc", - "flake-utils": "flake-utils_2", - "nix": "nix_2", - "nixpkgs": "nixpkgs_4" + "awslcfips2022": "awslcfips2022", + "awslcfips2024": "awslcfips2024", + "flake-utils": "flake-utils_6", + "nix": "nix_4", + "nixpkgs": "nixpkgs_8" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_3": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_4": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_5": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_6": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_7": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" } } }, diff --git a/flake.nix b/flake.nix index b38d08909fb..e9cf2328922 100644 --- a/flake.nix +++ b/flake.nix @@ -1,15 +1,22 @@ { description = "A flake for s2n-tls"; - inputs.nixpkgs.url = "github:NixOS/nixpkgs/nixos-22.11"; - # TODO: https://github.com/aws/aws-lc/pull/830 - inputs.awslc.url = "github:dougch/aws-lc?ref=nixv1.36.0"; - - outputs = { self, nix, nixpkgs, awslc, flake-utils }: + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-22.11"; + awslc.url = "github:dougch/aws-lc?ref=nixv1.36.0"; + awslcfips2022.url = "github:dougch/aws-lc?ref=nixAWS-LC-FIPS-2.0.17"; + awslcfips2024.url = "github:dougch/aws-lc?ref=nixfips-2024-09-27"; + }; + + outputs = + { self, nix, nixpkgs, awslc, awslcfips2022, awslcfips2024, flake-utils }: flake-utils.lib.eachDefaultSystem (system: let pkgs = nixpkgs.legacyPackages.${system}; + # Internal variable = input.awslc ... aws-lc = awslc.packages.${system}.aws-lc; + aws-lc-fips-2022 = awslcfips2022.packages.${system}.aws-lc-fips; + aws-lc-fips-2024 = awslcfips2024.packages.${system}.aws-lc-fips-2024; # TODO: submit a flake PR corretto = import nix/amazon-corretto-17.nix { pkgs = pkgs; }; # TODO: We have parts of our CI that rely on clang-format-15, but that is only available on github:nixos/nixpkgs/nixos-unstable @@ -101,6 +108,8 @@ OPENSSL_1_1_1_INSTALL_DIR = "${openssl_1_1_1}"; OPENSSL_3_0_INSTALL_DIR = "${openssl_3_0}"; AWSLC_INSTALL_DIR = "${aws-lc}"; + AWSLC_FIPS_2022_INSTALL_DIR = "${aws-lc-fips-2022}"; + AWSLC_FIPS_2024_INSTALL_DIR = "${aws-lc-fips-2024}"; GNUTLS_INSTALL_DIR = "${pkgs.gnutls}"; LIBRESSL_INSTALL_DIR = "${libressl}"; # Integ s_client/server tests expect openssl 1.1.1. @@ -171,8 +180,30 @@ source ${writeScript ./nix/shell.sh} ''; }); - - # Used to backup the devShell to s3 for caching. + devShells.awslcfips2022 = devShells.default.overrideAttrs + (finalAttrs: previousAttrs: { + # Re-include cmake to update the environment with a new libcrypto. + buildInputs = [ pkgs.cmake aws-lc-fips-2022 ]; + S2N_LIBCRYPTO = "awslc-fips-2022"; + shellHook = '' + echo Setting up $S2N_LIBCRYPTO environment from flake.nix... + export PATH=${openssl_1_1_1}/bin:$PATH + export PS1="[nix $S2N_LIBCRYPTO] $PS1" + source ${writeScript ./nix/shell.sh} + ''; + }); # Used to backup the devShell to s3 for caching. + devShells.awslcfips2024 = devShells.default.overrideAttrs + (finalAttrs: previousAttrs: { + # Re-include cmake to update the environment with a new libcrypto. + buildInputs = [ pkgs.cmake aws-lc-fips-2024 ]; + S2N_LIBCRYPTO = "awslc-fips-2024"; + shellHook = '' + echo Setting up $S2N_LIBCRYPTO environment from flake.nix... + export PATH=${openssl_1_1_1}/bin:$PATH + export PS1="[nix $S2N_LIBCRYPTO] $PS1" + source ${writeScript ./nix/shell.sh} + ''; + }); # Used to backup the devShell to s3 for caching. packages.devShell = devShells.default.inputDerivation; packages.default = packages.s2n-tls; packages.s2n-tls-openssl3 = packages.s2n-tls.overrideAttrs diff --git a/nix/shell.sh b/nix/shell.sh index 880422f1a6b..de1d03c1256 100644 --- a/nix/shell.sh +++ b/nix/shell.sh @@ -22,7 +22,9 @@ function libcrypto_alias { libcrypto_alias openssl102 "${OPENSSL_1_0_2_INSTALL_DIR}/bin/openssl" libcrypto_alias openssl111 "${OPENSSL_1_1_1_INSTALL_DIR}/bin/openssl" libcrypto_alias openssl30 "${OPENSSL_3_0_INSTALL_DIR}/bin/openssl" -libcrypto_alias bssl "${AWSLC_INSTALL_DIR}/bin/bssl" +libcrypto_alias awslc "${AWSLC_INSTALL_DIR}/bin/bssl" +libcrypto_alias awslcfips2022 "${AWSLC_FIPS_2022_INSTALL_DIR}/bin/bssl" +libcrypto_alias awslcfips2024 "${AWSLC_FIPS_2024_INSTALL_DIR}/bin/bssl" libcrypto_alias libressl "${LIBRESSL_INSTALL_DIR}/bin/openssl" #No need to alias gnutls because it is included in common_packages (see flake.nix). diff --git a/s2n.mk b/s2n.mk index efc42f49920..b48a876eff8 100644 --- a/s2n.mk +++ b/s2n.mk @@ -134,7 +134,7 @@ bindir ?= $(exec_prefix)/bin libdir ?= $(exec_prefix)/lib64 includedir ?= $(exec_prefix)/include -feature_probe = $(shell $(CC) $(CFLAGS) $(shell cat $(S2N_ROOT)/tests/features/GLOBAL.flags) $(shell cat $(S2N_ROOT)/tests/features/$(1).flags) -c -o tmp.o $(S2N_ROOT)/tests/features/$(1).c > /dev/null 2>&1 && echo "-D$(1)"; rm tmp.o > /dev/null 2>&1) +feature_probe = $(shell $(CC) $(CFLAGS) $(shell cat $(S2N_ROOT)/tests/features/GLOBAL.flags) $(shell cat $(S2N_ROOT)/tests/features/$(1).flags) -c -o tmp.o $(S2N_ROOT)/tests/features/$(1).c > /dev/null 2>&1 && echo "-D$(1)=1"; rm tmp.o > /dev/null 2>&1) FEATURES := $(notdir $(patsubst %.c,%,$(wildcard $(S2N_ROOT)/tests/features/*.c))) SUPPORTED_FEATURES := $(foreach feature,$(FEATURES),$(call feature_probe,$(feature))) diff --git a/tests/cbmc/proofs/s2n_hash_copy/Makefile b/tests/cbmc/proofs/s2n_hash_copy/Makefile index 02e2c6c99be..6a69e3a2da2 100644 --- a/tests/cbmc/proofs/s2n_hash_copy/Makefile +++ b/tests/cbmc/proofs/s2n_hash_copy/Makefile @@ -36,10 +36,6 @@ REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_free REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_init REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_new REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_reset -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_free -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_init -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_new -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_reset UNWINDSET += diff --git a/tests/cbmc/proofs/s2n_hash_digest/Makefile b/tests/cbmc/proofs/s2n_hash_digest/Makefile index f4e8cd51fe0..88ed030eb46 100644 --- a/tests/cbmc/proofs/s2n_hash_digest/Makefile +++ b/tests/cbmc/proofs/s2n_hash_digest/Makefile @@ -34,7 +34,6 @@ PROJECT_SOURCES += $(SRCDIR)/utils/s2n_safety.c # We abstract these functions because manual inspection demonstrates they are unreachable. REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_update -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_update UNWINDSET += diff --git a/tests/cbmc/proofs/s2n_hash_free/Makefile b/tests/cbmc/proofs/s2n_hash_free/Makefile index 676317f1dc2..031e04007a4 100644 --- a/tests/cbmc/proofs/s2n_hash_free/Makefile +++ b/tests/cbmc/proofs/s2n_hash_free/Makefile @@ -33,9 +33,6 @@ PROJECT_SOURCES += $(SRCDIR)/crypto/s2n_hash.c REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_init REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_new REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_reset -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_init -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_new -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_reset UNWINDSET += diff --git a/tests/cbmc/proofs/s2n_hash_free/s2n_hash_free_harness.c b/tests/cbmc/proofs/s2n_hash_free/s2n_hash_free_harness.c index 9c646a0b64f..98701e78f36 100644 --- a/tests/cbmc/proofs/s2n_hash_free/s2n_hash_free_harness.c +++ b/tests/cbmc/proofs/s2n_hash_free/s2n_hash_free_harness.c @@ -35,24 +35,17 @@ void s2n_hash_free_harness() assert(s2n_hash_free(state) == S2N_SUCCESS); if (state != NULL) { assert(state->hash_impl->free != NULL); - if (s2n_is_in_fips_mode()) { - assert(state->digest.high_level.evp.ctx == NULL); - assert(state->digest.high_level.evp_md5_secondary.ctx == NULL); - assert_rc_decrement_on_hash_state(&saved_hash_state); - } else { - assert_rc_unchanged_on_hash_state(&saved_hash_state); - } + assert(state->digest.high_level.evp.ctx == NULL); + assert(state->digest.high_level.evp_md5_secondary.ctx == NULL); + assert_rc_decrement_on_hash_state(&saved_hash_state); assert(state->is_ready_for_input == 0); } /* Cleanup after expected error cases, for memory leak check. */ if (state != NULL) { - /* 1. `free` leftover EVP_MD_CTX objects if `s2n_is_in_fips_mode`, - since `s2n_hash_free` is a NO-OP in that case. */ - if (!s2n_is_in_fips_mode()) { - S2N_EVP_MD_CTX_FREE(state->digest.high_level.evp.ctx); - S2N_EVP_MD_CTX_FREE(state->digest.high_level.evp_md5_secondary.ctx); - } + /* 1. `free` leftover EVP_MD_CTX objects */ + S2N_EVP_MD_CTX_FREE(state->digest.high_level.evp.ctx); + S2N_EVP_MD_CTX_FREE(state->digest.high_level.evp_md5_secondary.ctx); /* 2. `free` leftover reference-counted keys (i.e. those with non-zero ref-count), since they are not automatically `free`d until their ref count reaches 0. */ diff --git a/tests/cbmc/proofs/s2n_hash_init/Makefile b/tests/cbmc/proofs/s2n_hash_init/Makefile index 5d905e7c2ef..b41e55b3762 100644 --- a/tests/cbmc/proofs/s2n_hash_init/Makefile +++ b/tests/cbmc/proofs/s2n_hash_init/Makefile @@ -30,10 +30,6 @@ PROOF_SOURCES += $(PROOFDIR)/$(HARNESS_FILE) PROJECT_SOURCES += $(SRCDIR)/crypto/s2n_hash.c # We abstract these functions because manual inspection demonstrates they are unreachable. -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_copy -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_new -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_reset -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_free REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_copy REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_free REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_new diff --git a/tests/cbmc/proofs/s2n_hash_new/Makefile b/tests/cbmc/proofs/s2n_hash_new/Makefile index da1f324b484..fdf3ec2d3bb 100644 --- a/tests/cbmc/proofs/s2n_hash_new/Makefile +++ b/tests/cbmc/proofs/s2n_hash_new/Makefile @@ -27,9 +27,6 @@ PROOF_SOURCES += $(PROOF_STUB)/s2n_is_in_fips_mode.c PROJECT_SOURCES += $(SRCDIR)/crypto/s2n_hash.c # We abstract these functions because manual inspection demonstrates they are unreachable. -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_init -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_reset -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_free REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_init REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_reset REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_free diff --git a/tests/cbmc/proofs/s2n_hash_reset/Makefile b/tests/cbmc/proofs/s2n_hash_reset/Makefile index 55e08b1269c..f41e099d5d9 100644 --- a/tests/cbmc/proofs/s2n_hash_reset/Makefile +++ b/tests/cbmc/proofs/s2n_hash_reset/Makefile @@ -35,8 +35,6 @@ PROJECT_SOURCES += $(SRCDIR)/utils/s2n_safety.c # We abstract these functions because manual inspection demonstrates they are unreachable. REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_free REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_new -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_free -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_new UNWINDSET += diff --git a/tests/cbmc/proofs/s2n_hash_update/Makefile b/tests/cbmc/proofs/s2n_hash_update/Makefile index f9239e1946e..fe7e7ce1e43 100644 --- a/tests/cbmc/proofs/s2n_hash_update/Makefile +++ b/tests/cbmc/proofs/s2n_hash_update/Makefile @@ -34,7 +34,6 @@ PROJECT_SOURCES += $(SRCDIR)/utils/s2n_safety.c # We abstract these functions because manual inspection demonstrates they are unreachable. REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_digest -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_digest UNWINDSET += diff --git a/tests/cbmc/proofs/s2n_hmac_copy/Makefile b/tests/cbmc/proofs/s2n_hmac_copy/Makefile index 87231753868..57fe0f23084 100644 --- a/tests/cbmc/proofs/s2n_hmac_copy/Makefile +++ b/tests/cbmc/proofs/s2n_hmac_copy/Makefile @@ -37,10 +37,6 @@ REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_free REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_init REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_new REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_reset -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_free -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_init -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_new -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_reset UNWINDSET += diff --git a/tests/cbmc/proofs/s2n_hmac_free/Makefile b/tests/cbmc/proofs/s2n_hmac_free/Makefile index fd14d427e94..cacca0a4705 100644 --- a/tests/cbmc/proofs/s2n_hmac_free/Makefile +++ b/tests/cbmc/proofs/s2n_hmac_free/Makefile @@ -34,9 +34,6 @@ PROJECT_SOURCES += $(SRCDIR)/crypto/s2n_hmac.c REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_init REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_new REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_reset -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_init -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_new -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_reset UNWINDSET += diff --git a/tests/cbmc/proofs/s2n_hmac_free/s2n_hmac_free_harness.c b/tests/cbmc/proofs/s2n_hmac_free/s2n_hmac_free_harness.c index aeff29282ad..1e63d5a82d7 100644 --- a/tests/cbmc/proofs/s2n_hmac_free/s2n_hmac_free_harness.c +++ b/tests/cbmc/proofs/s2n_hmac_free/s2n_hmac_free_harness.c @@ -44,28 +44,21 @@ void s2n_hmac_free_harness() assert(state->outer.hash_impl->free != NULL); assert(state->outer_just_key.hash_impl->free != NULL); - if (s2n_is_in_fips_mode()) { - assert(state->inner.digest.high_level.evp.ctx == NULL); - assert(state->inner.digest.high_level.evp_md5_secondary.ctx == NULL); - assert_rc_decrement_on_hash_state(&saved_inner_hash_state); + assert(state->inner.digest.high_level.evp.ctx == NULL); + assert(state->inner.digest.high_level.evp_md5_secondary.ctx == NULL); + assert_rc_decrement_on_hash_state(&saved_inner_hash_state); - assert(state->inner_just_key.digest.high_level.evp.ctx == NULL); - assert(state->inner_just_key.digest.high_level.evp_md5_secondary.ctx == NULL); - assert_rc_decrement_on_hash_state(&saved_inner_just_key_hash_state); + assert(state->inner_just_key.digest.high_level.evp.ctx == NULL); + assert(state->inner_just_key.digest.high_level.evp_md5_secondary.ctx == NULL); + assert_rc_decrement_on_hash_state(&saved_inner_just_key_hash_state); - assert(state->outer.digest.high_level.evp.ctx == NULL); - assert(state->outer.digest.high_level.evp_md5_secondary.ctx == NULL); - assert_rc_decrement_on_hash_state(&saved_outer_hash_state); + assert(state->outer.digest.high_level.evp.ctx == NULL); + assert(state->outer.digest.high_level.evp_md5_secondary.ctx == NULL); + assert_rc_decrement_on_hash_state(&saved_outer_hash_state); - assert(state->outer_just_key.digest.high_level.evp.ctx == NULL); - assert(state->outer_just_key.digest.high_level.evp_md5_secondary.ctx == NULL); - assert_rc_decrement_on_hash_state(&saved_outer_just_key_hash_state); - } else { - assert_rc_unchanged_on_hash_state(&saved_inner_hash_state); - assert_rc_unchanged_on_hash_state(&saved_outer_just_key_hash_state); - assert_rc_unchanged_on_hash_state(&saved_outer_hash_state); - assert_rc_unchanged_on_hash_state(&saved_outer_just_key_hash_state); - } + assert(state->outer_just_key.digest.high_level.evp.ctx == NULL); + assert(state->outer_just_key.digest.high_level.evp_md5_secondary.ctx == NULL); + assert_rc_decrement_on_hash_state(&saved_outer_just_key_hash_state); assert(state->inner.is_ready_for_input == 0); assert(state->inner_just_key.is_ready_for_input == 0); @@ -75,18 +68,15 @@ void s2n_hmac_free_harness() /* Cleanup after expected error cases, for memory leak check. */ if (state != NULL) { - /* 1. `free` leftover EVP_MD_CTX objects if `s2n_is_in_fips_mode`, - since `s2n_hash_free` is a NO-OP in that case. */ - if (!s2n_is_in_fips_mode()) { - S2N_EVP_MD_CTX_FREE(state->inner.digest.high_level.evp.ctx); - S2N_EVP_MD_CTX_FREE(state->inner.digest.high_level.evp_md5_secondary.ctx); - S2N_EVP_MD_CTX_FREE(state->inner_just_key.digest.high_level.evp.ctx); - S2N_EVP_MD_CTX_FREE(state->inner_just_key.digest.high_level.evp_md5_secondary.ctx); - S2N_EVP_MD_CTX_FREE(state->outer.digest.high_level.evp.ctx); - S2N_EVP_MD_CTX_FREE(state->outer.digest.high_level.evp_md5_secondary.ctx); - S2N_EVP_MD_CTX_FREE(state->outer_just_key.digest.high_level.evp.ctx); - S2N_EVP_MD_CTX_FREE(state->outer_just_key.digest.high_level.evp_md5_secondary.ctx); - } + /* 1. `free` leftover EVP_MD_CTX objects */ + S2N_EVP_MD_CTX_FREE(state->inner.digest.high_level.evp.ctx); + S2N_EVP_MD_CTX_FREE(state->inner.digest.high_level.evp_md5_secondary.ctx); + S2N_EVP_MD_CTX_FREE(state->inner_just_key.digest.high_level.evp.ctx); + S2N_EVP_MD_CTX_FREE(state->inner_just_key.digest.high_level.evp_md5_secondary.ctx); + S2N_EVP_MD_CTX_FREE(state->outer.digest.high_level.evp.ctx); + S2N_EVP_MD_CTX_FREE(state->outer.digest.high_level.evp_md5_secondary.ctx); + S2N_EVP_MD_CTX_FREE(state->outer_just_key.digest.high_level.evp.ctx); + S2N_EVP_MD_CTX_FREE(state->outer_just_key.digest.high_level.evp_md5_secondary.ctx); /* 2. `free` leftover reference-counted keys (i.e. those with non-zero ref-count), * since they are not automatically `free`d until their ref count reaches 0. diff --git a/tests/cbmc/proofs/s2n_hmac_init/Makefile b/tests/cbmc/proofs/s2n_hmac_init/Makefile index 4931abbccda..f41c446b013 100644 --- a/tests/cbmc/proofs/s2n_hmac_init/Makefile +++ b/tests/cbmc/proofs/s2n_hmac_init/Makefile @@ -34,9 +34,6 @@ PROJECT_SOURCES += $(SRCDIR)/crypto/s2n_hmac.c PROJECT_SOURCES += $(SRCDIR)/utils/s2n_ensure.c # We abstract these functions because manual inspection demonstrates they are unreachable. -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_new -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_reset -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_free REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_free REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_new REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_reset diff --git a/tests/cbmc/proofs/s2n_hmac_new/Makefile b/tests/cbmc/proofs/s2n_hmac_new/Makefile index b553599d793..13f95e4edb9 100644 --- a/tests/cbmc/proofs/s2n_hmac_new/Makefile +++ b/tests/cbmc/proofs/s2n_hmac_new/Makefile @@ -28,9 +28,6 @@ PROJECT_SOURCES += $(SRCDIR)/crypto/s2n_hash.c PROJECT_SOURCES += $(SRCDIR)/crypto/s2n_hmac.c # We abstract these functions because manual inspection demonstrates they are unreachable. -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_init -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_reset -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_free REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_init REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_reset REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_free diff --git a/tests/cbmc/proofs/s2n_hmac_reset/Makefile b/tests/cbmc/proofs/s2n_hmac_reset/Makefile index c74c27914fd..6b4368f9db5 100644 --- a/tests/cbmc/proofs/s2n_hmac_reset/Makefile +++ b/tests/cbmc/proofs/s2n_hmac_reset/Makefile @@ -37,10 +37,6 @@ REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_free REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_init REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_new REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_reset -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_free -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_init -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_new -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_reset UNWINDSET += diff --git a/tests/cbmc/proofs/s2n_hmac_update/Makefile b/tests/cbmc/proofs/s2n_hmac_update/Makefile index 0098edb5975..b8fe9882ce7 100644 --- a/tests/cbmc/proofs/s2n_hmac_update/Makefile +++ b/tests/cbmc/proofs/s2n_hmac_update/Makefile @@ -35,7 +35,6 @@ PROJECT_SOURCES += $(SRCDIR)/utils/s2n_safety.c # We abstract these functions because manual inspection demonstrates they are unreachable. REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_evp_hash_digest -REMOVE_FUNCTION_BODY += __CPROVER_file_local_s2n_hash_c_s2n_low_level_hash_digest UNWINDSET += diff --git a/tests/cbmc/sources/make_common_datastructures.c b/tests/cbmc/sources/make_common_datastructures.c index 59c7590f10d..1def13949a9 100644 --- a/tests/cbmc/sources/make_common_datastructures.c +++ b/tests/cbmc/sources/make_common_datastructures.c @@ -396,8 +396,7 @@ struct s2n_config *cbmc_allocate_s2n_config() s2n_config->monotonic_clock_ctx = malloc(sizeof(*(s2n_config->monotonic_clock_ctx))); s2n_config->client_hello_cb = malloc(sizeof(*(s2n_config->client_hello_cb))); /* Function pointer. */ s2n_config->client_hello_cb_ctx = malloc(sizeof(*(s2n_config->client_hello_cb_ctx))); - s2n_config->ticket_keys = cbmc_allocate_s2n_set(); - s2n_config->ticket_key_hashes = cbmc_allocate_s2n_set(); + s2n_config->ticket_keys = cbmc_allocate_s2n_array(); s2n_config->cache_store_data = malloc(sizeof(*(s2n_config->cache_store_data))); s2n_config->cache_retrieve_data = malloc(sizeof(*(s2n_config->cache_retrieve_data))); s2n_config->cache_delete_data = malloc(sizeof(*(s2n_config->cache_delete_data))); diff --git a/tests/features/S2N_LIBCRYPTO_SUPPORTS_PROVIDERS.c b/tests/features/S2N_LIBCRYPTO_SUPPORTS_PROVIDERS.c index d115dad587f..0c73090282f 100644 --- a/tests/features/S2N_LIBCRYPTO_SUPPORTS_PROVIDERS.c +++ b/tests/features/S2N_LIBCRYPTO_SUPPORTS_PROVIDERS.c @@ -28,5 +28,8 @@ int main() EVP_MD *md = EVP_MD_fetch(NULL, NULL, NULL); EVP_MD_free(md); + /* Supports property queries for pkey context implicit fetching */ + EVP_PKEY_CTX *pkey_ctx = EVP_PKEY_CTX_new_from_pkey(NULL, NULL, NULL); + return 0; } diff --git a/tests/integrationv2/common.py b/tests/integrationv2/common.py index 2a210de658b..1f5ce5a5431 100644 --- a/tests/integrationv2/common.py +++ b/tests/integrationv2/common.py @@ -33,7 +33,9 @@ def data_bytes(n_bytes): def random_str(n): - return "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(n)) + return "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(n) + ) def pq_enabled(): @@ -54,14 +56,14 @@ def __init__(self, low=8000, high=30000): worker_count = 1 # If pytest is being run in parallel, worker processes will have # the WORKER_COUNT variable set. - parallel_workers = os.getenv('PYTEST_XDIST_WORKER_COUNT') + parallel_workers = os.getenv("PYTEST_XDIST_WORKER_COUNT") if parallel_workers is not None: worker_count = int(parallel_workers) chunk_size = int((high - low) / worker_count) # If xdist is being used, parse the workerid from the envvar. This can # be used to allocate unique ports to each worker. - worker = os.getenv('PYTEST_XDIST_WORKER') + worker = os.getenv("PYTEST_XDIST_WORKER") worker_id = 0 if worker is not None: worker_id = re.findall(r"gw(\d+)", worker) @@ -70,7 +72,7 @@ def __init__(self, low=8000, high=30000): # This is a naive way to allocate ports, but it allows us to cut # the run time in half without workers colliding. - worker_offset = (worker_id * chunk_size) + worker_offset = worker_id * chunk_size base_range = range(low + worker_offset, high) wrap_range = range(low, low + worker_offset) self.ports = iter(itertools.chain(base_range, wrap_range)) @@ -104,42 +106,44 @@ def __init__(self, name, prefix, location=TEST_CERT_DIRECTORY): self.name = name self.cert = location + prefix + "_cert.pem" self.key = location + prefix + "_key.pem" - self.algorithm = 'ANY' + self.algorithm = "ANY" self.curve = None - if 'ECDSA' in name: - self.algorithm = 'EC' + if "ECDSA" in name: + self.algorithm = "EC" self.curve = name[-3:] - elif 'RSA' in name: - self.algorithm = 'RSA' - if 'PSS' in name: - self.algorithm = 'RSAPSS' + elif "RSA" in name: + self.algorithm = "RSA" + if "PSS" in name: + self.algorithm = "RSAPSS" def compatible_with_cipher(self, cipher): if self.algorithm == cipher.algorithm: return True # TLS1.3 cipher suites do not specify auth method, so allow any auth method - if cipher.algorithm == 'ANY': + if cipher.algorithm == "ANY": return True - if self.algorithm == 'RSAPSS': + if self.algorithm == "RSAPSS": # RSA-PSS certs can only be used by ciphers with RSA auth - if cipher.algorithm != 'RSA': + if cipher.algorithm != "RSA": return False # RSA-PSS certs do not support RSA key exchange, only RSA auth # "DHE" here is intended to capture both "DHE" and "ECDHE" - if 'DHE' in cipher.name: + if "DHE" in cipher.name: return True return False def compatible_with_curve(self, curve): - if self.algorithm != 'EC': + if self.algorithm != "EC": return True return curve.name[-3:] == self.curve def compatible_with_sigalg(self, sigalg): if self.algorithm != sigalg.algorithm: return False - sig_alg_has_curve = sigalg.algorithm == 'EC' and sigalg.min_protocol == Protocols.TLS13 + sig_alg_has_curve = ( + sigalg.algorithm == "EC" and sigalg.min_protocol == Protocols.TLS13 + ) if sig_alg_has_curve and self.curve not in sigalg.name: return False return True @@ -152,6 +156,7 @@ class Certificates(object): """ When referencing certificates, use these values. """ + RSA_1024_SHA256 = Cert("RSA_1024_SHA256", "rsa_1024_sha256_client") RSA_1024_SHA384 = Cert("RSA_1024_SHA384", "rsa_1024_sha384_client") RSA_1024_SHA512 = Cert("RSA_1024_SHA512", "rsa_1024_sha512_client") @@ -170,9 +175,9 @@ class Certificates(object): ECDSA_521 = Cert("ECDSA_521", "ecdsa_p521") RSA_2048_SHA256_WILDCARD = Cert( - "RSA_2048_SHA256_WILDCARD", "rsa_2048_sha256_wildcard") - RSA_PSS_2048_SHA256 = Cert( - "RSA_PSS_2048_SHA256", "localhost_rsa_pss_2048_sha256") + "RSA_2048_SHA256_WILDCARD", "rsa_2048_sha256_wildcard" + ) + RSA_PSS_2048_SHA256 = Cert("RSA_PSS_2048_SHA256", "localhost_rsa_pss_2048_sha256") RSA_2048_PKCS1 = Cert("RSA_2048_PKCS1", "rsa_2048_pkcs1") @@ -212,6 +217,7 @@ class Protocols(object): protocols. Since this is hardcoded in S2N, it is not expected to change. """ + TLS13 = Protocol("TLS1.3", 34) TLS12 = Protocol("TLS1.2", 33) TLS11 = Protocol("TLS1.1", 32) @@ -221,7 +227,17 @@ class Protocols(object): class Cipher(object): - def __init__(self, name, min_version, openssl1_1_1, fips, parameters=None, iana_standard_name=None, s2n=False, pq=False): + def __init__( + self, + name, + min_version, + openssl1_1_1, + fips, + parameters=None, + iana_standard_name=None, + s2n=False, + pq=False, + ): self.name = name self.min_version = min_version self.openssl1_1_1 = openssl1_1_1 @@ -232,13 +248,13 @@ def __init__(self, name, min_version, openssl1_1_1, fips, parameters=None, iana_ self.pq = pq if self.min_version >= Protocols.TLS13: - self.algorithm = 'ANY' + self.algorithm = "ANY" elif iana_standard_name is None: - self.algorithm = 'ANY' - elif 'ECDSA' in iana_standard_name: - self.algorithm = 'EC' - elif 'RSA' in iana_standard_name: - self.algorithm = 'RSA' + self.algorithm = "ANY" + elif "ECDSA" in iana_standard_name: + self.algorithm = "EC" + elif "RSA" in iana_standard_name: + self.algorithm = "RSA" else: pytest.fail("Unknown signature algorithm on cipher") @@ -256,94 +272,259 @@ class Ciphers(object): """ When referencing ciphers, use these class values. """ - DHE_RSA_DES_CBC3_SHA = Cipher("DHE-RSA-DES-CBC3-SHA", Protocols.SSLv3, - False, False, iana_standard_name="SSL_DHE_RSA_WITH_3DES_EDE_CBC_SHA") - DHE_RSA_AES128_SHA = Cipher("DHE-RSA-AES128-SHA", Protocols.SSLv3, True, False, TEST_CERT_DIRECTORY + - 'dhparams_2048.pem', iana_standard_name="TLS_DHE_RSA_WITH_AES_128_CBC_SHA") - DHE_RSA_AES256_SHA = Cipher("DHE-RSA-AES256-SHA", Protocols.SSLv3, True, False, TEST_CERT_DIRECTORY + - 'dhparams_2048.pem', iana_standard_name="TLS_DHE_RSA_WITH_AES_256_CBC_SHA") - DHE_RSA_AES128_SHA256 = Cipher("DHE-RSA-AES128-SHA256", Protocols.TLS12, True, True, TEST_CERT_DIRECTORY + - 'dhparams_2048.pem', iana_standard_name="TLS_DHE_RSA_WITH_AES_128_CBC_SHA256") - DHE_RSA_AES256_SHA256 = Cipher("DHE-RSA-AES256-SHA256", Protocols.TLS12, True, True, TEST_CERT_DIRECTORY + - 'dhparams_2048.pem', iana_standard_name="TLS_DHE_RSA_WITH_AES_256_CBC_SHA256") - DHE_RSA_AES128_GCM_SHA256 = Cipher("DHE-RSA-AES128-GCM-SHA256", Protocols.TLS12, True, True, - TEST_CERT_DIRECTORY + 'dhparams_2048.pem', iana_standard_name="TLS_DHE_RSA_WITH_AES_128_GCM_SHA256") - DHE_RSA_AES256_GCM_SHA384 = Cipher("DHE-RSA-AES256-GCM-SHA384", Protocols.TLS12, True, True, - TEST_CERT_DIRECTORY + 'dhparams_2048.pem', iana_standard_name="TLS_DHE_RSA_WITH_AES_256_GCM_SHA384") - DHE_RSA_CHACHA20_POLY1305 = Cipher("DHE-RSA-CHACHA20-POLY1305", Protocols.TLS12, True, False, - TEST_CERT_DIRECTORY + 'dhparams_2048.pem', iana_standard_name="TLS_DHE_RSA_WITH_AES_256_GCM_SHA384") - - AES128_SHA = Cipher("AES128-SHA", Protocols.SSLv3, True, - True, iana_standard_name="TLS_RSA_WITH_AES_128_CBC_SHA") - AES256_SHA = Cipher("AES256-SHA", Protocols.SSLv3, True, - True, iana_standard_name="TLS_RSA_WITH_AES_256_CBC_SHA") - AES128_SHA256 = Cipher("AES128-SHA256", Protocols.TLS12, True, - True, iana_standard_name="TLS_RSA_WITH_AES_128_CBC_SHA256") - AES256_SHA256 = Cipher("AES256-SHA256", Protocols.TLS12, True, - True, iana_standard_name="TLS_RSA_WITH_AES_256_CBC_SHA256") - AES128_GCM_SHA256 = Cipher("TLS_AES_128_GCM_SHA256", Protocols.TLS13, - True, True, iana_standard_name="TLS_AES_128_GCM_SHA256") - AES256_GCM_SHA384 = Cipher("TLS_AES_256_GCM_SHA384", Protocols.TLS13, - True, True, iana_standard_name="TLS_AES_256_GCM_SHA384") - - ECDHE_ECDSA_AES128_SHA = Cipher("ECDHE-ECDSA-AES128-SHA", Protocols.SSLv3, - True, False, iana_standard_name="TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA") - ECDHE_ECDSA_AES256_SHA = Cipher("ECDHE-ECDSA-AES256-SHA", Protocols.SSLv3, - True, False, iana_standard_name="TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA") - ECDHE_ECDSA_AES128_SHA256 = Cipher("ECDHE-ECDSA-AES128-SHA256", Protocols.TLS12, - True, True, iana_standard_name="TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256") - ECDHE_ECDSA_AES256_SHA384 = Cipher("ECDHE-ECDSA-AES256-SHA384", Protocols.TLS12, - True, True, iana_standard_name="TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384") - ECDHE_ECDSA_AES128_GCM_SHA256 = Cipher("ECDHE-ECDSA-AES128-GCM-SHA256", Protocols.TLS12, - True, True, iana_standard_name="TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256") - ECDHE_ECDSA_AES256_GCM_SHA384 = Cipher("ECDHE-ECDSA-AES256-GCM-SHA384", Protocols.TLS12, - True, True, iana_standard_name="TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384") - ECDHE_ECDSA_CHACHA20_POLY1305 = Cipher("ECDHE-ECDSA-CHACHA20-POLY1305", Protocols.TLS12, - True, False, iana_standard_name="TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256") - - ECDHE_RSA_DES_CBC3_SHA = Cipher("ECDHE-RSA-DES-CBC3-SHA", Protocols.SSLv3, - False, False, iana_standard_name="TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA") - ECDHE_RSA_AES128_SHA = Cipher("ECDHE-RSA-AES128-SHA", Protocols.SSLv3, - True, False, iana_standard_name="TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA") - ECDHE_RSA_AES256_SHA = Cipher("ECDHE-RSA-AES256-SHA", Protocols.SSLv3, - True, False, iana_standard_name="TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA") - ECDHE_RSA_RC4_SHA = Cipher("ECDHE-RSA-RC4-SHA", Protocols.SSLv3, - False, False, iana_standard_name="TLS_ECDHE_RSA_WITH_RC4_128_SHA") - ECDHE_RSA_AES128_SHA256 = Cipher("ECDHE-RSA-AES128-SHA256", Protocols.TLS12, - True, True, iana_standard_name="TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256") - ECDHE_RSA_AES256_SHA384 = Cipher("ECDHE-RSA-AES256-SHA384", Protocols.TLS12, - True, True, iana_standard_name="TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384") - ECDHE_RSA_AES128_GCM_SHA256 = Cipher("ECDHE-RSA-AES128-GCM-SHA256", Protocols.TLS12, - True, True, iana_standard_name="TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256") - ECDHE_RSA_AES256_GCM_SHA384 = Cipher("ECDHE-RSA-AES256-GCM-SHA384", Protocols.TLS12, - True, True, iana_standard_name="TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384") - ECDHE_RSA_CHACHA20_POLY1305 = Cipher("ECDHE-RSA-CHACHA20-POLY1305", Protocols.TLS12, - True, False, iana_standard_name="TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256") - CHACHA20_POLY1305_SHA256 = Cipher("TLS_CHACHA20_POLY1305_SHA256", Protocols.TLS13, - True, False, iana_standard_name="TLS_CHACHA20_POLY1305_SHA256") + + DHE_RSA_DES_CBC3_SHA = Cipher( + "DHE-RSA-DES-CBC3-SHA", + Protocols.SSLv3, + False, + False, + iana_standard_name="SSL_DHE_RSA_WITH_3DES_EDE_CBC_SHA", + ) + DHE_RSA_AES128_SHA = Cipher( + "DHE-RSA-AES128-SHA", + Protocols.SSLv3, + True, + False, + TEST_CERT_DIRECTORY + "dhparams_2048.pem", + iana_standard_name="TLS_DHE_RSA_WITH_AES_128_CBC_SHA", + ) + DHE_RSA_AES256_SHA = Cipher( + "DHE-RSA-AES256-SHA", + Protocols.SSLv3, + True, + False, + TEST_CERT_DIRECTORY + "dhparams_2048.pem", + iana_standard_name="TLS_DHE_RSA_WITH_AES_256_CBC_SHA", + ) + DHE_RSA_AES128_SHA256 = Cipher( + "DHE-RSA-AES128-SHA256", + Protocols.TLS12, + True, + True, + TEST_CERT_DIRECTORY + "dhparams_2048.pem", + iana_standard_name="TLS_DHE_RSA_WITH_AES_128_CBC_SHA256", + ) + DHE_RSA_AES256_SHA256 = Cipher( + "DHE-RSA-AES256-SHA256", + Protocols.TLS12, + True, + True, + TEST_CERT_DIRECTORY + "dhparams_2048.pem", + iana_standard_name="TLS_DHE_RSA_WITH_AES_256_CBC_SHA256", + ) + DHE_RSA_AES128_GCM_SHA256 = Cipher( + "DHE-RSA-AES128-GCM-SHA256", + Protocols.TLS12, + True, + True, + TEST_CERT_DIRECTORY + "dhparams_2048.pem", + iana_standard_name="TLS_DHE_RSA_WITH_AES_128_GCM_SHA256", + ) + DHE_RSA_AES256_GCM_SHA384 = Cipher( + "DHE-RSA-AES256-GCM-SHA384", + Protocols.TLS12, + True, + True, + TEST_CERT_DIRECTORY + "dhparams_2048.pem", + iana_standard_name="TLS_DHE_RSA_WITH_AES_256_GCM_SHA384", + ) + DHE_RSA_CHACHA20_POLY1305 = Cipher( + "DHE-RSA-CHACHA20-POLY1305", + Protocols.TLS12, + True, + False, + TEST_CERT_DIRECTORY + "dhparams_2048.pem", + iana_standard_name="TLS_DHE_RSA_WITH_AES_256_GCM_SHA384", + ) + + AES128_SHA = Cipher( + "AES128-SHA", + Protocols.SSLv3, + True, + True, + iana_standard_name="TLS_RSA_WITH_AES_128_CBC_SHA", + ) + AES256_SHA = Cipher( + "AES256-SHA", + Protocols.SSLv3, + True, + True, + iana_standard_name="TLS_RSA_WITH_AES_256_CBC_SHA", + ) + AES128_SHA256 = Cipher( + "AES128-SHA256", + Protocols.TLS12, + True, + True, + iana_standard_name="TLS_RSA_WITH_AES_128_CBC_SHA256", + ) + AES256_SHA256 = Cipher( + "AES256-SHA256", + Protocols.TLS12, + True, + True, + iana_standard_name="TLS_RSA_WITH_AES_256_CBC_SHA256", + ) + AES128_GCM_SHA256 = Cipher( + "TLS_AES_128_GCM_SHA256", + Protocols.TLS13, + True, + True, + iana_standard_name="TLS_AES_128_GCM_SHA256", + ) + AES256_GCM_SHA384 = Cipher( + "TLS_AES_256_GCM_SHA384", + Protocols.TLS13, + True, + True, + iana_standard_name="TLS_AES_256_GCM_SHA384", + ) + + ECDHE_ECDSA_AES128_SHA = Cipher( + "ECDHE-ECDSA-AES128-SHA", + Protocols.SSLv3, + True, + False, + iana_standard_name="TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", + ) + ECDHE_ECDSA_AES256_SHA = Cipher( + "ECDHE-ECDSA-AES256-SHA", + Protocols.SSLv3, + True, + False, + iana_standard_name="TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", + ) + ECDHE_ECDSA_AES128_SHA256 = Cipher( + "ECDHE-ECDSA-AES128-SHA256", + Protocols.TLS12, + True, + True, + iana_standard_name="TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", + ) + ECDHE_ECDSA_AES256_SHA384 = Cipher( + "ECDHE-ECDSA-AES256-SHA384", + Protocols.TLS12, + True, + True, + iana_standard_name="TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384", + ) + ECDHE_ECDSA_AES128_GCM_SHA256 = Cipher( + "ECDHE-ECDSA-AES128-GCM-SHA256", + Protocols.TLS12, + True, + True, + iana_standard_name="TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + ) + ECDHE_ECDSA_AES256_GCM_SHA384 = Cipher( + "ECDHE-ECDSA-AES256-GCM-SHA384", + Protocols.TLS12, + True, + True, + iana_standard_name="TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", + ) + ECDHE_ECDSA_CHACHA20_POLY1305 = Cipher( + "ECDHE-ECDSA-CHACHA20-POLY1305", + Protocols.TLS12, + True, + False, + iana_standard_name="TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", + ) + + ECDHE_RSA_DES_CBC3_SHA = Cipher( + "ECDHE-RSA-DES-CBC3-SHA", + Protocols.SSLv3, + False, + False, + iana_standard_name="TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", + ) + ECDHE_RSA_AES128_SHA = Cipher( + "ECDHE-RSA-AES128-SHA", + Protocols.SSLv3, + True, + False, + iana_standard_name="TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", + ) + ECDHE_RSA_AES256_SHA = Cipher( + "ECDHE-RSA-AES256-SHA", + Protocols.SSLv3, + True, + False, + iana_standard_name="TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", + ) + ECDHE_RSA_RC4_SHA = Cipher( + "ECDHE-RSA-RC4-SHA", + Protocols.SSLv3, + False, + False, + iana_standard_name="TLS_ECDHE_RSA_WITH_RC4_128_SHA", + ) + ECDHE_RSA_AES128_SHA256 = Cipher( + "ECDHE-RSA-AES128-SHA256", + Protocols.TLS12, + True, + True, + iana_standard_name="TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + ) + ECDHE_RSA_AES256_SHA384 = Cipher( + "ECDHE-RSA-AES256-SHA384", + Protocols.TLS12, + True, + True, + iana_standard_name="TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384", + ) + ECDHE_RSA_AES128_GCM_SHA256 = Cipher( + "ECDHE-RSA-AES128-GCM-SHA256", + Protocols.TLS12, + True, + True, + iana_standard_name="TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + ) + ECDHE_RSA_AES256_GCM_SHA384 = Cipher( + "ECDHE-RSA-AES256-GCM-SHA384", + Protocols.TLS12, + True, + True, + iana_standard_name="TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + ) + ECDHE_RSA_CHACHA20_POLY1305 = Cipher( + "ECDHE-RSA-CHACHA20-POLY1305", + Protocols.TLS12, + True, + False, + iana_standard_name="TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", + ) + CHACHA20_POLY1305_SHA256 = Cipher( + "TLS_CHACHA20_POLY1305_SHA256", + Protocols.TLS13, + True, + False, + iana_standard_name="TLS_CHACHA20_POLY1305_SHA256", + ) KMS_TLS_1_0_2018_10 = Cipher( - "KMS-TLS-1-0-2018-10", Protocols.TLS10, False, False, s2n=True) + "KMS-TLS-1-0-2018-10", Protocols.TLS10, False, False, s2n=True + ) PQ_TLS_1_0_2023_01 = Cipher( - "PQ-TLS-1-0-2023-01-24", Protocols.TLS10, False, False, s2n=True, pq=True) + "PQ-TLS-1-0-2023-01-24", Protocols.TLS10, False, False, s2n=True, pq=True + ) PQ_TLS_1_3_2023_06_01 = Cipher( - "PQ-TLS-1-3-2023-06-01", Protocols.TLS12, False, False, s2n=True, pq=True) + "PQ-TLS-1-3-2023-06-01", Protocols.TLS12, False, False, s2n=True, pq=True + ) SECURITY_POLICY_20210816 = Cipher( - "20210816", Protocols.TLS12, False, False, s2n=True, pq=False) + "20210816", Protocols.TLS12, False, False, s2n=True, pq=False + ) @staticmethod def from_iana(iana_name): ciphers = [ - cipher for attr in vars(Ciphers) + cipher + for attr in vars(Ciphers) if not callable(cipher := getattr(Ciphers, attr)) and not attr.startswith("_") and cipher.iana_standard_name ] - return { - cipher.iana_standard_name: cipher - for cipher in ciphers - }.get(iana_name) + return {cipher.iana_standard_name: cipher for cipher in ciphers}.get(iana_name) class Curve(object): @@ -360,6 +541,7 @@ class Curves(object): When referencing curves, use these class values. Don't hardcode curve names. """ + X25519 = Curve("X25519", Protocols.TLS13) P256 = Curve("P-256") # Our only SSLv3 provider doesn't support extensions @@ -373,15 +555,13 @@ class Curves(object): @staticmethod def from_name(name): curves = [ - curve for attr in vars(Curves) + curve + for attr in vars(Curves) if not callable(curve := getattr(Curves, attr)) and not attr.startswith("_") and curve.name ] - return { - curve.name: curve - for curve in curves - }.get(name) + return {curve.name: curve for curve in curves}.get(name) class KemGroup(object): @@ -406,19 +586,26 @@ class KemGroups(object): class Signature(object): - def __init__(self, name, min_protocol=Protocols.SSLv3, max_protocol=Protocols.TLS13, sig_type=None, sig_digest=None): + def __init__( + self, + name, + min_protocol=Protocols.SSLv3, + max_protocol=Protocols.TLS13, + sig_type=None, + sig_digest=None, + ): self.min_protocol = min_protocol self.max_protocol = max_protocol - if 'RSA' in name.upper(): - self.algorithm = 'RSA' - if 'PSS_PSS' in name.upper(): - self.algorithm = 'RSAPSS' - if 'EC' in name.upper() or 'ED' in name.upper(): - self.algorithm = 'EC' + if "RSA" in name.upper(): + self.algorithm = "RSA" + if "PSS_PSS" in name.upper(): + self.algorithm = "RSAPSS" + if "EC" in name.upper() or "ED" in name.upper(): + self.algorithm = "EC" - if not (sig_type or sig_digest) and '+' in name: - sig_type, sig_digest = name.split('+') + if not (sig_type or sig_digest) and "+" in name: + sig_type, sig_digest = name.split("+") self.name = name @@ -430,44 +617,47 @@ def __str__(self): class Signatures(object): - RSA_SHA1 = Signature('RSA+SHA1', max_protocol=Protocols.TLS12) - RSA_SHA224 = Signature('RSA+SHA224', max_protocol=Protocols.TLS12) - RSA_SHA256 = Signature('RSA+SHA256', max_protocol=Protocols.TLS12) - RSA_SHA384 = Signature('RSA+SHA384', max_protocol=Protocols.TLS12) - RSA_SHA512 = Signature('RSA+SHA512', max_protocol=Protocols.TLS12) - RSA_MD5_SHA1 = Signature('RSA+MD5_SHA1', max_protocol=Protocols.TLS11) - ECDSA_SHA224 = Signature('ECDSA+SHA224', max_protocol=Protocols.TLS12) - ECDSA_SHA256 = Signature('ECDSA+SHA256', max_protocol=Protocols.TLS12) - ECDSA_SHA384 = Signature('ECDSA+SHA384', max_protocol=Protocols.TLS12) - ECDSA_SHA512 = Signature('ECDSA+SHA512', max_protocol=Protocols.TLS12) - ECDSA_SHA1 = Signature('ECDSA+SHA1', max_protocol=Protocols.TLS12) + RSA_SHA1 = Signature("RSA+SHA1", max_protocol=Protocols.TLS12) + RSA_SHA224 = Signature("RSA+SHA224", max_protocol=Protocols.TLS12) + RSA_SHA256 = Signature("RSA+SHA256", max_protocol=Protocols.TLS12) + RSA_SHA384 = Signature("RSA+SHA384", max_protocol=Protocols.TLS12) + RSA_SHA512 = Signature("RSA+SHA512", max_protocol=Protocols.TLS12) + RSA_MD5_SHA1 = Signature("RSA+MD5_SHA1", max_protocol=Protocols.TLS11) + ECDSA_SHA224 = Signature("ECDSA+SHA224", max_protocol=Protocols.TLS12) + ECDSA_SHA256 = Signature("ECDSA+SHA256", max_protocol=Protocols.TLS12) + ECDSA_SHA384 = Signature("ECDSA+SHA384", max_protocol=Protocols.TLS12) + ECDSA_SHA512 = Signature("ECDSA+SHA512", max_protocol=Protocols.TLS12) + ECDSA_SHA1 = Signature("ECDSA+SHA1", max_protocol=Protocols.TLS12) RSA_PSS_RSAE_SHA256 = Signature( - 'RSA-PSS+SHA256', - sig_type='RSA-PSS-RSAE', - sig_digest='SHA256') + "RSA-PSS+SHA256", sig_type="RSA-PSS-RSAE", sig_digest="SHA256" + ) RSA_PSS_PSS_SHA256 = Signature( - 'rsa_pss_pss_sha256', + "rsa_pss_pss_sha256", min_protocol=Protocols.TLS12, - sig_type='RSA-PSS-PSS', - sig_digest='SHA256') + sig_type="RSA-PSS-PSS", + sig_digest="SHA256", + ) ECDSA_SECP256r1_SHA256 = Signature( - 'ecdsa_secp256r1_sha256', + "ecdsa_secp256r1_sha256", min_protocol=Protocols.TLS13, - sig_type='ECDSA', - sig_digest='SHA256') + sig_type="ECDSA", + sig_digest="SHA256", + ) ECDSA_SECP384r1_SHA384 = Signature( - 'ecdsa_secp384r1_sha384', + "ecdsa_secp384r1_sha384", min_protocol=Protocols.TLS13, - sig_type='ECDSA', - sig_digest='SHA384') + sig_type="ECDSA", + sig_digest="SHA384", + ) ECDSA_SECP521r1_SHA512 = Signature( - 'ecdsa_secp521r1_sha512', + "ecdsa_secp521r1_sha512", min_protocol=Protocols.TLS13, - sig_type='ECDSA', - sig_digest='SHA512') + sig_type="ECDSA", + sig_digest="SHA512", + ) class Results(object): @@ -488,7 +678,15 @@ class Results(object): # Any exception thrown while running the process exception = None - def __init__(self, stdout, stderr, exit_code, exception, expect_stderr=False, expect_nonzero_exit=False): + def __init__( + self, + stdout, + stderr, + exit_code, + exception, + expect_stderr=False, + expect_nonzero_exit=False, + ): self.stdout = stdout self.stderr = stderr self.exit_code = exit_code @@ -497,7 +695,9 @@ def __init__(self, stdout, stderr, exit_code, exception, expect_stderr=False, ex self.expect_nonzero_exit = expect_nonzero_exit def __str__(self): - return "Stdout: {}\nStderr: {}\nExit code: {}\nException: {}".format(self.stdout, self.stderr, self.exit_code, self.exception) + return "Stdout: {}\nStderr: {}\nExit code: {}\nException: {}".format( + self.stdout, self.stderr, self.exit_code, self.exception + ) def assert_success(self): assert self.exception is None, self.exception @@ -512,34 +712,33 @@ def output_streams(self): class ProviderOptions(object): def __init__( - self, - mode=None, - host=None, - port=None, - cipher=None, - curve=None, - key=None, - cert=None, - use_session_ticket=False, - insecure=False, - data_to_send=None, - use_client_auth=False, - extra_flags=None, - trust_store=None, - reconnects_before_exit=None, - reconnect=None, - verify_hostname=None, - server_name=None, - protocol=None, - use_mainline_version=None, - env_overrides=dict(), - enable_client_ocsp=False, - ocsp_response=None, - signature_algorithm=None, - record_size=None, - verbose=True + self, + mode=None, + host=None, + port=None, + cipher=None, + curve=None, + key=None, + cert=None, + use_session_ticket=False, + insecure=False, + data_to_send=None, + use_client_auth=False, + extra_flags=None, + trust_store=None, + reconnects_before_exit=None, + reconnect=None, + verify_hostname=None, + server_name=None, + protocol=None, + use_mainline_version=None, + env_overrides=dict(), + enable_client_ocsp=False, + ocsp_response=None, + signature_algorithm=None, + record_size=None, + verbose=True, ): - # Client or server self.mode = mode diff --git a/tests/integrationv2/configuration.py b/tests/integrationv2/configuration.py index e343373d5bc..d4639a81607 100644 --- a/tests/integrationv2/configuration.py +++ b/tests/integrationv2/configuration.py @@ -31,12 +31,7 @@ # List of all curves that will be tested. -ALL_TEST_CURVES = [ - Curves.X25519, - Curves.P256, - Curves.P384, - Curves.P521 -] +ALL_TEST_CURVES = [Curves.X25519, Curves.P256, Curves.P384, Curves.P521] # List of all certificates that will be tested. @@ -66,7 +61,7 @@ Certificates.RSA_4096_SHA256, Certificates.ECDSA_256, Certificates.ECDSA_384, - Certificates.RSA_PSS_2048_SHA256 + Certificates.RSA_PSS_2048_SHA256, ] @@ -79,14 +74,12 @@ Ciphers.DHE_RSA_AES128_GCM_SHA256, Ciphers.DHE_RSA_AES256_GCM_SHA384, Ciphers.DHE_RSA_CHACHA20_POLY1305, - Ciphers.AES128_SHA, Ciphers.AES256_SHA, Ciphers.AES128_SHA256, Ciphers.AES256_SHA256, Ciphers.AES128_GCM_SHA256, Ciphers.AES256_GCM_SHA384, - Ciphers.ECDHE_ECDSA_AES128_GCM_SHA256, Ciphers.ECDHE_ECDSA_AES256_GCM_SHA384, Ciphers.ECDHE_ECDSA_AES128_SHA256, @@ -94,7 +87,6 @@ Ciphers.ECDHE_ECDSA_AES128_SHA, Ciphers.ECDHE_ECDSA_AES256_SHA, Ciphers.ECDHE_ECDSA_CHACHA20_POLY1305, - Ciphers.ECDHE_RSA_AES128_SHA, Ciphers.ECDHE_RSA_AES256_SHA, Ciphers.ECDHE_RSA_AES128_SHA256, @@ -102,7 +94,6 @@ Ciphers.ECDHE_RSA_AES128_GCM_SHA256, Ciphers.ECDHE_RSA_AES256_GCM_SHA384, Ciphers.ECDHE_RSA_CHACHA20_POLY1305, - Ciphers.CHACHA20_POLY1305_SHA256, ] @@ -124,103 +115,109 @@ "alligator": ( TEST_SNI_CERT_DIRECTORY + "alligator_cert.pem", TEST_SNI_CERT_DIRECTORY + "alligator_key.pem", - ["www.alligator.com"] + ["www.alligator.com"], ), "second_alligator_rsa": ( TEST_SNI_CERT_DIRECTORY + "second_alligator_rsa_cert.pem", TEST_SNI_CERT_DIRECTORY + "second_alligator_rsa_key.pem", - ["www.alligator.com"] + ["www.alligator.com"], ), "alligator_ecdsa": ( TEST_SNI_CERT_DIRECTORY + "alligator_ecdsa_cert.pem", TEST_SNI_CERT_DIRECTORY + "alligator_ecdsa_key.pem", - ["www.alligator.com"] + ["www.alligator.com"], ), "beaver": ( TEST_SNI_CERT_DIRECTORY + "beaver_cert.pem", TEST_SNI_CERT_DIRECTORY + "beaver_key.pem", - ["www.beaver.com"] + ["www.beaver.com"], ), "many_animals": ( TEST_SNI_CERT_DIRECTORY + "many_animal_sans_rsa_cert.pem", TEST_SNI_CERT_DIRECTORY + "many_animal_sans_rsa_key.pem", - ["www.catfish.com", - "www.dolphin.com", - "www.elephant.com", - "www.falcon.com", - "www.gorilla.com", - "www.horse.com", - "www.impala.com", - # "Simple hostname" - "Jackal", - "k.e.e.l.b.i.l.l.e.d.t.o.u.c.a.n", - # SAN on this cert is actually "ladybug.ladybug" - # Verify case insensitivity works as expected. - "LADYBUG.LADYBUG", - "com.penguin.macaroni"] + [ + "www.catfish.com", + "www.dolphin.com", + "www.elephant.com", + "www.falcon.com", + "www.gorilla.com", + "www.horse.com", + "www.impala.com", + # "Simple hostname" + "Jackal", + "k.e.e.l.b.i.l.l.e.d.t.o.u.c.a.n", + # SAN on this cert is actually "ladybug.ladybug" + # Verify case insensitivity works as expected. + "LADYBUG.LADYBUG", + "com.penguin.macaroni", + ], ), "narwhal_cn": ( TEST_SNI_CERT_DIRECTORY + "narwhal_cn_cert.pem", TEST_SNI_CERT_DIRECTORY + "narwhal_cn_key.pem", - ["www.narwhal.com"] + ["www.narwhal.com"], ), "octopus_cn_platypus_san": ( TEST_SNI_CERT_DIRECTORY + "octopus_cn_platypus_san_cert.pem", TEST_SNI_CERT_DIRECTORY + "octopus_cn_platypus_san_key.pem", - ["www.platypus.com"] + ["www.platypus.com"], ), "quail_cn_rattlesnake_cn": ( TEST_SNI_CERT_DIRECTORY + "quail_cn_rattlesnake_cn_cert.pem", TEST_SNI_CERT_DIRECTORY + "quail_cn_rattlesnake_cn_key.pem", - ["www.quail.com", "www.rattlesnake.com"] + ["www.quail.com", "www.rattlesnake.com"], ), "many_animals_mixed_case": ( TEST_SNI_CERT_DIRECTORY + "many_animal_sans_mixed_case_rsa_cert.pem", TEST_SNI_CERT_DIRECTORY + "many_animal_sans_mixed_case_rsa_key.pem", - ["alligator.com", - "beaver.com", - "catFish.com", - "WWW.dolphin.COM", - "www.ELEPHANT.com", - "www.Falcon.Com", - "WWW.gorilla.COM", - "www.horse.com", - "WWW.IMPALA.COM", - "WwW.jAcKaL.cOm"] + [ + "alligator.com", + "beaver.com", + "catFish.com", + "WWW.dolphin.COM", + "www.ELEPHANT.com", + "www.Falcon.Com", + "WWW.gorilla.COM", + "www.horse.com", + "WWW.IMPALA.COM", + "WwW.jAcKaL.cOm", + ], ), "embedded_wildcard": ( TEST_SNI_CERT_DIRECTORY + "embedded_wildcard_rsa_cert.pem", TEST_SNI_CERT_DIRECTORY + "embedded_wildcard_rsa_key.pem", - ["www.labelstart*labelend.com"] + ["www.labelstart*labelend.com"], ), "non_empty_label_wildcard": ( TEST_SNI_CERT_DIRECTORY + "non_empty_label_wildcard_rsa_cert.pem", TEST_SNI_CERT_DIRECTORY + "non_empty_label_wildcard_rsa_key.pem", - ["WILD*.middle.end"] + ["WILD*.middle.end"], ), "trailing_wildcard": ( TEST_SNI_CERT_DIRECTORY + "trailing_wildcard_rsa_cert.pem", TEST_SNI_CERT_DIRECTORY + "trailing_wildcard_rsa_key.pem", - ["the.prefix.*"] + ["the.prefix.*"], ), "wildcard_insect": ( TEST_SNI_CERT_DIRECTORY + "wildcard_insect_rsa_cert.pem", TEST_SNI_CERT_DIRECTORY + "wildcard_insect_rsa_key.pem", - ["ant.insect.hexapod", - "BEE.insect.hexapod", - "wasp.INSECT.hexapod", - "butterfly.insect.hexapod"] + [ + "ant.insect.hexapod", + "BEE.insect.hexapod", + "wasp.INSECT.hexapod", + "butterfly.insect.hexapod", + ], ), "termite": ( TEST_SNI_CERT_DIRECTORY + "termite_rsa_cert.pem", TEST_SNI_CERT_DIRECTORY + "termite_rsa_key.pem", - ["termite.insect.hexapod"] + ["termite.insect.hexapod"], ), "underwing": ( TEST_SNI_CERT_DIRECTORY + "underwing_ecdsa_cert.pem", TEST_SNI_CERT_DIRECTORY + "underwing_ecdsa_key.pem", - ["underwing.insect.hexapod"] - ) + ["underwing.insect.hexapod"], + ), } @@ -228,74 +225,103 @@ # Test inputs: server certificates to load into s2nd, client SNI and capabilities, outputs are selected server cert # and negotiated cipher. MultiCertTest = collections.namedtuple( - 'MultiCertTest', 'description server_certs client_sni client_ciphers expected_cert expect_matching_hostname') + "MultiCertTest", + "description server_certs client_sni client_ciphers expected_cert expect_matching_hostname", +) MULTI_CERT_TEST_CASES = [ MultiCertTest( description="Test basic SNI match for default cert.", - server_certs=[SNI_CERTS["alligator"], - SNI_CERTS["beaver"], SNI_CERTS["alligator_ecdsa"]], + server_certs=[ + SNI_CERTS["alligator"], + SNI_CERTS["beaver"], + SNI_CERTS["alligator_ecdsa"], + ], client_sni="www.alligator.com", client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["alligator"], - expect_matching_hostname=True), + expect_matching_hostname=True, + ), MultiCertTest( description="Test basic SNI matches for non-default cert.", - server_certs=[SNI_CERTS["alligator"], - SNI_CERTS["beaver"], SNI_CERTS["alligator_ecdsa"]], + server_certs=[ + SNI_CERTS["alligator"], + SNI_CERTS["beaver"], + SNI_CERTS["alligator_ecdsa"], + ], client_sni="www.beaver.com", client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["beaver"], - expect_matching_hostname=True), + expect_matching_hostname=True, + ), MultiCertTest( description="Test default cert is selected when there are no SNI matches.", - server_certs=[SNI_CERTS["alligator"], - SNI_CERTS["beaver"], SNI_CERTS["alligator_ecdsa"]], + server_certs=[ + SNI_CERTS["alligator"], + SNI_CERTS["beaver"], + SNI_CERTS["alligator_ecdsa"], + ], client_sni="not.a.match", client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["alligator"], - expect_matching_hostname=False), + expect_matching_hostname=False, + ), MultiCertTest( description="Test default cert is selected when no SNI is sent.", - server_certs=[SNI_CERTS["alligator"], - SNI_CERTS["beaver"], SNI_CERTS["alligator_ecdsa"]], + server_certs=[ + SNI_CERTS["alligator"], + SNI_CERTS["beaver"], + SNI_CERTS["alligator_ecdsa"], + ], client_sni=None, client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["alligator"], - expect_matching_hostname=False), + expect_matching_hostname=False, + ), MultiCertTest( description="Test ECDSA cert is selected with matching domain and client only supports ECDSA.", - server_certs=[SNI_CERTS["alligator"], - SNI_CERTS["beaver"], SNI_CERTS["alligator_ecdsa"]], + server_certs=[ + SNI_CERTS["alligator"], + SNI_CERTS["beaver"], + SNI_CERTS["alligator_ecdsa"], + ], client_sni="www.alligator.com", client_ciphers=[Ciphers.ECDHE_ECDSA_AES128_SHA], expected_cert=SNI_CERTS["alligator_ecdsa"], - expect_matching_hostname=True), + expect_matching_hostname=True, + ), MultiCertTest( description="Test ECDSA cert selected when: domain matches for both ECDSA+RSA, client supports ECDSA+RSA " - " ciphers, ECDSA is higher priority on server side.", - server_certs=[SNI_CERTS["alligator"], - SNI_CERTS["beaver"], SNI_CERTS["alligator_ecdsa"]], + " ciphers, ECDSA is higher priority on server side.", + server_certs=[ + SNI_CERTS["alligator"], + SNI_CERTS["beaver"], + SNI_CERTS["alligator_ecdsa"], + ], client_sni="www.alligator.com", - client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA, - Ciphers.ECDHE_ECDSA_AES128_SHA], + client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA, Ciphers.ECDHE_ECDSA_AES128_SHA], expected_cert=SNI_CERTS["alligator_ecdsa"], - expect_matching_hostname=True), + expect_matching_hostname=True, + ), MultiCertTest( description="Test domain match is highest priority. Domain matching ECDSA certificate should be selected" - " even if domain mismatched RSA certificate is available and RSA cipher is higher priority.", + " even if domain mismatched RSA certificate is available and RSA cipher is higher priority.", server_certs=[SNI_CERTS["beaver"], SNI_CERTS["alligator_ecdsa"]], client_sni="www.alligator.com", - client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA256, - Ciphers.ECDHE_ECDSA_AES128_SHA256], + client_ciphers=[ + Ciphers.ECDHE_RSA_AES128_SHA256, + Ciphers.ECDHE_ECDSA_AES128_SHA256, + ], expected_cert=SNI_CERTS["alligator_ecdsa"], - expect_matching_hostname=True), + expect_matching_hostname=True, + ), MultiCertTest( description="Test certificate with single SAN entry matching is selected before mismatched multi SAN cert", server_certs=[SNI_CERTS["many_animals"], SNI_CERTS["alligator"]], client_sni="www.alligator.com", client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["alligator"], - expect_matching_hostname=True), + expect_matching_hostname=True, + ), # many_animals was the first cert added MultiCertTest( description="Test default cert with multiple sans and no SNI sent.", @@ -303,119 +329,157 @@ client_sni=None, client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["many_animals"], - expect_matching_hostname=False), + expect_matching_hostname=False, + ), MultiCertTest( description="Test certificate match with CN", server_certs=[SNI_CERTS["alligator"], SNI_CERTS["narwhal_cn"]], client_sni="www.narwhal.com", client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["narwhal_cn"], - expect_matching_hostname=True), + expect_matching_hostname=True, + ), MultiCertTest( description="Test SAN+CN cert can match using SAN.", - server_certs=[SNI_CERTS["alligator"], - SNI_CERTS["octopus_cn_platypus_san"]], + server_certs=[SNI_CERTS["alligator"], SNI_CERTS["octopus_cn_platypus_san"]], client_sni="www.platypus.com", client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["octopus_cn_platypus_san"], - expect_matching_hostname=True), + expect_matching_hostname=True, + ), MultiCertTest( description="Test that CN is not considered for matching if the certificate contains SANs.", - server_certs=[SNI_CERTS["alligator"], - SNI_CERTS["octopus_cn_platypus_san"]], + server_certs=[SNI_CERTS["alligator"], SNI_CERTS["octopus_cn_platypus_san"]], client_sni="www.octopus.com", client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["alligator"], - expect_matching_hostname=False), + expect_matching_hostname=False, + ), MultiCertTest( description="Test certificate with multiple CNs can match.", - server_certs=[SNI_CERTS["alligator"], - SNI_CERTS["quail_cn_rattlesnake_cn"]], + server_certs=[SNI_CERTS["alligator"], SNI_CERTS["quail_cn_rattlesnake_cn"]], client_sni="www.rattlesnake.com", client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["quail_cn_rattlesnake_cn"], - expect_matching_hostname=False), + expect_matching_hostname=False, + ), MultiCertTest( description="Test cert with embedded wildcard is not treated as a wildcard.", server_certs=[SNI_CERTS["alligator"], SNI_CERTS["embedded_wildcard"]], client_sni="www.labelstartWILDCARDlabelend.com", client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["alligator"], - expect_matching_hostname=False), + expect_matching_hostname=False, + ), MultiCertTest( - description="Test non empty left label wildcard cert is not treated as a wildcard."\ - " s2n only supports wildcards with a single * as the left label", - server_certs=[SNI_CERTS["alligator"], - SNI_CERTS["non_empty_label_wildcard"]], + description="Test non empty left label wildcard cert is not treated as a wildcard." + " s2n only supports wildcards with a single * as the left label", + server_certs=[SNI_CERTS["alligator"], SNI_CERTS["non_empty_label_wildcard"]], client_sni="WILDCARD.middle.end", client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["alligator"], - expect_matching_hostname=False), + expect_matching_hostname=False, + ), MultiCertTest( description="Test cert with trailing * is not treated as wildcard.", server_certs=[SNI_CERTS["alligator"], SNI_CERTS["trailing_wildcard"]], client_sni="the.prefix.WILDCARD", client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["alligator"], - expect_matching_hostname=False), + expect_matching_hostname=False, + ), MultiCertTest( - description="Certificate with exact sni match(termite.insect.hexapod) is preferred over wildcard"\ - " *.insect.hexapod", - server_certs=[SNI_CERTS["wildcard_insect"], - SNI_CERTS["alligator"], SNI_CERTS["termite"]], + description="Certificate with exact sni match(termite.insect.hexapod) is preferred over wildcard" + " *.insect.hexapod", + server_certs=[ + SNI_CERTS["wildcard_insect"], + SNI_CERTS["alligator"], + SNI_CERTS["termite"], + ], client_sni="termite.insect.hexapod", client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], expected_cert=SNI_CERTS["termite"], - expect_matching_hostname=True), + expect_matching_hostname=True, + ), MultiCertTest( - description="ECDSA Certificate with exact sni match(underwing.insect.hexapod) is preferred over RSA wildcard"\ - " *.insect.hexapod when RSA ciphers are higher priority than ECDSA in server preferences.", - server_certs=[SNI_CERTS["wildcard_insect"], - SNI_CERTS["alligator"], SNI_CERTS["underwing"]], + description="ECDSA Certificate with exact sni match(underwing.insect.hexapod) is preferred over RSA wildcard" + " *.insect.hexapod when RSA ciphers are higher priority than ECDSA in server preferences.", + server_certs=[ + SNI_CERTS["wildcard_insect"], + SNI_CERTS["alligator"], + SNI_CERTS["underwing"], + ], client_sni="underwing.insect.hexapod", - client_ciphers=[Ciphers.ECDHE_RSA_AES128_GCM_SHA256, - Ciphers.ECDHE_ECDSA_AES128_GCM_SHA256], + client_ciphers=[ + Ciphers.ECDHE_RSA_AES128_GCM_SHA256, + Ciphers.ECDHE_ECDSA_AES128_GCM_SHA256, + ], expected_cert=SNI_CERTS["underwing"], - expect_matching_hostname=True), + expect_matching_hostname=True, + ), MultiCertTest( description="Firstly loaded matching certificate should be selected among certificates with the same domain names", - server_certs=[SNI_CERTS["alligator"], - SNI_CERTS["second_alligator_rsa"]], + server_certs=[SNI_CERTS["alligator"], SNI_CERTS["second_alligator_rsa"]], client_sni="www.alligator.com", client_ciphers=[Ciphers.AES128_GCM_SHA256], expected_cert=SNI_CERTS["alligator"], - expect_matching_hostname=True), + expect_matching_hostname=True, + ), MultiCertTest( description="Firstly loaded matching certificate should be selected among matching+non-matching certificates", - server_certs=[SNI_CERTS["beaver"], SNI_CERTS["alligator"], - SNI_CERTS["second_alligator_rsa"]], + server_certs=[ + SNI_CERTS["beaver"], + SNI_CERTS["alligator"], + SNI_CERTS["second_alligator_rsa"], + ], client_sni="www.alligator.com", client_ciphers=[Ciphers.AES128_GCM_SHA256], expected_cert=SNI_CERTS["alligator"], - expect_matching_hostname=True)] + expect_matching_hostname=True, + ), +] # Positive test for wildcard matches -MULTI_CERT_TEST_CASES.extend([MultiCertTest( - description="Test wildcard *.insect.hexapod matches subdomain " + specific_insect_domain, - server_certs=[SNI_CERTS["alligator"], SNI_CERTS["wildcard_insect"]], - client_sni=specific_insect_domain, - client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], - expected_cert=SNI_CERTS["wildcard_insect"], - expect_matching_hostname=True) for specific_insect_domain in SNI_CERTS["wildcard_insect"][2]]) +MULTI_CERT_TEST_CASES.extend( + [ + MultiCertTest( + description="Test wildcard *.insect.hexapod matches subdomain " + + specific_insect_domain, + server_certs=[SNI_CERTS["alligator"], SNI_CERTS["wildcard_insect"]], + client_sni=specific_insect_domain, + client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], + expected_cert=SNI_CERTS["wildcard_insect"], + expect_matching_hostname=True, + ) + for specific_insect_domain in SNI_CERTS["wildcard_insect"][2] + ] +) # Positive test for basic SAN matches -MULTI_CERT_TEST_CASES.extend([MultiCertTest( - description="Match SAN " + many_animal_domain + " in many_animals cert", - server_certs=[SNI_CERTS["alligator"], SNI_CERTS["many_animals"]], - client_sni=many_animal_domain, - client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], - expected_cert=SNI_CERTS["many_animals"], - expect_matching_hostname=True) for many_animal_domain in SNI_CERTS["many_animals"][2]]) +MULTI_CERT_TEST_CASES.extend( + [ + MultiCertTest( + description="Match SAN " + many_animal_domain + " in many_animals cert", + server_certs=[SNI_CERTS["alligator"], SNI_CERTS["many_animals"]], + client_sni=many_animal_domain, + client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], + expected_cert=SNI_CERTS["many_animals"], + expect_matching_hostname=True, + ) + for many_animal_domain in SNI_CERTS["many_animals"][2] + ] +) # Positive test for mixed cased SAN matches -MULTI_CERT_TEST_CASES.extend([MultiCertTest( - description="Match SAN " + many_animal_domain + - " in many_animals_mixed_case cert", - server_certs=[SNI_CERTS["alligator"], - SNI_CERTS["many_animals_mixed_case"]], - client_sni=many_animal_domain, - client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], - expected_cert=SNI_CERTS["many_animals_mixed_case"], - expect_matching_hostname=True) for many_animal_domain in SNI_CERTS["many_animals_mixed_case"][2]]) +MULTI_CERT_TEST_CASES.extend( + [ + MultiCertTest( + description="Match SAN " + + many_animal_domain + + " in many_animals_mixed_case cert", + server_certs=[SNI_CERTS["alligator"], SNI_CERTS["many_animals_mixed_case"]], + client_sni=many_animal_domain, + client_ciphers=[Ciphers.ECDHE_RSA_AES128_SHA], + expected_cert=SNI_CERTS["many_animals_mixed_case"], + expect_matching_hostname=True, + ) + for many_animal_domain in SNI_CERTS["many_animals_mixed_case"][2] + ] +) diff --git a/tests/integrationv2/conftest.py b/tests/integrationv2/conftest.py index 5bd3bcd7b81..026ec91a917 100644 --- a/tests/integrationv2/conftest.py +++ b/tests/integrationv2/conftest.py @@ -25,7 +25,7 @@ def available_providers(): bin_path = f"{expected_location}/{binary}" if not os.path.exists(bin_path): pytest.fail(f"unable to locate {binary}") - os.environ['PATH'] += os.pathsep + expected_location + os.environ["PATH"] += os.pathsep + expected_location providers.add(S2N) if os.path.exists("./bin/SSLSocketClient.class"): @@ -38,8 +38,14 @@ def available_providers(): def pytest_addoption(parser: pytest.Parser): - parser.addoption("--provider-version", action="store", dest="provider-version", - default=None, type=str, help="Set the version of the TLS provider") + parser.addoption( + "--provider-version", + action="store", + dest="provider-version", + default=None, + type=str, + help="Set the version of the TLS provider", + ) parser.addoption( "--best-effort-NOT-FOR-CI", action="store_true", @@ -58,13 +64,14 @@ def pytest_configure(config: pytest.Config): This is executed once per pytest session on process startup. """ config.addinivalue_line( - "markers", "uncollect_if(*, func): function to unselect tests from parametrization" + "markers", + "uncollect_if(*, func): function to unselect tests from parametrization", ) if config.getoption("--best-effort-NOT-FOR-CI"): config.stash[PATH_CONFIGURATION_KEY] = available_providers() - provider_version = config.getoption('provider-version', None) + provider_version = config.getoption("provider-version", None) # By default, any libcrypto with "fips" in its name should be in fips mode. # However, s2n-tls no longer supports fips mode with openssl-1.0.2-fips. if "fips" in provider_version and "openssl-1.0.2-fips" not in provider_version: @@ -79,9 +86,9 @@ def pytest_collection_modifyitems(config, items): removed = [] kept = [] for item in items: - m = item.get_closest_marker('uncollect_if') + m = item.get_closest_marker("uncollect_if") if m: - func = m.kwargs['func'] + func = m.kwargs["func"] if func(**item.callspec.params): removed.append(item) continue diff --git a/tests/integrationv2/constants.py b/tests/integrationv2/constants.py index 6c66589780e..6c46d26f952 100644 --- a/tests/integrationv2/constants.py +++ b/tests/integrationv2/constants.py @@ -4,6 +4,5 @@ TEST_SNI_CERT_DIRECTORY = "../pems/sni/" TEST_OCSP_DIRECTORY = "../pems/ocsp/" -TRUST_STORE_BUNDLE = TEST_CERT_DIRECTORY + 'trust-store/ca-bundle.crt' -TRUST_STORE_TRUSTED_BUNDLE = TEST_CERT_DIRECTORY + \ - 'trust-store/ca-bundle.trust.crt' +TRUST_STORE_BUNDLE = TEST_CERT_DIRECTORY + "trust-store/ca-bundle.crt" +TRUST_STORE_TRUSTED_BUNDLE = TEST_CERT_DIRECTORY + "trust-store/ca-bundle.trust.crt" diff --git a/tests/integrationv2/fixtures.py b/tests/integrationv2/fixtures.py index d5820ce52d9..6bad9d86765 100644 --- a/tests/integrationv2/fixtures.py +++ b/tests/integrationv2/fixtures.py @@ -26,8 +26,16 @@ def managed_process(request: pytest.FixtureRequest): # Indicates whether a launch was aborted. If so, non-graceful shutdown is allowed aborted = False - def _fn(provider_class: Provider, options: ProviderOptions, timeout=5, send_marker=None, close_marker=None, - expect_stderr=None, kill_marker=None, send_with_newline=None): + def _fn( + provider_class: Provider, + options: ProviderOptions, + timeout=5, + send_marker=None, + close_marker=None, + expect_stderr=None, + kill_marker=None, + send_with_newline=None, + ): best_effort_mode = request.config.getoption("--best-effort-NOT-FOR-CI") if best_effort_mode: # modify the `aborted` field in the generator object @@ -40,7 +48,11 @@ def _fn(provider_class: Provider, options: ProviderOptions, timeout=5, send_mark provider = provider_class(options) cmd_line = provider.get_cmd_line() - if best_effort_mode and provider_class is S2N and not (cmd_line[0] == "s2nc" or cmd_line[0] == "s2nd"): + if ( + best_effort_mode + and provider_class is S2N + and not (cmd_line[0] == "s2nc" or cmd_line[0] == "s2nd") + ): aborted = True pytest.skip("s2nc_head or s2nd_head not supported for best-effort") @@ -63,7 +75,7 @@ def _fn(provider_class: Provider, options: ProviderOptions, timeout=5, send_mark env_overrides=options.env_overrides, expect_stderr=expect_stderr, kill_marker=kill_marker, - send_with_newline=send_with_newline + send_with_newline=send_with_newline, ) processes.append(p) @@ -72,7 +84,8 @@ def _fn(provider_class: Provider, options: ProviderOptions, timeout=5, send_mark with provider._provider_ready_condition: # Don't continue processing until the provider has indicated it is ready. provider._provider_ready_condition.wait_for( - provider.is_provider_ready, timeout) + provider.is_provider_ready, timeout + ) return p try: @@ -96,13 +109,14 @@ def _swap_mtu(device, new_mtu): Return the original MTU so it can be reset later. """ cmd = ["ip", "link", "show", device] - p = subprocess.Popen(cmd, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p = subprocess.Popen( + cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) mtu = 65536 for line in p.stdout.readlines(): s = line.decode("utf-8") - pieces = s.split(' ') - if len(pieces) >= 4 and pieces[3] == 'mtu': + pieces = s.split(" ") + if len(pieces) >= 4 and pieces[3] == "mtu": mtu = int(pieces[4]) p.wait() @@ -112,7 +126,7 @@ def _swap_mtu(device, new_mtu): return int(mtu) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def custom_mtu(): """ This fixture will swap the loopback's MTU from the default @@ -126,6 +140,6 @@ def custom_mtu(): if os.geteuid() != 0: pytest.skip("Test needs root privileges to modify lo MTU") - original_mtu = _swap_mtu('lo', 1500) + original_mtu = _swap_mtu("lo", 1500) yield - _swap_mtu('lo', original_mtu) + _swap_mtu("lo", original_mtu) diff --git a/tests/integrationv2/global_flags.py b/tests/integrationv2/global_flags.py index 853cd3feb02..0b2837b7b25 100644 --- a/tests/integrationv2/global_flags.py +++ b/tests/integrationv2/global_flags.py @@ -5,11 +5,11 @@ # based on the environment. # If S2N is operating in FIPS mode -S2N_FIPS_MODE = 's2n_fips_mode' +S2N_FIPS_MODE = "s2n_fips_mode" # The version of provider being used # (set from the S2N_LIBCRYPTO env var, which is how the original integration test works) -S2N_PROVIDER_VERSION = 's2n_provider_version' +S2N_PROVIDER_VERSION = "s2n_provider_version" _flags = {} diff --git a/tests/integrationv2/processes.py b/tests/integrationv2/processes.py index 76dd7a00326..44af7b83224 100644 --- a/tests/integrationv2/processes.py +++ b/tests/integrationv2/processes.py @@ -11,7 +11,7 @@ _PopenSelector = selectors.PollSelector -_PIPE_BUF = getattr(select, 'PIPE_BUF', 512) +_PIPE_BUF = getattr(select, "PIPE_BUF", 512) _DEBUG_LEN = 80 @@ -63,8 +63,15 @@ def wait_for(self, wait_for_marker, timeout=None): return (stdout, stderr) - def communicate(self, input_data=None, send_marker_list=None, close_marker=None, kill_marker=None, - send_with_newline=False, timeout=None): + def communicate( + self, + input_data=None, + send_marker_list=None, + close_marker=None, + kill_marker=None, + send_with_newline=False, + timeout=None, + ): """ Communicates with the managed process. If send_marker_list is set, input_data will not be sent until the marker is seen. @@ -83,15 +90,32 @@ def communicate(self, input_data=None, send_marker_list=None, close_marker=None, close_marker, kill_marker, send_with_newline, - timeout + timeout, ) finally: self._communication_started = True return (stdout, stderr) - def _communicate(self, input_data=None, send_marker_list=None, close_marker=None, kill_marker=None, - send_with_newline=False, timeout=None): + # Helper function to print out debugging statements + def get_fd_name(self, proc, fileobj): + if fileobj == proc.stdout: + return "stdout" + elif fileobj == proc.stderr: + return "stderr" + elif fileobj == proc.stdin: + return "stdin" + return "unknown fd" + + def _communicate( + self, + input_data=None, + send_marker_list=None, + close_marker=None, + kill_marker=None, + send_with_newline=False, + timeout=None, + ): """ This method will read and write data to a subprocess in a non-blocking manner. The code is heavily based on Popen.communicate. There are a couple differences: @@ -147,42 +171,54 @@ def _communicate(self, input_data=None, send_marker_list=None, close_marker=None while selector.get_map(): timeout = self._remaining_time(endtime) if timeout is not None and timeout < 0: - self._check_timeout(endtime, orig_timeout, - stdout, stderr, - skip_check_and_raise=True) + self._check_timeout( + endtime, orig_timeout, stdout, stderr, skip_check_and_raise=True + ) raise RuntimeError( # Impossible :) - '_check_timeout(..., skip_check_and_raise=True) ' - 'failed to raise TimeoutExpired.') + "_check_timeout(..., skip_check_and_raise=True) " + "failed to raise TimeoutExpired." + ) ready = selector.select(timeout) self._check_timeout(endtime, orig_timeout, stdout, stderr) # (Key, events) tuple represents a single I/O operation - for key, events in ready: + for key, num_events in ready: # STDIN is only registered to receive events after the send_marker is found. if key.fileobj is self.proc.stdin: - print(f'{self.name}: stdin available') - chunk = input_view[input_data_offset: - input_data_offset + _PIPE_BUF] + print(f"{self.name}: stdin available") + chunk = input_view[ + input_data_offset : input_data_offset + _PIPE_BUF + ] try: input_data_offset += os.write(key.fd, chunk) - print(f'{self.name}: sent') + print(f"{self.name}: sent") except BrokenPipeError: - selector.unregister(key.fileobj) + print(f"{self.name}: Unregistering (stdin) BrokenPipeError") + selector.unregister(self.proc.stdin) else: if input_data_offset >= input_data_len: - selector.unregister(key.fileobj) + print( + f"{self.name}: Unregistering (stdin) Input_data_offset >= input_data_len" + ) + selector.unregister(self.proc.stdin) input_data_sent = True input_data_offset = 0 if send_marker_list: send_marker = send_marker_list.pop(0) - print(f'{self.name}: next send_marker is {send_marker}') + print( + f"{self.name}: next send_marker is {send_marker}" + ) elif key.fileobj in (self.proc.stdout, self.proc.stderr): - print(f'{self.name}: stdout available') + fd_name = self.get_fd_name(self.proc, key.fileobj) + print(f"{self.name}: {fd_name} available") + # 32 KB (32 × 1024 = 32,768 bytes), read 32KB from the file descriptor data = os.read(key.fd, 32768) if not data: selector.unregister(key.fileobj) + print(f"{self.name}: Unregistering: {fd_name} No Data") + data_str = str(data) # Prepends n - 1 bytes of previously-seen stdout to the chunk we'll be searching @@ -192,11 +228,15 @@ def _communicate(self, input_data=None, send_marker_list=None, close_marker=None stored_stdout_list = self._fileobj2output[key.fileobj] send_marker_len = len(send_marker) - 1 if len(stored_stdout_list) > 0: - data_str = str(stored_stdout_list[-1][-send_marker_len:] + data) + data_str = str( + stored_stdout_list[-1][-send_marker_len:] + data + ) data_debug = data_str[:_DEBUG_LEN] if len(data_str) > _DEBUG_LEN: - data_debug += f' ...({len(data_str) - _DEBUG_LEN} more bytes)' + data_debug += ( + f" ...({len(data_str) - _DEBUG_LEN} more bytes)" + ) # fileobj2output[key.fileobj] is a list of data chunks # that get joined later @@ -206,57 +246,75 @@ def _communicate(self, input_data=None, send_marker_list=None, close_marker=None # register STDIN to receive events. If there is no data to send, # just mark input_send as true so we can close out STDIN. if send_marker: - print(f'{self.name}: looking for send_marker {send_marker} in {data_debug}') + print( + f"{self.name}: looking for send_marker {send_marker} in {data_debug}" + ) if send_marker is not None and send_marker in data_str: - print(f'{self.name}: found {send_marker}') + print(f"{self.name}: found {send_marker}") send_marker = None if self.proc.stdin and input_data: selector.register( - self.proc.stdin, selectors.EVENT_WRITE) + self.proc.stdin, selectors.EVENT_WRITE + ) message = input_data.pop(0) if send_with_newline: - message += b'\n' + message += b"\n" # Data destined for stdin is stored in a memoryview input_view = memoryview(message) input_data_len = len(message) input_data_sent = False - print(f'{self.name}: will send {message}') + print(f"{self.name}: will send {message}") else: input_data_sent = True - print(f'{self.name}: will send nothing') + print(f"{self.name}: will send nothing") if self.wait_for_marker: - print(f'{self.name}: looking for wait_for_marker {self.wait_for_marker} in {data_debug}') - if self.wait_for_marker is not None and self.wait_for_marker in data_str: + print( + f"{self.name}: looking for wait_for_marker {self.wait_for_marker} in {data_debug}" + ) + if ( + self.wait_for_marker is not None + and self.wait_for_marker in data_str + ): + print( + f"{self.name}: Unregistering (stdout + stderr), found wait_for_marker" + ) selector.unregister(self.proc.stdout) selector.unregister(self.proc.stderr) return None, None if kill_marker: - print(f'{self.name}: looking for kill_marker {kill_marker} in {data}') + print( + f"{self.name}: looking for kill_marker {kill_marker} in {data}" + ) if kill_marker is not None and kill_marker in data: + print( + f"{self.name}: Unregistering (stdout + stderr), found kill_marker" + ) selector.unregister(self.proc.stdout) selector.unregister(self.proc.stderr) self.proc.kill() - # If we have finished sending all our input, and have received the - # ready-to-send marker, we can close out stdin. - if self.proc.stdin and input_data_sent and not input_data: - print(f'{self.name}: finished sending') - if close_marker: - print(f'{self.name}: looking for close_marker {close_marker} in {data_debug}') - if close_marker is None or (close_marker and close_marker in data_str): - print(f'{self.name}: closing stdin') - input_data_sent = None - self.proc.stdin.close() + if self.proc.stdin and input_data_sent and not input_data: + print(f"{self.name}: finished sending") + if close_marker: + print( + f"{self.name}: looking for close_marker {close_marker} in {data_debug}" + ) + if close_marker is None or ( + close_marker and close_marker in data_str + ): + print(f"{self.name} Found close marker: closing stdin") + input_data_sent = None + self.proc.stdin.close() self.proc.wait(timeout=self._remaining_time(endtime)) # All data exchanged. Translate lists into strings. if stdout is not None: - stdout = b''.join(stdout) + stdout = b"".join(stdout) if stderr is not None: - stderr = b''.join(stderr) + stderr = b"".join(stderr) return (stdout, stderr) @@ -267,8 +325,9 @@ def _remaining_time(self, endtime): else: return endtime - _time() - def _check_timeout(self, endtime, orig_timeout, stdout_seq, stderr_seq, - skip_check_and_raise=False): + def _check_timeout( + self, endtime, orig_timeout, stdout_seq, stderr_seq, skip_check_and_raise=False + ): """ Convenience for checking if a timeout has expired. @@ -279,9 +338,11 @@ def _check_timeout(self, endtime, orig_timeout, stdout_seq, stderr_seq, return if skip_check_and_raise or _time() > endtime: raise subprocess.TimeoutExpired( - self.proc.args, orig_timeout, - output=b''.join(stdout_seq) if stdout_seq else None, - stderr=b''.join(stderr_seq) if stderr_seq else None) + self.proc.args, + orig_timeout, + output=b"".join(stdout_seq) if stdout_seq else None, + stderr=b"".join(stderr_seq) if stderr_seq else None, + ) class ManagedProcess(threading.Thread): @@ -293,9 +354,20 @@ class ManagedProcess(threading.Thread): are made available to the caller. """ - def __init__(self, cmd_line, provider_set_ready_condition, wait_for_marker=None, send_marker_list=None, - close_marker=None, timeout=5, data_source=None, env_overrides=dict(), expect_stderr=False, - kill_marker=None, send_with_newline=False): + def __init__( + self, + cmd_line, + provider_set_ready_condition, + wait_for_marker=None, + send_marker_list=None, + close_marker=None, + timeout=5, + data_source=None, + env_overrides=dict(), + expect_stderr=False, + kill_marker=None, + send_with_newline=False, + ): threading.Thread.__init__(self) proc_env = os.environ.copy() @@ -341,12 +413,17 @@ def __init__(self, cmd_line, provider_set_ready_condition, wait_for_marker=None, def run(self): with self.results_condition: try: - proc = subprocess.Popen(self.cmd_line, env=self.proc_env, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True) + proc = subprocess.Popen( + self.cmd_line, + env=self.proc_env, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=True, + ) self.proc = proc except Exception as ex: - self.results = Results( - None, None, None, ex, self.expect_stderr) + self.results = Results(None, None, None, ex, self.expect_stderr) raise ex communicator = _processCommunicator(proc, self.cmd_line[0]) @@ -374,7 +451,7 @@ def run(self): proc.returncode, None, expect_stderr=self.expect_stderr, - expect_nonzero_exit=self.kill_marker is not None + expect_nonzero_exit=self.kill_marker is not None, ) except subprocess.TimeoutExpired as ex: proc.kill() @@ -383,16 +460,28 @@ def run(self): # Read any remaining output proc_results = communicator.communicate() self.results = Results( - proc_results[0], proc_results[1], proc.returncode, wrapped_ex, self.expect_stderr) + proc_results[0], + proc_results[1], + proc.returncode, + wrapped_ex, + self.expect_stderr, + ) except Exception as ex: self.results = Results( - proc_results[0], proc_results[1], proc.returncode, ex, self.expect_stderr) + proc_results[0], + proc_results[1], + proc.returncode, + ex, + self.expect_stderr, + ) raise ex finally: # This data is dumped to stdout so we capture this # information no matter where a test fails. print("###############################################################") - print(f"####################### {self.cmd_line[0]} #######################") + print( + f"####################### {self.cmd_line[0]} #######################" + ) print("###############################################################") print(f"Command line:\n\t{' '.join(self.cmd_line)}") @@ -436,7 +525,8 @@ def get_results(self, send_data=None): """ with self.results_condition: result = self.results_condition.wait_for( - self._results_ready, timeout=self.timeout) + self._results_ready, timeout=self.timeout + ) if result is False: raise Exception("Timeout") diff --git a/tests/integrationv2/providers.py b/tests/integrationv2/providers.py index 8453b3c5982..cf5ebb3461a 100644 --- a/tests/integrationv2/providers.py +++ b/tests/integrationv2/providers.py @@ -112,26 +112,27 @@ def __init__(self, options: ProviderOptions): Provider.__init__(self, options) def setup_client(self): - self.ready_to_test_marker = 'listening on lo' + self.ready_to_test_marker = "listening on lo" tcpdump_filter = "dst port {}".format(self.options.port) - cmd_line = ["tcpdump", - # Line buffer the output - "-l", - - # Only read 10 packets before exiting. This is enough to find a large - # packet, and still exit before the timeout. - "-c", "10", - - # Watch the loopback device - "-i", "lo", - - # Don't resolve IP addresses - "-nn", - - # Set the buffer size to 1k - "-B", "1024", - tcpdump_filter] + cmd_line = [ + "tcpdump", + # Line buffer the output + "-l", + # Only read 10 packets before exiting. This is enough to find a large + # packet, and still exit before the timeout. + "-c", + "10", + # Watch the loopback device + "-i", + "lo", + # Don't resolve IP addresses + "-nn", + # Set the buffer size to 1k + "-B", + "1024", + tcpdump_filter, + ] return cmd_line @@ -148,18 +149,14 @@ def __init__(self, options: ProviderOptions): @classmethod def get_send_marker(cls): - return 's2n is ready' + return "s2n is ready" @classmethod def _pss_supported(cls): # RSA-PSS is unsupported for openssl-1.0 # libressl and boringssl are disabled because of configuration issues # see https://github.com/aws/s2n-tls/issues/3250 - PSS_UNSUPPORTED_LIBCRYPTOS = { - "libressl", - "boringssl", - "openssl-1.0" - } + PSS_UNSUPPORTED_LIBCRYPTOS = {"libressl", "boringssl", "openssl-1.0"} for libcrypto in PSS_UNSUPPORTED_LIBCRYPTOS: # e.g. "openssl-1.0" in "openssl-1.0.2-fips" if libcrypto in get_flag(S2N_PROVIDER_VERSION): @@ -168,7 +165,7 @@ def _pss_supported(cls): @classmethod def supports_certificate(cls, cert: Cert): - if not cls._pss_supported() and cert.algorithm == 'RSAPSS': + if not cls._pss_supported() and cert.algorithm == "RSAPSS": return False return True @@ -178,11 +175,13 @@ def supports_protocol(cls, protocol): return False # SSLv3 cannot be negotiated in FIPS mode with libcryptos other than AWS-LC. - if all([ - protocol == Protocols.SSLv3, - get_flag(S2N_FIPS_MODE), - "awslc" not in get_flag(S2N_PROVIDER_VERSION) - ]): + if all( + [ + protocol == Protocols.SSLv3, + get_flag(S2N_FIPS_MODE), + "awslc" not in get_flag(S2N_PROVIDER_VERSION), + ] + ): return False return True @@ -196,7 +195,10 @@ def supports_cipher(cls, cipher, with_curve=None): "RC4": ["openssl-3"], } - for unsupported_cipher, unsupported_libcryptos in unsupported_configurations.items(): + for ( + unsupported_cipher, + unsupported_libcryptos, + ) in unsupported_configurations.items(): # the queried cipher has some libcrypto's that don't support it # e.g. "RC4" in "TLS_ECDHE_RSA_WITH_RC4_128_SHA" if unsupported_cipher in cipher.name: @@ -210,14 +212,15 @@ def supports_cipher(cls, cipher, with_curve=None): @classmethod def supports_signature(cls, signature): # Disable RSA_PSS_RSAE_SHA256 in unsupported libcryptos - if any([ - libcrypto in get_flag(S2N_PROVIDER_VERSION) - for libcrypto in [ - "openssl-1.0.2", - "libressl", - "boringssl" - ] - ]) and signature == Signatures.RSA_PSS_RSAE_SHA256: + if ( + any( + [ + libcrypto in get_flag(S2N_PROVIDER_VERSION) + for libcrypto in ["openssl-1.0.2", "libressl", "boringssl"] + ] + ) + and signature == Signatures.RSA_PSS_RSAE_SHA256 + ): return False return True @@ -228,44 +231,44 @@ def setup_client(self): """ cmd_line = [] if self.options.use_mainline_version is True: - cmd_line.append('s2nc_head') + cmd_line.append("s2nc_head") else: - cmd_line.append('s2nc') - cmd_line.append('--non-blocking') + cmd_line.append("s2nc") + cmd_line.append("--non-blocking") # Tests requiring reconnects can't wait on echo data, # but all other tests can. if self.options.reconnect is not True: - cmd_line.append('-e') + cmd_line.append("-e") if self.options.use_session_ticket is False: - cmd_line.append('-T') + cmd_line.append("-T") if self.options.insecure is True: - cmd_line.append('--insecure') + cmd_line.append("--insecure") elif self.options.trust_store: - cmd_line.extend(['-f', self.options.trust_store]) + cmd_line.extend(["-f", self.options.trust_store]) elif self.options.cert: - cmd_line.extend(['-f', self.options.cert]) + cmd_line.extend(["-f", self.options.cert]) if self.options.reconnect is True: - cmd_line.append('-r') + cmd_line.append("-r") # If the test provided a cipher (security policy) that is compatible with # s2n, we'll use it. Otherwise, default to the appropriate `test_all` policy. - cipher_prefs = 'test_all_tls12' + cipher_prefs = "test_all_tls12" if self.options.protocol is Protocols.TLS13: - cipher_prefs = 'test_all' + cipher_prefs = "test_all" if self.options.cipher and self.options.cipher.s2n: cipher_prefs = self.options.cipher.name - cmd_line.extend(['-c', cipher_prefs]) + cmd_line.extend(["-c", cipher_prefs]) if self.options.use_client_auth: if self.options.key: - cmd_line.extend(['--key', self.options.key]) + cmd_line.extend(["--key", self.options.key]) if self.options.cert: - cmd_line.extend(['--cert', self.options.cert]) + cmd_line.extend(["--cert", self.options.cert]) if get_flag(S2N_FIPS_MODE): cmd_line.append("--enter-fips-mode") @@ -285,49 +288,50 @@ def setup_client(self): def setup_server(self): # s2nd prints this message after it begins listening for connections - self.ready_to_test_marker = 'Listening on' + self.ready_to_test_marker = "Listening on" """ Using the passed ProviderOptions, create a command line. """ cmd_line = [] if self.options.use_mainline_version is True: - cmd_line.append('s2nd_head') + cmd_line.append("s2nd_head") else: - cmd_line.append('s2nd') - cmd_line.extend(['-X', '--self-service-blinding', '--non-blocking']) + cmd_line.append("s2nd") + cmd_line.extend(["-X", "--self-service-blinding", "--non-blocking"]) if self.options.key is not None: - cmd_line.extend(['--key', self.options.key]) + cmd_line.extend(["--key", self.options.key]) if self.options.cert is not None: - cmd_line.extend(['--cert', self.options.cert]) + cmd_line.extend(["--cert", self.options.cert]) if self.options.insecure is True: - cmd_line.append('--insecure') + cmd_line.append("--insecure") elif self.options.trust_store: - cmd_line.extend(['-t', self.options.trust_store]) + cmd_line.extend(["-t", self.options.trust_store]) elif self.options.cert: - cmd_line.extend(['-t', self.options.cert]) + cmd_line.extend(["-t", self.options.cert]) # If the test provided a cipher (security policy) that is compatible with # s2n, we'll use it. Otherwise, default to the appropriate `test_all` policy. - cipher_prefs = 'test_all_tls12' + cipher_prefs = "test_all_tls12" if self.options.protocol is Protocols.TLS13: - cipher_prefs = 'test_all' + cipher_prefs = "test_all" if self.options.cipher and self.options.cipher.s2n: cipher_prefs = self.options.cipher.name - cmd_line.extend(['-c', cipher_prefs]) + cmd_line.extend(["-c", cipher_prefs]) if self.options.use_client_auth is True: - cmd_line.append('-m') + cmd_line.append("-m") if self.options.use_session_ticket is False: - cmd_line.append('-T') + cmd_line.append("-T") if self.options.reconnects_before_exit is not None: cmd_line.append( - '--max-conns={}'.format(self.options.reconnects_before_exit)) + "--max-conns={}".format(self.options.reconnects_before_exit) + ) if get_flag(S2N_FIPS_MODE): cmd_line.append("--enter-fips-mode") @@ -363,7 +367,7 @@ def __init__(self, options: ProviderOptions): @classmethod def get_send_marker(cls): - return 'Verify return code' + return "Verify return code" def _join_ciphers(self, ciphers): """ @@ -375,7 +379,7 @@ def _join_ciphers(self, ciphers): for c in ciphers: cipher_list.append(c.name) - ciphers = ':'.join(cipher_list) + ciphers = ":".join(cipher_list) return ciphers @@ -386,23 +390,29 @@ def _cipher_to_cmdline(self, cipher): if type(cipher) is list: # In the case of a cipher list we need to be sure TLS13 specific ciphers aren't # mixed with ciphers from previous versions - is_tls13_or_above = (cipher[0].min_version >= Protocols.TLS13) - mismatch = [c for c in cipher if ( - c.min_version >= Protocols.TLS13) != is_tls13_or_above] + is_tls13_or_above = cipher[0].min_version >= Protocols.TLS13 + mismatch = [ + c + for c in cipher + if (c.min_version >= Protocols.TLS13) != is_tls13_or_above + ] if len(mismatch) > 0: - raise Exception("Cannot combine ciphers for TLS1.3 or above with older ciphers: {}".format( - [c.name for c in cipher])) + raise Exception( + "Cannot combine ciphers for TLS1.3 or above with older ciphers: {}".format( + [c.name for c in cipher] + ) + ) ciphers.append(self._join_ciphers(cipher)) else: - is_tls13_or_above = (cipher.min_version >= Protocols.TLS13) + is_tls13_or_above = cipher.min_version >= Protocols.TLS13 ciphers.append(cipher.name) if is_tls13_or_above: - cmdline.append('-ciphersuites') + cmdline.append("-ciphersuites") else: - cmdline.append('-cipher') + cmdline.append("-cipher") return cmdline + ciphers @@ -444,61 +454,61 @@ def at_least_openssl_1_1(self) -> None: raise FileNotFoundError(f"Openssl version returned {OpenSSL.get_version()}, expected at least 1.1.x.") def setup_client(self): - cmd_line = ['openssl', 's_client'] + cmd_line = ["openssl", "s_client"] cmd_line.extend( - ['-connect', '{}:{}'.format(self.options.host, self.options.port)]) + ["-connect", "{}:{}".format(self.options.host, self.options.port)] + ) # Additional debugging that will be captured incase of failure if self.options.verbose: - cmd_line.append('-debug') + cmd_line.append("-debug") - cmd_line.extend(['-tlsextdebug', '-state']) + cmd_line.extend(["-tlsextdebug", "-state"]) if self.options.key is not None: - cmd_line.extend(['-key', self.options.key]) + cmd_line.extend(["-key", self.options.key]) # Unlike s2n, OpenSSL allows us to be much more specific about which TLS # protocol to use. if self.options.protocol == Protocols.TLS13: - cmd_line.append('-tls1_3') + cmd_line.append("-tls1_3") elif self.options.protocol == Protocols.TLS12: - cmd_line.append('-tls1_2') + cmd_line.append("-tls1_2") elif self.options.protocol == Protocols.TLS11: - cmd_line.append('-tls1_1') + cmd_line.append("-tls1_1") elif self.options.protocol == Protocols.TLS10: - cmd_line.append('-tls1') + cmd_line.append("-tls1") elif self.options.protocol == Protocols.SSLv3: - cmd_line.append('-ssl3') + cmd_line.append("-ssl3") if self.options.cipher is not None: cmd_line.extend(self._cipher_to_cmdline(self.options.cipher)) if self.options.curve is not None: - cmd_line.extend(['-curves', str(self.options.curve)]) + cmd_line.extend(["-curves", str(self.options.curve)]) if self.options.use_client_auth: if self.options.key: - cmd_line.extend(['-key', self.options.key]) + cmd_line.extend(["-key", self.options.key]) if self.options.cert: - cmd_line.extend(['-cert', self.options.cert]) + cmd_line.extend(["-cert", self.options.cert]) if self.options.reconnect is True: - cmd_line.append('-reconnect') + cmd_line.append("-reconnect") if self.options.extra_flags is not None: cmd_line.extend(self.options.extra_flags) if self.options.server_name is not None: - cmd_line.extend(['-servername', self.options.server_name]) + cmd_line.extend(["-servername", self.options.server_name]) if self.options.verify_hostname is not None: - cmd_line.extend(['-verify_hostname', self.options.server_name]) + cmd_line.extend(["-verify_hostname", self.options.server_name]) if self.options.enable_client_ocsp: cmd_line.append("-status") if self.options.signature_algorithm is not None: - cmd_line.extend( - ["-sigalgs", self.options.signature_algorithm.name]) + cmd_line.extend(["-sigalgs", self.options.signature_algorithm.name]) if self.options.record_size is not None: cmd_line.extend(["-max_send_frag", str(self.options.record_size)]) @@ -510,60 +520,58 @@ def setup_client(self): def setup_server(self): # s_server prints this message before it is ready to send/receive data - self.ready_to_test_marker = 'ACCEPT' + self.ready_to_test_marker = "ACCEPT" - cmd_line = ['openssl', 's_server'] - cmd_line.extend(['-accept', '{}'.format(self.options.port)]) + cmd_line = ["openssl", "s_server"] + cmd_line.extend(["-accept", "{}".format(self.options.port)]) if self.options.reconnects_before_exit is not None: # If the user request a specific reconnection count, set it here - cmd_line.extend( - ['-naccept', str(self.options.reconnects_before_exit)]) + cmd_line.extend(["-naccept", str(self.options.reconnects_before_exit)]) else: # Exit after the first connection by default - cmd_line.extend(['-naccept', '1']) + cmd_line.extend(["-naccept", "1"]) # Additional debugging that will be captured incase of failure if self.options.verbose: - cmd_line.append('-debug') + cmd_line.append("-debug") - cmd_line.extend(['-tlsextdebug', '-state']) + cmd_line.extend(["-tlsextdebug", "-state"]) if self.options.cert is not None: - cmd_line.extend(['-cert', self.options.cert]) + cmd_line.extend(["-cert", self.options.cert]) if self.options.key is not None: - cmd_line.extend(['-key', self.options.key]) + cmd_line.extend(["-key", self.options.key]) # Unlike s2n, OpenSSL allows us to be much more specific about which TLS # protocol to use. if self.options.protocol == Protocols.TLS13: - cmd_line.append('-tls1_3') + cmd_line.append("-tls1_3") elif self.options.protocol == Protocols.TLS12: - cmd_line.append('-tls1_2') + cmd_line.append("-tls1_2") elif self.options.protocol == Protocols.TLS11: - cmd_line.append('-tls1_1') + cmd_line.append("-tls1_1") elif self.options.protocol == Protocols.TLS10: - cmd_line.append('-tls1') + cmd_line.append("-tls1") elif self.options.protocol == Protocols.SSLv3: - cmd_line.append('-ssl3') + cmd_line.append("-ssl3") if self.options.cipher is not None: cmd_line.extend(self._cipher_to_cmdline(self.options.cipher)) if self.options.cipher.parameters is not None: - cmd_line.extend(['-dhparam', self.options.cipher.parameters]) + cmd_line.extend(["-dhparam", self.options.cipher.parameters]) if self.options.curve is not None: - cmd_line.extend(['-curves', str(self.options.curve)]) + cmd_line.extend(["-curves", str(self.options.curve)]) if self.options.use_client_auth is True: # We use "Verify" instead of "verify" to require a client cert - cmd_line.extend(['-Verify', '1']) + cmd_line.extend(["-Verify", "1"]) if self.options.ocsp_response is not None: cmd_line.extend(["-status_file", self.options.ocsp_response]) if self.options.signature_algorithm is not None: - cmd_line.extend( - ["-sigalgs", self.options.signature_algorithm.name]) + cmd_line.extend(["-sigalgs", self.options.signature_algorithm.name]) if self.options.extra_flags is not None: cmd_line.extend(self.options.extra_flags) @@ -607,7 +615,11 @@ def get_send_marker(cls): @classmethod def supports_protocol(cls, protocol): # https://aws.amazon.com/blogs/opensource/tls-1-0-1-1-changes-in-openjdk-and-amazon-corretto/ - if protocol is Protocols.SSLv3 or protocol is Protocols.TLS10 or protocol is Protocols.TLS11: + if ( + protocol is Protocols.SSLv3 + or protocol is Protocols.TLS10 + or protocol is Protocols.TLS11 + ): return False return True @@ -615,16 +627,16 @@ def supports_protocol(cls, protocol): @classmethod def supports_cipher(cls, cipher, with_curve=None): # Java SSL does not support CHACHA20 - if 'CHACHA20' in cipher.name: + if "CHACHA20" in cipher.name: return False return True def setup_server(self): - pytest.skip('JavaSSL does not support server mode at this time') + pytest.skip("JavaSSL does not support server mode at this time") def setup_client(self): - cmd_line = ['java', "-classpath", "bin", "SSLSocketClient"] + cmd_line = ["java", "-classpath", "bin", "SSLSocketClient"] if self.options.port is not None: cmd_line.extend([self.options.port]) @@ -661,29 +673,30 @@ def __init__(self, options: ProviderOptions): @classmethod def get_send_marker(cls): - return 'Cert issuer:' + return "Cert issuer:" def setup_server(self): - cmd_line = ['bssl', 's_server'] - cmd_line.extend(['-accept', self.options.port]) + cmd_line = ["bssl", "s_server"] + cmd_line.extend(["-accept", self.options.port]) if self.options.cert is not None: - cmd_line.extend(['-cert', self.options.cert]) + cmd_line.extend(["-cert", self.options.cert]) if self.options.key is not None: - cmd_line.extend(['-key', self.options.key]) + cmd_line.extend(["-key", self.options.key]) if self.options.curve is not None: if self.options.curve == Curves.P256: - cmd_line.extend(['-curves', 'P-256']) + cmd_line.extend(["-curves", "P-256"]) elif self.options.curve == Curves.P384: - cmd_line.extend(['-curves', 'P-384']) + cmd_line.extend(["-curves", "P-384"]) elif self.options.curve == Curves.P521: - cmd_line.extend(['-curves', 'P-521']) + cmd_line.extend(["-curves", "P-521"]) elif self.options.curve == Curves.SecP256r1Kyber768Draft00: - cmd_line.extend(['-curves', 'SecP256r1Kyber768Draft00']) + cmd_line.extend(["-curves", "SecP256r1Kyber768Draft00"]) elif self.options.curve == Curves.X25519Kyber768Draft00: - cmd_line.extend(['-curves', 'X25519Kyber768Draft00']) + cmd_line.extend(["-curves", "X25519Kyber768Draft00"]) elif self.options.curve == Curves.X25519: - pytest.skip('BoringSSL does not support curve {}'.format( - self.options.curve)) + pytest.skip( + "BoringSSL does not support curve {}".format(self.options.curve) + ) if self.options.extra_flags is not None: cmd_line.extend(self.options.extra_flags) @@ -691,37 +704,42 @@ def setup_server(self): return cmd_line def setup_client(self): - cmd_line = ['bssl', 's_client'] + cmd_line = ["bssl", "s_client"] cmd_line.extend( - ['-connect', '{}:{}'.format(self.options.host, self.options.port)]) + ["-connect", "{}:{}".format(self.options.host, self.options.port)] + ) if self.options.cert is not None: - cmd_line.extend(['-cert', self.options.cert]) + cmd_line.extend(["-cert", self.options.cert]) if self.options.key is not None: - cmd_line.extend(['-key', self.options.key]) + cmd_line.extend(["-key", self.options.key]) if self.options.cipher is not None: if self.options.cipher == Ciphersuites.TLS_CHACHA20_POLY1305_SHA256: cmd_line.extend( - ['-cipher', 'TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256']) + ["-cipher", "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"] + ) elif self.options.cipher == Ciphersuites.TLS_AES_128_GCM_256: - pytest.skip('BoringSSL does not support Cipher {}'.format( - self.options.cipher)) + pytest.skip( + "BoringSSL does not support Cipher {}".format(self.options.cipher) + ) elif self.options.cipher == Ciphersuites.TLS_AES_256_GCM_384: - pytest.skip('BoringSSL does not support Cipher {}'.format( - self.options.cipher)) + pytest.skip( + "BoringSSL does not support Cipher {}".format(self.options.cipher) + ) if self.options.curve is not None: if self.options.curve == Curves.P256: - cmd_line.extend(['-curves', 'P-256']) + cmd_line.extend(["-curves", "P-256"]) elif self.options.curve == Curves.P384: - cmd_line.extend(['-curves', 'P-384']) + cmd_line.extend(["-curves", "P-384"]) elif self.options.curve == Curves.P521: - cmd_line.extend(['-curves', 'P-521']) + cmd_line.extend(["-curves", "P-521"]) elif self.options.curve == Curves.SecP256r1Kyber768Draft00: - cmd_line.extend(['-curves', 'SecP256r1Kyber768Draft00']) + cmd_line.extend(["-curves", "SecP256r1Kyber768Draft00"]) elif self.options.curve == Curves.X25519Kyber768Draft00: - cmd_line.extend(['-curves', 'X25519Kyber768Draft00']) + cmd_line.extend(["-curves", "X25519Kyber768Draft00"]) elif self.options.curve == Curves.X25519: - pytest.skip('BoringSSL does not support curve {}'.format( - self.options.curve)) + pytest.skip( + "BoringSSL does not support curve {}".format(self.options.curve) + ) if self.options.extra_flags is not None: cmd_line.extend(self.options.extra_flags) @@ -742,35 +760,32 @@ def __init__(self, options: ProviderOptions): @staticmethod def cipher_to_priority_str(cipher): return { - Ciphers.DHE_RSA_AES128_SHA: "DHE-RSA:+AES-128-CBC:+SHA1", - Ciphers.DHE_RSA_AES256_SHA: "DHE-RSA:+AES-256-CBC:+SHA1", - Ciphers.DHE_RSA_AES128_SHA256: "DHE-RSA:+AES-128-CBC:+SHA256", - Ciphers.DHE_RSA_AES256_SHA256: "DHE-RSA:+AES-256-CBC:+SHA256", - Ciphers.DHE_RSA_AES128_GCM_SHA256: "DHE-RSA:+AES-128-GCM:+AEAD", - Ciphers.DHE_RSA_AES256_GCM_SHA384: "DHE-RSA:+AES-256-GCM:+AEAD", - Ciphers.DHE_RSA_CHACHA20_POLY1305: "DHE-RSA:+CHACHA20-POLY1305:+AEAD", - - Ciphers.AES128_SHA: "RSA:+AES-128-CBC:+SHA1", - Ciphers.AES256_SHA: "RSA:+AES-256-CBC:+SHA1", - Ciphers.AES128_SHA256: "RSA:+AES-128-CBC:+SHA256", - Ciphers.AES256_SHA256: "RSA:+AES-256-CBC:+SHA256", - Ciphers.AES128_GCM_SHA256: "RSA:+AES-128-GCM:+AEAD", - Ciphers.AES256_GCM_SHA384: "RSA:+AES-256-GCM:+AEAD", - - Ciphers.ECDHE_ECDSA_AES128_SHA: "ECDHE-ECDSA:+AES-128-CBC:+SHA1", - Ciphers.ECDHE_ECDSA_AES256_SHA: "ECDHE-ECDSA:+AES-256-CBC:+SHA1", - Ciphers.ECDHE_ECDSA_AES128_SHA256: "ECDHE-ECDSA:+AES-128-CBC:+SHA256", - Ciphers.ECDHE_ECDSA_AES256_SHA384: "ECDHE-ECDSA:+AES-256-CBC:+SHA384", - Ciphers.ECDHE_ECDSA_AES128_GCM_SHA256: "ECDHE-ECDSA:+AES-128-GCM:+AEAD", - Ciphers.ECDHE_ECDSA_AES256_GCM_SHA384: "ECDHE-ECDSA:+AES-256-GCM:+AEAD", - - Ciphers.ECDHE_RSA_AES128_SHA: "ECDHE-RSA:+AES-128-CBC:+SHA1", - Ciphers.ECDHE_RSA_AES256_SHA: "ECDHE-RSA:+AES-256-CBC:+SHA1", - Ciphers.ECDHE_RSA_AES128_SHA256: "ECDHE-RSA:+AES-128-CBC:+SHA256", - Ciphers.ECDHE_RSA_AES256_SHA384: "ECDHE-RSA:+AES-256-CBC:+SHA384", - Ciphers.ECDHE_RSA_AES128_GCM_SHA256: "ECDHE-RSA:+AES-128-GCM:+AEAD", - Ciphers.ECDHE_RSA_AES256_GCM_SHA384: "ECDHE-RSA:+AES-256-GCM:+AEAD", - Ciphers.ECDHE_RSA_CHACHA20_POLY1305: "ECDHE-RSA:+CHACHA20-POLY1305:+AEAD" + Ciphers.DHE_RSA_AES128_SHA: "DHE-RSA:+AES-128-CBC:+SHA1", + Ciphers.DHE_RSA_AES256_SHA: "DHE-RSA:+AES-256-CBC:+SHA1", + Ciphers.DHE_RSA_AES128_SHA256: "DHE-RSA:+AES-128-CBC:+SHA256", + Ciphers.DHE_RSA_AES256_SHA256: "DHE-RSA:+AES-256-CBC:+SHA256", + Ciphers.DHE_RSA_AES128_GCM_SHA256: "DHE-RSA:+AES-128-GCM:+AEAD", + Ciphers.DHE_RSA_AES256_GCM_SHA384: "DHE-RSA:+AES-256-GCM:+AEAD", + Ciphers.DHE_RSA_CHACHA20_POLY1305: "DHE-RSA:+CHACHA20-POLY1305:+AEAD", + Ciphers.AES128_SHA: "RSA:+AES-128-CBC:+SHA1", + Ciphers.AES256_SHA: "RSA:+AES-256-CBC:+SHA1", + Ciphers.AES128_SHA256: "RSA:+AES-128-CBC:+SHA256", + Ciphers.AES256_SHA256: "RSA:+AES-256-CBC:+SHA256", + Ciphers.AES128_GCM_SHA256: "RSA:+AES-128-GCM:+AEAD", + Ciphers.AES256_GCM_SHA384: "RSA:+AES-256-GCM:+AEAD", + Ciphers.ECDHE_ECDSA_AES128_SHA: "ECDHE-ECDSA:+AES-128-CBC:+SHA1", + Ciphers.ECDHE_ECDSA_AES256_SHA: "ECDHE-ECDSA:+AES-256-CBC:+SHA1", + Ciphers.ECDHE_ECDSA_AES128_SHA256: "ECDHE-ECDSA:+AES-128-CBC:+SHA256", + Ciphers.ECDHE_ECDSA_AES256_SHA384: "ECDHE-ECDSA:+AES-256-CBC:+SHA384", + Ciphers.ECDHE_ECDSA_AES128_GCM_SHA256: "ECDHE-ECDSA:+AES-128-GCM:+AEAD", + Ciphers.ECDHE_ECDSA_AES256_GCM_SHA384: "ECDHE-ECDSA:+AES-256-GCM:+AEAD", + Ciphers.ECDHE_RSA_AES128_SHA: "ECDHE-RSA:+AES-128-CBC:+SHA1", + Ciphers.ECDHE_RSA_AES256_SHA: "ECDHE-RSA:+AES-256-CBC:+SHA1", + Ciphers.ECDHE_RSA_AES128_SHA256: "ECDHE-RSA:+AES-128-CBC:+SHA256", + Ciphers.ECDHE_RSA_AES256_SHA384: "ECDHE-RSA:+AES-256-CBC:+SHA384", + Ciphers.ECDHE_RSA_AES128_GCM_SHA256: "ECDHE-RSA:+AES-128-GCM:+AEAD", + Ciphers.ECDHE_RSA_AES256_GCM_SHA384: "ECDHE-RSA:+AES-256-GCM:+AEAD", + Ciphers.ECDHE_RSA_CHACHA20_POLY1305: "ECDHE-RSA:+CHACHA20-POLY1305:+AEAD", }.get(cipher) @staticmethod @@ -781,25 +796,25 @@ def protocol_to_priority_str(protocol): Protocols.TLS10.value: "VERS-TLS1.0", Protocols.TLS11.value: "VERS-TLS1.1", Protocols.TLS12.value: "VERS-TLS1.2", - Protocols.TLS13.value: "VERS-TLS1.3" + Protocols.TLS13.value: "VERS-TLS1.3", }.get(protocol.value) @staticmethod def curve_to_priority_str(curve): return { - Curves.P256: "CURVE-SECP256R1", - Curves.P384: "CURVE-SECP384R1", - Curves.P521: "CURVE-SECP521R1", - Curves.X25519: "CURVE-X25519" + Curves.P256: "CURVE-SECP256R1", + Curves.P384: "CURVE-SECP384R1", + Curves.P521: "CURVE-SECP521R1", + Curves.X25519: "CURVE-X25519", }.get(curve) @staticmethod def sigalg_to_priority_str(sigalg): return { - Signatures.RSA_SHA1: "SIGN-RSA-SHA1", - Signatures.RSA_SHA256: "SIGN-RSA-SHA256", - Signatures.RSA_SHA384: "SIGN-RSA-SHA384", - Signatures.RSA_SHA512: "SIGN-RSA-SHA512", + Signatures.RSA_SHA1: "SIGN-RSA-SHA1", + Signatures.RSA_SHA256: "SIGN-RSA-SHA256", + Signatures.RSA_SHA384: "SIGN-RSA-SHA384", + Signatures.RSA_SHA512: "SIGN-RSA-SHA512", }.get(sigalg) @classmethod @@ -827,7 +842,9 @@ def create_priority_str(self): else: priority_str += ":+GROUP-ALL" - sigalg_to_priority_str = self.sigalg_to_priority_str(self.options.signature_algorithm) + sigalg_to_priority_str = self.sigalg_to_priority_str( + self.options.signature_algorithm + ) if sigalg_to_priority_str: priority_str += ":+" + sigalg_to_priority_str else: @@ -849,9 +866,11 @@ def setup_client(self): cmd_line = [ "gnutls-cli", - "--port", str(self.options.port), + "--port", + str(self.options.port), self.options.host, - "--debug", "9999" + "--debug", + "9999", ] if self.options.verbose: @@ -885,7 +904,7 @@ def setup_server(self): "gnutls-serv", f"--port={self.options.port}", "--echo", - "--debug=9999" + "--debug=9999", ] if self.options.cert is not None: diff --git a/tests/integrationv2/pyproject.toml b/tests/integrationv2/pyproject.toml index ec06a19a9ef..c04014f9df4 100644 --- a/tests/integrationv2/pyproject.toml +++ b/tests/integrationv2/pyproject.toml @@ -9,5 +9,6 @@ dependencies = [ "pytest>=8.3.4", "pytest-rerunfailures>=15.0", "pytest-xdist>=3.6.1", + "ruff>=0.9.7", "sslyze>=6.1.0", ] diff --git a/tests/integrationv2/test_buffered_send.py b/tests/integrationv2/test_buffered_send.py index d10cef65d14..d06b644a07b 100644 --- a/tests/integrationv2/test_buffered_send.py +++ b/tests/integrationv2/test_buffered_send.py @@ -2,13 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from configuration import available_ports, PROTOCOLS, ALL_TEST_CIPHERS, MINIMAL_TEST_CERTS +from configuration import ( + available_ports, + PROTOCOLS, + ALL_TEST_CIPHERS, + MINIMAL_TEST_CERTS, +) from common import ProviderOptions, data_bytes from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N, OpenSSL, GnuTLS from utils import invalid_test_parameters, get_parameter_name, to_bytes, to_string -SEND_DATA_SIZE = 2 ** 14 +SEND_DATA_SIZE = 2**14 # CLOSE_MARKER must a substring of SEND_DATA exactly once, and must be its suffix CLOSE_MARKER = "unique-suffix-close-marker" @@ -28,28 +33,26 @@ SEND_BUFFER_SIZE_MIN_RECOMMENDED, SEND_BUFFER_SIZE_MULTI_RECORD, SEND_BUFFER_SIZE_PREFER_THROUGHPUT, - SEND_BUFFER_SIZE_HUGE + SEND_BUFFER_SIZE_HUGE, ] -FRAGMENT_PREFERENCE = [ - None, - "--prefer-low-latency", - "--prefer-throughput" -] +FRAGMENT_PREFERENCE = [None, "--prefer-low-latency", "--prefer-throughput"] def test_SEND_BUFFER_SIZE_MIN_is_s2ns_min_buffer_size(managed_process): port = next(available_ports) - s2n_options = ProviderOptions(mode=Provider.ServerMode, - port=port, - data_to_send="test", - extra_flags=['--buffered-send', SEND_BUFFER_SIZE_MIN]) + s2n_options = ProviderOptions( + mode=Provider.ServerMode, + port=port, + data_to_send="test", + extra_flags=["--buffered-send", SEND_BUFFER_SIZE_MIN], + ) s2nd = managed_process(S2N, s2n_options) s2n_options.mode = Provider.ClientMode - s2n_options.extra_flags = ['--buffered-send', SEND_BUFFER_SIZE_MIN - 1] + s2n_options.extra_flags = ["--buffered-send", SEND_BUFFER_SIZE_MIN - 1] s2nc = managed_process(S2N, s2n_options) for results in s2nc.get_results(): @@ -66,9 +69,18 @@ def test_SEND_BUFFER_SIZE_MIN_is_s2ns_min_buffer_size(managed_process): @pytest.mark.parametrize("protocol", PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("buffer_size", SEND_BUFFER_SIZES, ids=get_parameter_name) -@pytest.mark.parametrize("fragment_preference", FRAGMENT_PREFERENCE, ids=get_parameter_name) -def test_s2n_server_buffered_send(managed_process, cipher, provider, protocol, certificate, buffer_size, - fragment_preference): +@pytest.mark.parametrize( + "fragment_preference", FRAGMENT_PREFERENCE, ids=get_parameter_name +) +def test_s2n_server_buffered_send( + managed_process, + cipher, + provider, + protocol, + certificate, + buffer_size, + fragment_preference, +): # Communication Timeline # Client [S2N|OpenSSL|GnuTLS] | Server [S2N] # Handshake | Handshake @@ -83,9 +95,10 @@ def test_s2n_server_buffered_send(managed_process, cipher, provider, protocol, c data_to_send=None, insecure=True, protocol=protocol, - verbose=False) + verbose=False, + ) - extra_flags = ['--buffered-send', buffer_size] + extra_flags = ["--buffered-send", buffer_size] if fragment_preference is not None: extra_flags.append(fragment_preference) @@ -98,10 +111,15 @@ def test_s2n_server_buffered_send(managed_process, cipher, provider, protocol, c protocol=protocol, key=certificate.key, cert=certificate.cert, - extra_flags=extra_flags) + extra_flags=extra_flags, + ) - server = managed_process(S2N, s2n_server_options, send_marker=[S2N.get_send_marker()]) - client = managed_process(provider, provider_client_options, close_marker=CLOSE_MARKER) + server = managed_process( + S2N, s2n_server_options, send_marker=[S2N.get_send_marker()] + ) + client = managed_process( + provider, provider_client_options, close_marker=CLOSE_MARKER + ) for results in client.get_results(): assert SEND_DATA_STRING in str(results.stdout) @@ -119,9 +137,18 @@ def test_s2n_server_buffered_send(managed_process, cipher, provider, protocol, c @pytest.mark.parametrize("protocol", PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("buffer_size", SEND_BUFFER_SIZES, ids=get_parameter_name) -@pytest.mark.parametrize("fragment_preference", FRAGMENT_PREFERENCE, ids=get_parameter_name) -def test_s2n_client_buffered_send(managed_process, cipher, provider, protocol, certificate, buffer_size, - fragment_preference): +@pytest.mark.parametrize( + "fragment_preference", FRAGMENT_PREFERENCE, ids=get_parameter_name +) +def test_s2n_client_buffered_send( + managed_process, + cipher, + provider, + protocol, + certificate, + buffer_size, + fragment_preference, +): # Communication Timeline # Client [S2N] | Server [S2N|OpenSSL] # Handshake | Handshake @@ -129,7 +156,7 @@ def test_s2n_client_buffered_send(managed_process, cipher, provider, protocol, c # Close | Close on CLOSE_MARKER port = next(available_ports) - extra_flags = ['--buffered-send', buffer_size] + extra_flags = ["--buffered-send", buffer_size] if fragment_preference is not None: extra_flags.append(fragment_preference) @@ -140,7 +167,8 @@ def test_s2n_client_buffered_send(managed_process, cipher, provider, protocol, c data_to_send=SEND_DATA, insecure=True, protocol=protocol, - extra_flags=extra_flags) + extra_flags=extra_flags, + ) provider_server_options = ProviderOptions( mode=Provider.ServerMode, @@ -150,10 +178,12 @@ def test_s2n_client_buffered_send(managed_process, cipher, provider, protocol, c protocol=protocol, key=certificate.key, cert=certificate.cert, - verbose=False) + verbose=False, + ) - server = managed_process(provider, provider_server_options, - close_marker=CLOSE_MARKER) + server = managed_process( + provider, provider_server_options, close_marker=CLOSE_MARKER + ) client = managed_process(S2N, s2n_client_options) for results in client.get_results(): diff --git a/tests/integrationv2/test_client_authentication.py b/tests/integrationv2/test_client_authentication.py index 87563c997da..50b91e1d243 100644 --- a/tests/integrationv2/test_client_authentication.py +++ b/tests/integrationv2/test_client_authentication.py @@ -3,11 +3,16 @@ import copy import pytest -from configuration import (available_ports, ALL_TEST_CIPHERS, PROTOCOLS) +from configuration import available_ports, ALL_TEST_CIPHERS, PROTOCOLS from common import Certificates, ProviderOptions, Protocols, data_bytes, Signatures from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N, GnuTLS, OpenSSL -from utils import invalid_test_parameters, get_parameter_name, get_expected_s2n_version, to_bytes +from utils import ( + invalid_test_parameters, + get_parameter_name, + get_expected_s2n_version, + to_bytes, +) # If we test every available cert, the test takes too long. # Choose a good representative subset. @@ -22,20 +27,27 @@ def assert_openssl_handshake_complete(results, is_complete=True): if is_complete: - assert b'read finished' in results.stderr - assert b'write finished' in results.stderr + assert b"read finished" in results.stderr + assert b"write finished" in results.stderr else: - assert b'read finished' not in results.stderr or b'write finished' not in results.stderr + assert ( + b"read finished" not in results.stderr + or b"write finished" not in results.stderr + ) def assert_s2n_handshake_complete(results, protocol, provider, is_complete=True): expected_version = get_expected_s2n_version(protocol, provider) if is_complete: - assert to_bytes("Actual protocol version: {}".format( - expected_version)) in results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + in results.stdout + ) else: - assert to_bytes("Actual protocol version: {}".format( - expected_version)) not in results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + not in results.stdout + ) @pytest.mark.uncollect_if(func=invalid_test_parameters) @@ -45,8 +57,15 @@ def assert_s2n_handshake_complete(results, protocol, provider, is_complete=True) @pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("certificate", CERTS_TO_TEST, ids=get_parameter_name) @pytest.mark.parametrize("client_certificate", CERTS_TO_TEST, ids=get_parameter_name) -def test_client_auth_with_s2n_server(managed_process, provider, other_provider, protocol, cipher, certificate, - client_certificate): +def test_client_auth_with_s2n_server( + managed_process, + provider, + other_provider, + protocol, + cipher, + certificate, + client_certificate, +): port = next(available_ports) random_bytes = data_bytes(64) @@ -60,7 +79,8 @@ def test_client_auth_with_s2n_server(managed_process, provider, other_provider, cert=client_certificate.cert, trust_store=certificate.cert, insecure=False, - protocol=protocol) + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.data_to_send = None @@ -75,8 +95,8 @@ def test_client_auth_with_s2n_server(managed_process, provider, other_provider, # Openssl should send a client certificate and complete the handshake for results in client.get_results(): results.assert_success() - assert b'write client certificate' in results.stderr - assert b'write certificate verify' in results.stderr + assert b"write client certificate" in results.stderr + assert b"write certificate verify" in results.stderr assert_openssl_handshake_complete(results) # S2N should successfully connect @@ -93,21 +113,29 @@ def test_client_auth_with_s2n_server(managed_process, provider, other_provider, @pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("certificate", CERTS_TO_TEST, ids=get_parameter_name) @pytest.mark.parametrize("client_certificate", CERTS_TO_TEST, ids=get_parameter_name) -def test_client_auth_with_s2n_server_using_nonmatching_certs(managed_process, provider, other_provider, protocol, - cipher, certificate, client_certificate): +def test_client_auth_with_s2n_server_using_nonmatching_certs( + managed_process, + provider, + other_provider, + protocol, + cipher, + certificate, + client_certificate, +): port = next(available_ports) client_options = ProviderOptions( mode=Provider.ClientMode, port=port, cipher=cipher, - data_to_send=b'', + data_to_send=b"", use_client_auth=True, key=client_certificate.key, cert=client_certificate.cert, trust_store=certificate.cert, insecure=False, - protocol=protocol) + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.data_to_send = None @@ -124,8 +152,8 @@ def test_client_auth_with_s2n_server_using_nonmatching_certs(managed_process, pr # Openssl should tell us that a certificate was sent, but the handshake did not complete for results in client.get_results(): assert results.exception is None - assert b'write client certificate' in results.stderr - assert b'write certificate verify' in results.stderr + assert b"write client certificate" in results.stderr + assert b"write certificate verify" in results.stderr # TLS1.3 OpenSSL fails after the handshake, but pre-TLS1.3 fails during if protocol is not Protocols.TLS13: assert results.exit_code != 0 @@ -135,8 +163,8 @@ def test_client_auth_with_s2n_server_using_nonmatching_certs(managed_process, pr for results in server.get_results(): assert results.exception is None assert results.exit_code != 0 - assert b'Certificate is untrusted' in results.stderr - assert b'Error: Mutual Auth was required, but not negotiated' in results.stderr + assert b"Certificate is untrusted" in results.stderr + assert b"Error: Mutual Auth was required, but not negotiated" in results.stderr assert_s2n_handshake_complete(results, protocol, provider, False) @@ -146,7 +174,9 @@ def test_client_auth_with_s2n_server_using_nonmatching_certs(managed_process, pr @pytest.mark.parametrize("protocol", PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("certificate", CERTS_TO_TEST, ids=get_parameter_name) -def test_client_auth_with_s2n_client_no_cert(managed_process, provider, other_provider, protocol, cipher, certificate): +def test_client_auth_with_s2n_client_no_cert( + managed_process, provider, other_provider, protocol, cipher, certificate +): port = next(available_ports) random_bytes = data_bytes(64) @@ -158,7 +188,8 @@ def test_client_auth_with_s2n_client_no_cert(managed_process, provider, other_pr use_client_auth=True, trust_store=certificate.cert, insecure=False, - protocol=protocol) + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.data_to_send = None @@ -172,8 +203,8 @@ def test_client_auth_with_s2n_client_no_cert(managed_process, provider, other_pr # Openssl should tell us that a cert was requested but not received for results in server.get_results(): results.assert_success() - assert b'write certificate request' in results.stderr - assert b'read client certificate' not in results.stderr + assert b"write certificate request" in results.stderr + assert b"read client certificate" not in results.stderr assert b"peer did not return a certificate" in results.stderr assert_openssl_handshake_complete(results, False) @@ -181,7 +212,7 @@ def test_client_auth_with_s2n_client_no_cert(managed_process, provider, other_pr assert results.exception is None # TLS1.3 OpenSSL fails after the handshake, but pre-TLS1.3 fails during if protocol is not Protocols.TLS13: - assert (results.exit_code != 0) + assert results.exit_code != 0 assert b"Failed to negotiate: 'TLS alert received'" in results.stderr assert_s2n_handshake_complete(results, protocol, provider, False) @@ -193,8 +224,15 @@ def test_client_auth_with_s2n_client_no_cert(managed_process, provider, other_pr @pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("certificate", CERTS_TO_TEST, ids=get_parameter_name) @pytest.mark.parametrize("client_certificate", CERTS_TO_TEST, ids=get_parameter_name) -def test_client_auth_with_s2n_client_with_cert(managed_process, provider, other_provider, protocol, cipher, certificate, - client_certificate): +def test_client_auth_with_s2n_client_with_cert( + managed_process, + provider, + other_provider, + protocol, + cipher, + certificate, + client_certificate, +): port = next(available_ports) random_bytes = data_bytes(64) @@ -208,7 +246,8 @@ def test_client_auth_with_s2n_client_with_cert(managed_process, provider, other_ cert=client_certificate.cert, trust_store=certificate.cert, insecure=False, - protocol=protocol) + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.data_to_send = None @@ -229,8 +268,8 @@ def test_client_auth_with_s2n_client_with_cert(managed_process, provider, other_ for results in server.get_results(): results.assert_success() assert random_bytes[1:] in results.stdout - assert b'read client certificate' in results.stderr - assert b'read certificate verify' in results.stderr + assert b"read client certificate" in results.stderr + assert b"read certificate verify" in results.stderr assert_openssl_handshake_complete(results) @@ -245,7 +284,11 @@ def test_client_auth_with_s2n_client_with_cert(managed_process, provider, other_ """ -@pytest.mark.parametrize("certificate", [Certificates.RSA_2048_PKCS1, Certificates.ECDSA_256], ids=get_parameter_name) +@pytest.mark.parametrize( + "certificate", + [Certificates.RSA_2048_PKCS1, Certificates.ECDSA_256], + ids=get_parameter_name, +) def test_tls_12_client_auth_downgrade(managed_process, certificate): port = next(available_ports) @@ -306,9 +349,11 @@ def test_tls_12_client_auth_downgrade(managed_process, certificate): for results in server.get_results(): results.assert_success() - assert to_bytes( - f"Actual protocol version: {expected_protocol_version}" - ) in results.stdout - assert to_bytes( - f"Client signature negotiated: {expected_signature_type}" - ) in results.stdout + assert ( + to_bytes(f"Actual protocol version: {expected_protocol_version}") + in results.stdout + ) + assert ( + to_bytes(f"Client signature negotiated: {expected_signature_type}") + in results.stdout + ) diff --git a/tests/integrationv2/test_cross_compatibility.py b/tests/integrationv2/test_cross_compatibility.py index cfbb99d60c3..c3c49f88da8 100644 --- a/tests/integrationv2/test_cross_compatibility.py +++ b/tests/integrationv2/test_cross_compatibility.py @@ -4,7 +4,12 @@ import copy import os -from configuration import available_ports, ALL_TEST_CIPHERS, ALL_TEST_CURVES, ALL_TEST_CERTS +from configuration import ( + available_ports, + ALL_TEST_CIPHERS, + ALL_TEST_CURVES, + ALL_TEST_CERTS, +) from common import ProviderOptions, Protocols, data_bytes from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N, OpenSSL @@ -13,7 +18,7 @@ S2N_RESUMPTION_MARKER = to_bytes("Resumed session") CLOSE_MARKER_BYTES = data_bytes(10) -TICKET_FILE = 'ticket' +TICKET_FILE = "ticket" RESUMPTION_PROTOCOLS = [Protocols.TLS12, Protocols.TLS13] @@ -30,8 +35,16 @@ @pytest.mark.parametrize("protocol", RESUMPTION_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -def test_s2n_old_server_new_ticket(managed_process, tmp_path, cipher, curve, certificate, protocol, provider, - other_provider): +def test_s2n_old_server_new_ticket( + managed_process, + tmp_path, + cipher, + curve, + certificate, + protocol, + provider, + other_provider, +): ticket_file = str(tmp_path / TICKET_FILE) assert not os.path.exists(ticket_file) @@ -46,7 +59,7 @@ def test_s2n_old_server_new_ticket(managed_process, tmp_path, cipher, curve, cer client_options = copy.copy(options) client_options.mode = Provider.ClientMode - client_options.extra_flags = ['-sess_out', ticket_file] + client_options.extra_flags = ["-sess_out", ticket_file] server_options = copy.copy(options) server_options.mode = Provider.ServerMode @@ -54,10 +67,10 @@ def test_s2n_old_server_new_ticket(managed_process, tmp_path, cipher, curve, cer server_options.cert = certificate.cert server_options.data_to_send = CLOSE_MARKER_BYTES - s2n_server = managed_process( - S2N, server_options, send_marker=S2N.get_send_marker()) - client = managed_process(provider, client_options, - close_marker=str(CLOSE_MARKER_BYTES)) + s2n_server = managed_process(S2N, server_options, send_marker=S2N.get_send_marker()) + client = managed_process( + provider, client_options, close_marker=str(CLOSE_MARKER_BYTES) + ) for results in client.get_results(): results.assert_success() @@ -66,13 +79,13 @@ def test_s2n_old_server_new_ticket(managed_process, tmp_path, cipher, curve, cer results.assert_success() assert os.path.exists(ticket_file) - client_options.extra_flags = ['-sess_in', ticket_file] + client_options.extra_flags = ["-sess_in", ticket_file] server_options.use_mainline_version = True - s2n_server = managed_process( - S2N, server_options, send_marker=S2N.get_send_marker()) - client = managed_process(provider, client_options, - close_marker=str(CLOSE_MARKER_BYTES)) + s2n_server = managed_process(S2N, server_options, send_marker=S2N.get_send_marker()) + client = managed_process( + provider, client_options, close_marker=str(CLOSE_MARKER_BYTES) + ) for results in client.get_results(): results.assert_success() @@ -95,8 +108,16 @@ def test_s2n_old_server_new_ticket(managed_process, tmp_path, cipher, curve, cer @pytest.mark.parametrize("protocol", RESUMPTION_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -def test_s2n_new_server_old_ticket(managed_process, tmp_path, cipher, curve, certificate, protocol, provider, - other_provider): +def test_s2n_new_server_old_ticket( + managed_process, + tmp_path, + cipher, + curve, + certificate, + protocol, + provider, + other_provider, +): ticket_file = str(tmp_path / TICKET_FILE) assert not os.path.exists(ticket_file) @@ -111,7 +132,7 @@ def test_s2n_new_server_old_ticket(managed_process, tmp_path, cipher, curve, cer client_options = copy.copy(options) client_options.mode = Provider.ClientMode - client_options.extra_flags = ['-sess_out', ticket_file] + client_options.extra_flags = ["-sess_out", ticket_file] server_options = copy.copy(options) server_options.mode = Provider.ServerMode @@ -120,10 +141,10 @@ def test_s2n_new_server_old_ticket(managed_process, tmp_path, cipher, curve, cer server_options.cert = certificate.cert server_options.data_to_send = CLOSE_MARKER_BYTES - s2n_server = managed_process( - S2N, server_options, send_marker=S2N.get_send_marker()) - client = managed_process(provider, client_options, - close_marker=str(CLOSE_MARKER_BYTES)) + s2n_server = managed_process(S2N, server_options, send_marker=S2N.get_send_marker()) + client = managed_process( + provider, client_options, close_marker=str(CLOSE_MARKER_BYTES) + ) for results in client.get_results(): results.assert_success() @@ -132,13 +153,13 @@ def test_s2n_new_server_old_ticket(managed_process, tmp_path, cipher, curve, cer results.assert_success() assert os.path.exists(ticket_file) - client_options.extra_flags = ['-sess_in', ticket_file] + client_options.extra_flags = ["-sess_in", ticket_file] server_options.use_mainline_version = False - s2n_server = managed_process( - S2N, server_options, send_marker=S2N.get_send_marker()) - client = managed_process(provider, client_options, - close_marker=str(CLOSE_MARKER_BYTES)) + s2n_server = managed_process(S2N, server_options, send_marker=S2N.get_send_marker()) + client = managed_process( + provider, client_options, close_marker=str(CLOSE_MARKER_BYTES) + ) for results in client.get_results(): results.assert_success() @@ -162,8 +183,16 @@ def test_s2n_new_server_old_ticket(managed_process, tmp_path, cipher, curve, cer @pytest.mark.parametrize("protocol", RESUMPTION_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -def test_s2n_old_client_new_ticket(managed_process, tmp_path, cipher, curve, certificate, protocol, provider, - other_provider): +def test_s2n_old_client_new_ticket( + managed_process, + tmp_path, + cipher, + curve, + certificate, + protocol, + provider, + other_provider, +): ticket_file = str(tmp_path / TICKET_FILE) assert not os.path.exists(ticket_file) @@ -178,7 +207,7 @@ def test_s2n_old_client_new_ticket(managed_process, tmp_path, cipher, curve, cer client_options = copy.copy(options) client_options.mode = Provider.ClientMode - client_options.extra_flags = ['--ticket-out', ticket_file] + client_options.extra_flags = ["--ticket-out", ticket_file] server_options = copy.copy(options) server_options.mode = Provider.ServerMode @@ -195,7 +224,7 @@ def test_s2n_old_client_new_ticket(managed_process, tmp_path, cipher, curve, cer results.assert_success() assert os.path.exists(ticket_file) - client_options.extra_flags = ['--ticket-in', ticket_file] + client_options.extra_flags = ["--ticket-in", ticket_file] client_options.use_mainline_version = True server = managed_process(provider, server_options) @@ -223,8 +252,16 @@ def test_s2n_old_client_new_ticket(managed_process, tmp_path, cipher, curve, cer @pytest.mark.parametrize("protocol", RESUMPTION_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -def test_s2n_new_client_old_ticket(managed_process, tmp_path, cipher, curve, certificate, protocol, provider, - other_provider): +def test_s2n_new_client_old_ticket( + managed_process, + tmp_path, + cipher, + curve, + certificate, + protocol, + provider, + other_provider, +): ticket_file = str(tmp_path / TICKET_FILE) assert not os.path.exists(ticket_file) @@ -239,7 +276,7 @@ def test_s2n_new_client_old_ticket(managed_process, tmp_path, cipher, curve, cer client_options = copy.copy(options) client_options.mode = Provider.ClientMode - client_options.extra_flags = ['--ticket-out', ticket_file] + client_options.extra_flags = ["--ticket-out", ticket_file] client_options.use_mainline_version = True server_options = copy.copy(options) @@ -257,7 +294,7 @@ def test_s2n_new_client_old_ticket(managed_process, tmp_path, cipher, curve, cer results.assert_success() assert os.path.exists(ticket_file) - client_options.extra_flags = ['--ticket-in', ticket_file] + client_options.extra_flags = ["--ticket-in", ticket_file] client_options.use_mainline_version = False server = managed_process(provider, server_options) diff --git a/tests/integrationv2/test_dynamic_record_sizes.py b/tests/integrationv2/test_dynamic_record_sizes.py index 6af004c5b7e..5924db5c010 100644 --- a/tests/integrationv2/test_dynamic_record_sizes.py +++ b/tests/integrationv2/test_dynamic_record_sizes.py @@ -3,11 +3,22 @@ import copy import pytest -from configuration import available_ports, ALL_TEST_CIPHERS, ALL_TEST_CURVES, ALL_TEST_CERTS, PROTOCOLS +from configuration import ( + available_ports, + ALL_TEST_CIPHERS, + ALL_TEST_CURVES, + ALL_TEST_CERTS, + PROTOCOLS, +) from common import ProviderOptions, data_bytes from fixtures import custom_mtu, managed_process # lgtm [py/unused-import] from providers import Provider, S2N, OpenSSL, Tcpdump -from utils import invalid_test_parameters, get_parameter_name, get_expected_s2n_version, to_bytes +from utils import ( + invalid_test_parameters, + get_parameter_name, + get_expected_s2n_version, + to_bytes, +) def find_fragmented_packet(results): @@ -19,15 +30,15 @@ def find_fragmented_packet(results): to the output line. This happens even when using `-nn` on the command line. That is why we need two ways to detect the length of the packet. """ - for line in results.decode('utf-8').split('\n'): - pieces = line.split(' ') + for line in results.decode("utf-8").split("\n"): + pieces = line.split(" ") if len(pieces) < 3: continue packet_len = 0 - if pieces[-2] == 'length': + if pieces[-2] == "length": packet_len = int(pieces[-1]) - elif pieces[-3] == 'length': + elif pieces[-3] == "length": # In this case the length has a colon `1234:`, so we must trim it. packet_len = int(pieces[-2][:-1]) @@ -44,8 +55,16 @@ def find_fragmented_packet(results): @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("protocol", PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("certificate", ALL_TEST_CERTS, ids=get_parameter_name) -def test_s2n_client_dynamic_record(custom_mtu, managed_process, cipher, curve, provider, other_provider, protocol, - certificate): +def test_s2n_client_dynamic_record( + custom_mtu, + managed_process, + cipher, + curve, + provider, + other_provider, + protocol, + certificate, +): port = next(available_ports) # 16384 bytes is enough to reliably get a packet that will exceed the MTU @@ -56,7 +75,8 @@ def test_s2n_client_dynamic_record(custom_mtu, managed_process, cipher, curve, p cipher=cipher, data_to_send=bytes_to_send, insecure=True, - protocol=protocol) + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.data_to_send = None @@ -74,8 +94,10 @@ def test_s2n_client_dynamic_record(custom_mtu, managed_process, cipher, curve, p for results in client.get_results(): results.assert_success() - assert to_bytes("Actual protocol version: {}".format( - expected_version)) in results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + in results.stdout + ) for results in server.get_results(): results.assert_success() diff --git a/tests/integrationv2/test_early_data.py b/tests/integrationv2/test_early_data.py index 9872941480c..ba3ac788723 100644 --- a/tests/integrationv2/test_early_data.py +++ b/tests/integrationv2/test_early_data.py @@ -4,7 +4,12 @@ import os import pytest -from configuration import available_ports, ALL_TEST_CURVES, ALL_TEST_CERTS, TLS13_CIPHERS +from configuration import ( + available_ports, + ALL_TEST_CURVES, + ALL_TEST_CERTS, + TLS13_CIPHERS, +) from common import ProviderOptions, Protocols, Curves, data_bytes from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N as S2NBase, OpenSSL as OpenSSLBase @@ -12,8 +17,8 @@ from test_hello_retry_requests import S2N_HRR_MARKER -TICKET_FILE = 'ticket' -EARLY_DATA_FILE = 'early_data' +TICKET_FILE = "ticket" +EARLY_DATA_FILE = "early_data" MAX_EARLY_DATA = 500 # Arbitrary largish number DATA_TO_SEND = data_bytes(500) # Arbitrary large number @@ -23,19 +28,17 @@ S2N_DEFAULT_CURVE = Curves.X25519 # We have no plans to support this curve any time soon -S2N_UNSUPPORTED_CURVE = 'X448' -S2N_HRR_CURVES = list( - curve for curve in ALL_TEST_CURVES if curve != S2N_DEFAULT_CURVE) +S2N_UNSUPPORTED_CURVE = "X448" +S2N_HRR_CURVES = list(curve for curve in ALL_TEST_CURVES if curve != S2N_DEFAULT_CURVE) S2N_EARLY_DATA_MARKER = to_bytes("WITH_EARLY_DATA") S2N_EARLY_DATA_RECV_MARKER = "Early Data received: " S2N_EARLY_DATA_STATUS_MARKER = "Early Data status: {status}" -S2N_EARLY_DATA_ACCEPTED_MARKER = S2N_EARLY_DATA_STATUS_MARKER.format( - status="ACCEPTED") -S2N_EARLY_DATA_REJECTED_MARKER = S2N_EARLY_DATA_STATUS_MARKER.format( - status="REJECTED") +S2N_EARLY_DATA_ACCEPTED_MARKER = S2N_EARLY_DATA_STATUS_MARKER.format(status="ACCEPTED") +S2N_EARLY_DATA_REJECTED_MARKER = S2N_EARLY_DATA_STATUS_MARKER.format(status="REJECTED") S2N_EARLY_DATA_NOT_REQUESTED_MARKER = S2N_EARLY_DATA_STATUS_MARKER.format( - status="NOT REQUESTED") + status="NOT REQUESTED" +) class S2N(S2NBase): @@ -46,12 +49,12 @@ def setup_client(self): cmd_line = S2NBase.setup_client(self) early_data_file = self.options.early_data_file if early_data_file and os.path.exists(early_data_file): - cmd_line.extend(['--early-data', early_data_file]) + cmd_line.extend(["--early-data", early_data_file]) return cmd_line def setup_server(self): cmd_line = S2NBase.setup_server(self) - cmd_line.extend(['--max-early-data', self.options.max_early_data]) + cmd_line.extend(["--max-early-data", self.options.max_early_data]) return cmd_line @@ -63,19 +66,19 @@ def setup_client(self): cmd_line = OpenSSLBase.setup_client(self) early_data_file = self.options.early_data_file if early_data_file and os.path.exists(early_data_file): - cmd_line.extend(['-early_data', early_data_file]) + cmd_line.extend(["-early_data", early_data_file]) ticket_file = self.options.ticket_file if ticket_file: if os.path.exists(ticket_file): - cmd_line.extend(['-sess_in', ticket_file]) + cmd_line.extend(["-sess_in", ticket_file]) else: - cmd_line.extend(['-sess_out', self.options.ticket_file]) + cmd_line.extend(["-sess_out", self.options.ticket_file]) return cmd_line def setup_server(self): cmd_line = OpenSSLBase.setup_server(self) if self.options.max_early_data > 0: - cmd_line.extend(['-early_data']) + cmd_line.extend(["-early_data"]) return cmd_line @@ -89,7 +92,7 @@ def setup_server(self): def get_early_data_bytes(file_path, early_data_size): early_data = data_bytes(early_data_size) - with open(file_path, 'wb') as fout: + with open(file_path, "wb") as fout: fout.write(early_data) return early_data @@ -118,16 +121,10 @@ def get_ticket_from_s2n_server(options, managed_process, provider, certificate): assert not os.path.exists(options.ticket_file) s2n_server = managed_process( - S2N, - server_options, - send_marker=S2N.get_send_marker(), - timeout=10 + S2N, server_options, send_marker=S2N.get_send_marker(), timeout=10 ) client = managed_process( - provider, - client_options, - close_marker=str(close_marker_bytes), - timeout=10 + provider, client_options, close_marker=str(close_marker_bytes), timeout=10 ) for results in s2n_server.get_results(): @@ -163,9 +160,21 @@ def test_nothing(): @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("provider", CLIENT_PROVIDERS, ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -@pytest.mark.parametrize("early_data_size", [int(MAX_EARLY_DATA/2), int(MAX_EARLY_DATA-1), MAX_EARLY_DATA, 1]) -def test_s2n_server_with_early_data(managed_process, tmp_path, cipher, curve, certificate, protocol, provider, - other_provider, early_data_size): +@pytest.mark.parametrize( + "early_data_size", + [int(MAX_EARLY_DATA / 2), int(MAX_EARLY_DATA - 1), MAX_EARLY_DATA, 1], +) +def test_s2n_server_with_early_data( + managed_process, + tmp_path, + cipher, + curve, + certificate, + protocol, + provider, + other_provider, + early_data_size, +): ticket_file = str(tmp_path / TICKET_FILE) early_data_file = str(tmp_path / EARLY_DATA_FILE) early_data = get_early_data_bytes(early_data_file, early_data_size) @@ -200,8 +209,7 @@ def test_s2n_server_with_early_data(managed_process, tmp_path, cipher, curve, ce for results in s2n_server.get_results(): results.assert_success() assert S2N_EARLY_DATA_MARKER in results.stdout - assert (to_bytes(S2N_EARLY_DATA_RECV_MARKER) + - early_data) in results.stdout + assert (to_bytes(S2N_EARLY_DATA_RECV_MARKER) + early_data) in results.stdout assert to_bytes(S2N_EARLY_DATA_ACCEPTED_MARKER) in results.stdout assert DATA_TO_SEND in results.stdout @@ -220,9 +228,20 @@ def test_s2n_server_with_early_data(managed_process, tmp_path, cipher, curve, ce @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("provider", SERVER_PROVIDERS, ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -@pytest.mark.parametrize("early_data_size", [int(MAX_EARLY_DATA/2), int(MAX_EARLY_DATA-1), MAX_EARLY_DATA, 1]) -def test_s2n_client_with_early_data(managed_process, tmp_path, cipher, certificate, protocol, provider, other_provider, - early_data_size): +@pytest.mark.parametrize( + "early_data_size", + [int(MAX_EARLY_DATA / 2), int(MAX_EARLY_DATA - 1), MAX_EARLY_DATA, 1], +) +def test_s2n_client_with_early_data( + managed_process, + tmp_path, + cipher, + certificate, + protocol, + provider, + other_provider, + early_data_size, +): early_data_file = str(tmp_path / EARLY_DATA_FILE) early_data = get_early_data_bytes(early_data_file, early_data_size) @@ -253,8 +272,10 @@ def test_s2n_client_with_early_data(managed_process, tmp_path, cipher, certifica for results in s2n_client.get_results(): results.assert_success() assert S2N_EARLY_DATA_MARKER in results.stdout - assert results.stdout.count( - to_bytes(S2N_EARLY_DATA_ACCEPTED_MARKER)) == NUM_RESUMES + assert ( + results.stdout.count(to_bytes(S2N_EARLY_DATA_ACCEPTED_MARKER)) + == NUM_RESUMES + ) for results in server.get_results(): results.assert_success() @@ -275,8 +296,9 @@ def test_s2n_client_with_early_data(managed_process, tmp_path, cipher, certifica @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("provider", SERVER_PROVIDERS, ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -def test_s2n_client_without_early_data(managed_process, tmp_path, cipher, certificate, protocol, provider, - other_provider): +def test_s2n_client_without_early_data( + managed_process, tmp_path, cipher, certificate, protocol, provider, other_provider +): early_data_file = str(tmp_path / EARLY_DATA_FILE) early_data = get_early_data_bytes(early_data_file, MAX_EARLY_DATA) @@ -311,8 +333,10 @@ def test_s2n_client_without_early_data(managed_process, tmp_path, cipher, certif for results in s2n_client.get_results(): results.assert_success() assert S2N_EARLY_DATA_MARKER not in results.stdout - assert results.stdout.count( - to_bytes(S2N_EARLY_DATA_NOT_REQUESTED_MARKER)) == NUM_CONNECTIONS + assert ( + results.stdout.count(to_bytes(S2N_EARLY_DATA_NOT_REQUESTED_MARKER)) + == NUM_CONNECTIONS + ) """ @@ -334,9 +358,21 @@ def test_s2n_client_without_early_data(managed_process, tmp_path, cipher, certif @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("provider", CLIENT_PROVIDERS, ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -@pytest.mark.parametrize("early_data_size", [int(MAX_EARLY_DATA/2), int(MAX_EARLY_DATA-1), MAX_EARLY_DATA, 1]) -def test_s2n_server_with_early_data_rejected(managed_process, tmp_path, cipher, curve, certificate, protocol, provider, - other_provider, early_data_size): +@pytest.mark.parametrize( + "early_data_size", + [int(MAX_EARLY_DATA / 2), int(MAX_EARLY_DATA - 1), MAX_EARLY_DATA, 1], +) +def test_s2n_server_with_early_data_rejected( + managed_process, + tmp_path, + cipher, + curve, + certificate, + protocol, + provider, + other_provider, + early_data_size, +): ticket_file = str(tmp_path / TICKET_FILE) early_data_file = str(tmp_path / EARLY_DATA_FILE) get_early_data_bytes(early_data_file, early_data_size) @@ -393,12 +429,25 @@ def test_s2n_server_with_early_data_rejected(managed_process, tmp_path, cipher, @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("provider", SERVER_PROVIDERS, ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -@pytest.mark.parametrize("early_data_size", [int(MAX_EARLY_DATA/2), int(MAX_EARLY_DATA-1), MAX_EARLY_DATA, 1]) -def test_s2n_client_with_early_data_rejected_via_hrr(managed_process, tmp_path, cipher, curve, certificate, protocol, - provider, other_provider, early_data_size): +@pytest.mark.parametrize( + "early_data_size", + [int(MAX_EARLY_DATA / 2), int(MAX_EARLY_DATA - 1), MAX_EARLY_DATA, 1], +) +def test_s2n_client_with_early_data_rejected_via_hrr( + managed_process, + tmp_path, + cipher, + curve, + certificate, + protocol, + provider, + other_provider, + early_data_size, +): if provider == S2N: pytest.skip( - "S2N does not respect ProviderOptions.curve, so does not trigger a retry") + "S2N does not respect ProviderOptions.curve, so does not trigger a retry" + ) early_data_file = str(tmp_path / EARLY_DATA_FILE) early_data = get_early_data_bytes(early_data_file, early_data_size) @@ -432,8 +481,10 @@ def test_s2n_client_with_early_data_rejected_via_hrr(managed_process, tmp_path, results.assert_success() assert S2N_EARLY_DATA_MARKER not in results.stdout assert S2N_HRR_MARKER in results.stdout - assert results.stdout.count( - to_bytes(S2N_EARLY_DATA_REJECTED_MARKER)) == NUM_RESUMES + assert ( + results.stdout.count(to_bytes(S2N_EARLY_DATA_REJECTED_MARKER)) + == NUM_RESUMES + ) for results in server.get_results(): results.assert_success() @@ -455,9 +506,21 @@ def test_s2n_client_with_early_data_rejected_via_hrr(managed_process, tmp_path, @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("provider", CLIENT_PROVIDERS, ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -@pytest.mark.parametrize("early_data_size", [int(MAX_EARLY_DATA/2), int(MAX_EARLY_DATA-1), MAX_EARLY_DATA, 1]) -def test_s2n_server_with_early_data_rejected_via_hrr(managed_process, tmp_path, cipher, curve, certificate, protocol, - provider, other_provider, early_data_size): +@pytest.mark.parametrize( + "early_data_size", + [int(MAX_EARLY_DATA / 2), int(MAX_EARLY_DATA - 1), MAX_EARLY_DATA, 1], +) +def test_s2n_server_with_early_data_rejected_via_hrr( + managed_process, + tmp_path, + cipher, + curve, + certificate, + protocol, + provider, + other_provider, + early_data_size, +): ticket_file = str(tmp_path / TICKET_FILE) early_data_file = str(tmp_path / EARLY_DATA_FILE) early_data = get_early_data_bytes(early_data_file, early_data_size) @@ -512,12 +575,22 @@ def test_s2n_server_with_early_data_rejected_via_hrr(managed_process, tmp_path, @pytest.mark.parametrize("provider", CLIENT_PROVIDERS, ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("excess_early_data", [1, 10, MAX_EARLY_DATA]) -def test_s2n_server_with_early_data_max_exceeded(managed_process, tmp_path, cipher, curve, certificate, protocol, - provider, other_provider, excess_early_data): +def test_s2n_server_with_early_data_max_exceeded( + managed_process, + tmp_path, + cipher, + curve, + certificate, + protocol, + provider, + other_provider, + excess_early_data, +): ticket_file = str(tmp_path / TICKET_FILE) early_data_file = str(tmp_path / EARLY_DATA_FILE) early_data = get_early_data_bytes( - early_data_file, MAX_EARLY_DATA + excess_early_data) + early_data_file, MAX_EARLY_DATA + excess_early_data + ) options = ProviderOptions( port=next(available_ports), @@ -560,6 +633,7 @@ def test_s2n_server_with_early_data_max_exceeded(managed_process, tmp_path, ciph # Full early data should not be reported assert early_data not in results.stdout # Partial early data should be reported - assert (to_bytes(S2N_EARLY_DATA_RECV_MARKER) + - early_data[:MAX_EARLY_DATA]) in results.stdout + assert ( + to_bytes(S2N_EARLY_DATA_RECV_MARKER) + early_data[:MAX_EARLY_DATA] + ) in results.stdout assert to_bytes("Bad message encountered") in results.stderr diff --git a/tests/integrationv2/test_external_psk.py b/tests/integrationv2/test_external_psk.py index f7aaf06a4ff..f62566dd799 100644 --- a/tests/integrationv2/test_external_psk.py +++ b/tests/integrationv2/test_external_psk.py @@ -2,7 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from configuration import available_ports, TLS13_CIPHERS, ALL_TEST_CURVES, ALL_TEST_CERTS +from configuration import ( + available_ports, + TLS13_CIPHERS, + ALL_TEST_CURVES, + ALL_TEST_CERTS, +) from common import ProviderOptions, Protocols, data_bytes from fixtures import managed_process # lgtm [py/unused-import] from providers import S2N, OpenSSL @@ -10,17 +15,19 @@ from enum import Enum, auto # Known value test vectors from https://tools.ietf.org/html/rfc8448#section-4 -known_psk_identity = '2c035d829359ee5ff7af4ec900000000262a6494dc486d2c8a34cb33fa90bf1b00'\ - '70ad3c498883c9367c09a2be785abc55cd226097a3a982117283f82a03a143efd3'\ - 'ff5dd36d64e861be7fd61d2827db279cce145077d454a3664d4e6da4d29ee03725'\ - 'a6a4dafcd0fc67d2aea70529513e3da2677fa5906c5b3f7d8f92f228bda40dda72'\ - '1470f9fbf297b5aea617646fac5c03272e970727c621a79141ef5f7de6505e5bfb'\ - 'c388e93343694093934ae4d357' -known_psk_secret = '4ecd0eb6ec3b4d87f5d6028f922ca4c5851a277fd41311c9e62d2c9492e1c4f3' +known_psk_identity = ( + "2c035d829359ee5ff7af4ec900000000262a6494dc486d2c8a34cb33fa90bf1b00" + "70ad3c498883c9367c09a2be785abc55cd226097a3a982117283f82a03a143efd3" + "ff5dd36d64e861be7fd61d2827db279cce145077d454a3664d4e6da4d29ee03725" + "a6a4dafcd0fc67d2aea70529513e3da2677fa5906c5b3f7d8f92f228bda40dda72" + "1470f9fbf297b5aea617646fac5c03272e970727c621a79141ef5f7de6505e5bfb" + "c388e93343694093934ae4d357" +) +known_psk_secret = "4ecd0eb6ec3b4d87f5d6028f922ca4c5851a277fd41311c9e62d2c9492e1c4f3" # Arbitrary test vectors -PSK_IDENTITY_LIST = [known_psk_identity, 'psk_identity', 'test_psk_identity'] -PSK_SECRET_LIST = [known_psk_secret, 'a6dadae4567876', 'a64dafcd0fc67d2a'] +PSK_IDENTITY_LIST = [known_psk_identity, "psk_identity", "test_psk_identity"] +PSK_SECRET_LIST = [known_psk_secret, "a6dadae4567876", "a64dafcd0fc67d2a"] PSK_IDENTITY_NO_MATCH = "PSK_IDENTITY_NO_MATCH" PSK_SECRET_NO_MATCH = "e9492e1c" PSK_IDENTITY_NO_MATCH_2 = "PSK_IDENTITY_NO_MATCH_2" @@ -37,14 +44,16 @@ class Outcome(Enum): def setup_s2n_psk_params(psk_identity, psk_secret, psk_hash_alg): - return ['--psk', psk_identity + ',' + psk_secret + ',' + psk_hash_alg] + return ["--psk", psk_identity + "," + psk_secret + "," + psk_hash_alg] def setup_openssl_psk_params(psk_identity, psk_secret): - return ['-psk_identity', psk_identity, '--psk', psk_secret] + return ["-psk_identity", psk_identity, "--psk", psk_secret] -def setup_provider_options(mode, port, cipher, curve, certificate, data_to_send, client_psk_params): +def setup_provider_options( + mode, port, cipher, curve, certificate, data_to_send, client_psk_params +): options = ProviderOptions( host="localhost", port=port, @@ -54,7 +63,8 @@ def setup_provider_options(mode, port, cipher, curve, certificate, data_to_send, protocol=Protocols.TLS13, data_to_send=data_to_send, mode=mode, - extra_flags=client_psk_params) + extra_flags=client_psk_params, + ) if certificate: options.key = certificate.key options.cert = certificate.cert @@ -64,10 +74,10 @@ def setup_provider_options(mode, port, cipher, curve, certificate, data_to_send, def get_psk_hash_alg_from_cipher(cipher): # S2N supports only SHA256 and SHA384 PSK Hash Algorithms - if 'SHA256' in cipher.name: - return 'SHA256' - elif 'SHA384' in cipher.name: - return 'SHA384' + if "SHA256" in cipher.name: + return "SHA256" + elif "SHA384" in cipher.name: + return "SHA384" else: return None @@ -79,29 +89,33 @@ def skip_invalid_psk_tests(provider, psk_hash_alg): # In OpenSSL, PSK works only with TLS1.3 ciphersuites based on SHA256 hash algorithm which includes # all TLS1.3 ciphersuites supported by S2N except TLS_AES_256_GCM_SHA384. - if provider == OpenSSL and psk_hash_alg == 'SHA384': + if provider == OpenSSL and psk_hash_alg == "SHA384": pytest.skip() def validate_negotiated_psk_s2n(outcome, psk_identity, results): if outcome == Outcome.psk_connection: - assert to_bytes("Negotiated PSK identity: {}".format( - psk_identity)) in results.stdout + assert ( + to_bytes("Negotiated PSK identity: {}".format(psk_identity)) + in results.stdout + ) elif outcome == Outcome.full_handshake: - assert to_bytes("Negotiated PSK identity: {}".format( - psk_identity)) not in results.stdout + assert ( + to_bytes("Negotiated PSK identity: {}".format(psk_identity)) + not in results.stdout + ) else: assert results.exit_code != 0 - assert to_bytes( - "Failed to negotiate: 'TLS alert received'") in results.stderr + assert to_bytes("Failed to negotiate: 'TLS alert received'") in results.stderr def validate_negotiated_psk_openssl(outcome, results): if outcome == Outcome.psk_connection: - assert to_bytes("extension \"psk\"") in results.stdout + assert to_bytes('extension "psk"') in results.stdout elif outcome == Outcome.full_handshake: - assert to_bytes( - "SSL_connect:SSLv3/TLS read server certificate") in results.stderr + assert ( + to_bytes("SSL_connect:SSLv3/TLS read server certificate") in results.stderr + ) else: assert to_bytes("SSL_accept:error in error") in results.stderr @@ -130,42 +144,49 @@ def test_nothing(): @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("psk_identity", PSK_IDENTITY_LIST, ids=get_parameter_name) @pytest.mark.parametrize("psk_secret", PSK_SECRET_LIST, ids=get_parameter_name) -def test_s2n_server_psk_connection(managed_process, cipher, curve, protocol, provider, other_provider, psk_identity, - psk_secret): +def test_s2n_server_psk_connection( + managed_process, + cipher, + curve, + protocol, + provider, + other_provider, + psk_identity, + psk_secret, +): port = next(available_ports) random_bytes = data_bytes(10) psk_hash_alg = get_psk_hash_alg_from_cipher(cipher) skip_invalid_psk_tests(provider, psk_hash_alg) if provider == S2N: - client_psk_params = setup_s2n_psk_params( - psk_identity, psk_secret, psk_hash_alg) + client_psk_params = setup_s2n_psk_params(psk_identity, psk_secret, psk_hash_alg) else: client_psk_params = setup_openssl_psk_params(psk_identity, psk_secret) client_options = setup_provider_options( - provider.ClientMode, port, cipher, curve, None, random_bytes, client_psk_params) + provider.ClientMode, port, cipher, curve, None, random_bytes, client_psk_params + ) - server_psk_params = setup_s2n_psk_params( - psk_identity, psk_secret, psk_hash_alg) + server_psk_params = setup_s2n_psk_params(psk_identity, psk_secret, psk_hash_alg) server_options = setup_provider_options( - S2N.ServerMode, port, cipher, curve, None, None, server_psk_params) + S2N.ServerMode, port, cipher, curve, None, None, server_psk_params + ) server = managed_process( - S2N, server_options, timeout=5, close_marker=str(random_bytes)) + S2N, server_options, timeout=5, close_marker=str(random_bytes) + ) client = managed_process(provider, client_options, timeout=5) for results in client.get_results(): results.assert_success() if provider == S2N: - validate_negotiated_psk_s2n( - Outcome.psk_connection, psk_identity, results) + validate_negotiated_psk_s2n(Outcome.psk_connection, psk_identity, results) else: validate_negotiated_psk_openssl(Outcome.psk_connection, results) for results in server.get_results(): results.assert_success() - validate_negotiated_psk_s2n( - Outcome.psk_connection, psk_identity, results) + validate_negotiated_psk_s2n(Outcome.psk_connection, psk_identity, results) assert random_bytes in results.stdout @@ -184,8 +205,16 @@ def test_s2n_server_psk_connection(managed_process, cipher, curve, protocol, pro @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("psk_identity", PSK_IDENTITY_LIST, ids=get_parameter_name) @pytest.mark.parametrize("psk_secret", PSK_SECRET_LIST, ids=get_parameter_name) -def test_s2n_server_multiple_psks(managed_process, cipher, curve, protocol, provider, other_provider, psk_identity, - psk_secret): +def test_s2n_server_multiple_psks( + managed_process, + cipher, + curve, + protocol, + provider, + other_provider, + psk_identity, + psk_secret, +): port = next(available_ports) random_bytes = data_bytes(10) psk_hash_alg = get_psk_hash_alg_from_cipher(cipher) @@ -197,41 +226,48 @@ def test_s2n_server_multiple_psks(managed_process, cipher, curve, protocol, prov OpenSSL Provider does not support multiple PSKs in the same connection, the last psk parameter is the psk parameter used in the connection. """ - client_psk_params.extend(setup_openssl_psk_params( - PSK_IDENTITY_NO_MATCH, PSK_SECRET_NO_MATCH)) client_psk_params.extend( - setup_openssl_psk_params(psk_identity, psk_secret)) + setup_openssl_psk_params(PSK_IDENTITY_NO_MATCH, PSK_SECRET_NO_MATCH) + ) + client_psk_params.extend(setup_openssl_psk_params(psk_identity, psk_secret)) else: - client_psk_params.extend(setup_s2n_psk_params( - PSK_IDENTITY_NO_MATCH, PSK_SECRET_NO_MATCH, psk_hash_alg)) - client_psk_params.extend(setup_s2n_psk_params( - psk_identity, psk_secret, psk_hash_alg)) + client_psk_params.extend( + setup_s2n_psk_params( + PSK_IDENTITY_NO_MATCH, PSK_SECRET_NO_MATCH, psk_hash_alg + ) + ) + client_psk_params.extend( + setup_s2n_psk_params(psk_identity, psk_secret, psk_hash_alg) + ) client_options = setup_provider_options( - provider.ClientMode, port, cipher, curve, None, random_bytes, client_psk_params) - - server_psk_params = setup_s2n_psk_params( - psk_identity, psk_secret, psk_hash_alg) - server_psk_params.extend(setup_s2n_psk_params( - PSK_IDENTITY_NO_MATCH_2, PSK_SECRET_NO_MATCH_2, psk_hash_alg)) + provider.ClientMode, port, cipher, curve, None, random_bytes, client_psk_params + ) + + server_psk_params = setup_s2n_psk_params(psk_identity, psk_secret, psk_hash_alg) + server_psk_params.extend( + setup_s2n_psk_params( + PSK_IDENTITY_NO_MATCH_2, PSK_SECRET_NO_MATCH_2, psk_hash_alg + ) + ) server_options = setup_provider_options( - S2N.ServerMode, port, cipher, curve, None, None, server_psk_params) + S2N.ServerMode, port, cipher, curve, None, None, server_psk_params + ) server = managed_process( - S2N, server_options, timeout=5, close_marker=str(random_bytes)) + S2N, server_options, timeout=5, close_marker=str(random_bytes) + ) client = managed_process(provider, client_options, timeout=5) for results in client.get_results(): results.assert_success() if provider == S2N: - validate_negotiated_psk_s2n( - Outcome.psk_connection, psk_identity, results) + validate_negotiated_psk_s2n(Outcome.psk_connection, psk_identity, results) else: validate_negotiated_psk_openssl(Outcome.psk_connection, results) for results in server.get_results(): results.assert_success() - validate_negotiated_psk_s2n( - Outcome.psk_connection, psk_identity, results) + validate_negotiated_psk_s2n(Outcome.psk_connection, psk_identity, results) assert random_bytes in results.stdout @@ -254,43 +290,63 @@ def test_s2n_server_multiple_psks(managed_process, cipher, curve, protocol, prov @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("psk_identity", PSK_IDENTITY_LIST, ids=get_parameter_name) @pytest.mark.parametrize("psk_secret", PSK_SECRET_LIST, ids=get_parameter_name) -@pytest.mark.parametrize("certificate", ALL_TEST_CERTS_WITH_EMPTY_CERT, ids=get_parameter_name) -def test_s2n_server_full_handshake(managed_process, cipher, curve, protocol, provider, other_provider, psk_identity, - psk_secret, certificate): +@pytest.mark.parametrize( + "certificate", ALL_TEST_CERTS_WITH_EMPTY_CERT, ids=get_parameter_name +) +def test_s2n_server_full_handshake( + managed_process, + cipher, + curve, + protocol, + provider, + other_provider, + psk_identity, + psk_secret, + certificate, +): port = next(available_ports) random_bytes = data_bytes(10) psk_hash_alg = get_psk_hash_alg_from_cipher(cipher) skip_invalid_psk_tests(provider, psk_hash_alg) if provider == S2N: - client_psk_params = setup_s2n_psk_params( - psk_identity, psk_secret, psk_hash_alg) + client_psk_params = setup_s2n_psk_params(psk_identity, psk_secret, psk_hash_alg) else: client_psk_params = setup_openssl_psk_params(psk_identity, psk_secret) client_options = setup_provider_options( - provider.ClientMode, port, cipher, curve, certificate, random_bytes, client_psk_params) + provider.ClientMode, + port, + cipher, + curve, + certificate, + random_bytes, + client_psk_params, + ) server_psk_params = setup_s2n_psk_params( - PSK_IDENTITY_NO_MATCH, PSK_SECRET_NO_MATCH, psk_hash_alg) + PSK_IDENTITY_NO_MATCH, PSK_SECRET_NO_MATCH, psk_hash_alg + ) server_options = setup_provider_options( - S2N.ServerMode, port, cipher, curve, certificate, None, server_psk_params) + S2N.ServerMode, port, cipher, curve, certificate, None, server_psk_params + ) server = managed_process( - S2N, server_options, timeout=5, close_marker=str(random_bytes)) + S2N, server_options, timeout=5, close_marker=str(random_bytes) + ) client = managed_process(provider, client_options, timeout=5) for results in client.get_results(): results.assert_success() if provider == S2N: - validate_negotiated_psk_s2n( - Outcome.full_handshake, psk_identity, results) + validate_negotiated_psk_s2n(Outcome.full_handshake, psk_identity, results) else: validate_negotiated_psk_openssl(Outcome.full_handshake, results) for results in server.get_results(): results.assert_success() validate_negotiated_psk_s2n( - Outcome.full_handshake, PSK_IDENTITY_NO_MATCH, results) + Outcome.full_handshake, PSK_IDENTITY_NO_MATCH, results + ) assert random_bytes in results.stdout @@ -309,41 +365,48 @@ def test_s2n_server_full_handshake(managed_process, cipher, curve, protocol, pro @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("psk_identity", PSK_IDENTITY_LIST, ids=get_parameter_name) @pytest.mark.parametrize("psk_secret", PSK_SECRET_LIST, ids=get_parameter_name) -def test_s2n_client_psk_connection(managed_process, cipher, curve, protocol, provider, other_provider, psk_identity, - psk_secret): +def test_s2n_client_psk_connection( + managed_process, + cipher, + curve, + protocol, + provider, + other_provider, + psk_identity, + psk_secret, +): port = next(available_ports) random_bytes = data_bytes(10) psk_hash_alg = get_psk_hash_alg_from_cipher(cipher) skip_invalid_psk_tests(provider, psk_hash_alg) - client_psk_params = setup_s2n_psk_params( - psk_identity, psk_secret, psk_hash_alg) + client_psk_params = setup_s2n_psk_params(psk_identity, psk_secret, psk_hash_alg) client_options = setup_provider_options( - S2N.ClientMode, port, cipher, curve, None, random_bytes, client_psk_params) + S2N.ClientMode, port, cipher, curve, None, random_bytes, client_psk_params + ) if provider == S2N: - server_psk_params = setup_s2n_psk_params( - psk_identity, psk_secret, psk_hash_alg) + server_psk_params = setup_s2n_psk_params(psk_identity, psk_secret, psk_hash_alg) else: server_psk_params = setup_openssl_psk_params(psk_identity, psk_secret) - server_psk_params += ['-nocert'] + server_psk_params += ["-nocert"] server_options = setup_provider_options( - provider.ServerMode, port, cipher, curve, None, None, server_psk_params) + provider.ServerMode, port, cipher, curve, None, None, server_psk_params + ) - server = managed_process(provider, server_options, - timeout=5, close_marker=str(random_bytes)) + server = managed_process( + provider, server_options, timeout=5, close_marker=str(random_bytes) + ) client = managed_process(S2N, client_options, timeout=5) for results in client.get_results(): results.assert_success() - validate_negotiated_psk_s2n( - Outcome.psk_connection, psk_identity, results) + validate_negotiated_psk_s2n(Outcome.psk_connection, psk_identity, results) for results in server.get_results(): results.assert_success() if provider == S2N: - validate_negotiated_psk_s2n( - Outcome.psk_connection, psk_identity, results) + validate_negotiated_psk_s2n(Outcome.psk_connection, psk_identity, results) else: validate_negotiated_psk_openssl(Outcome.psk_connection, results) assert random_bytes in results.stdout @@ -364,19 +427,28 @@ def test_s2n_client_psk_connection(managed_process, cipher, curve, protocol, pro @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("psk_identity", PSK_IDENTITY_LIST, ids=get_parameter_name) @pytest.mark.parametrize("psk_secret", PSK_SECRET_LIST, ids=get_parameter_name) -def test_s2n_client_multiple_psks(managed_process, cipher, curve, protocol, provider, other_provider, psk_identity, - psk_secret): +def test_s2n_client_multiple_psks( + managed_process, + cipher, + curve, + protocol, + provider, + other_provider, + psk_identity, + psk_secret, +): port = next(available_ports) random_bytes = data_bytes(10) psk_hash_alg = get_psk_hash_alg_from_cipher(cipher) skip_invalid_psk_tests(provider, psk_hash_alg) - client_psk_params = setup_s2n_psk_params( - psk_identity, psk_secret, psk_hash_alg) - client_psk_params.extend(setup_s2n_psk_params( - PSK_IDENTITY_NO_MATCH, PSK_SECRET_NO_MATCH, psk_hash_alg)) + client_psk_params = setup_s2n_psk_params(psk_identity, psk_secret, psk_hash_alg) + client_psk_params.extend( + setup_s2n_psk_params(PSK_IDENTITY_NO_MATCH, PSK_SECRET_NO_MATCH, psk_hash_alg) + ) client_options = setup_provider_options( - S2N.ClientMode, port, cipher, curve, None, random_bytes, client_psk_params) + S2N.ClientMode, port, cipher, curve, None, random_bytes, client_psk_params + ) server_psk_params = [] if provider == OpenSSL: @@ -384,33 +456,37 @@ def test_s2n_client_multiple_psks(managed_process, cipher, curve, protocol, prov OpenSSL Provider does not support multiple PSKs in the same connection, the last psk params is the final psk used in the connection. """ - server_psk_params.extend(setup_openssl_psk_params( - PSK_IDENTITY_NO_MATCH_2, PSK_SECRET_NO_MATCH_2)) server_psk_params.extend( - setup_openssl_psk_params(psk_identity, psk_secret)) - server_psk_params += ['-nocert'] + setup_openssl_psk_params(PSK_IDENTITY_NO_MATCH_2, PSK_SECRET_NO_MATCH_2) + ) + server_psk_params.extend(setup_openssl_psk_params(psk_identity, psk_secret)) + server_psk_params += ["-nocert"] else: - server_psk_params.extend(setup_s2n_psk_params( - PSK_IDENTITY_NO_MATCH_2, PSK_SECRET_NO_MATCH_2, psk_hash_alg)) - server_psk_params.extend(setup_s2n_psk_params( - psk_identity, psk_secret, psk_hash_alg)) + server_psk_params.extend( + setup_s2n_psk_params( + PSK_IDENTITY_NO_MATCH_2, PSK_SECRET_NO_MATCH_2, psk_hash_alg + ) + ) + server_psk_params.extend( + setup_s2n_psk_params(psk_identity, psk_secret, psk_hash_alg) + ) server_options = setup_provider_options( - provider.ServerMode, port, cipher, curve, None, None, server_psk_params) + provider.ServerMode, port, cipher, curve, None, None, server_psk_params + ) - server = managed_process(provider, server_options, - timeout=5, close_marker=str(random_bytes)) + server = managed_process( + provider, server_options, timeout=5, close_marker=str(random_bytes) + ) client = managed_process(S2N, client_options, timeout=5) for results in client.get_results(): results.assert_success() - validate_negotiated_psk_s2n( - Outcome.psk_connection, psk_identity, results) + validate_negotiated_psk_s2n(Outcome.psk_connection, psk_identity, results) for results in server.get_results(): results.assert_success() if provider == S2N: - validate_negotiated_psk_s2n( - Outcome.psk_connection, psk_identity, results) + validate_negotiated_psk_s2n(Outcome.psk_connection, psk_identity, results) else: validate_negotiated_psk_openssl(Outcome.psk_connection, results) assert random_bytes in results.stdout @@ -432,32 +508,35 @@ def test_s2n_client_multiple_psks(managed_process, cipher, curve, protocol, prov @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) @pytest.mark.parametrize("psk_identity", PSK_IDENTITY_LIST, ids=get_parameter_name) @pytest.mark.parametrize("psk_secret", PSK_SECRET_LIST, ids=get_parameter_name) -def test_s2n_client_psk_handshake_failure(managed_process, cipher, curve, protocol, provider, psk_identity, psk_secret): +def test_s2n_client_psk_handshake_failure( + managed_process, cipher, curve, protocol, provider, psk_identity, psk_secret +): port = next(available_ports) random_bytes = data_bytes(10) psk_hash_alg = get_psk_hash_alg_from_cipher(cipher) skip_invalid_psk_tests(provider, psk_hash_alg) - client_psk_params = setup_s2n_psk_params( - psk_identity, psk_secret, psk_hash_alg) + client_psk_params = setup_s2n_psk_params(psk_identity, psk_secret, psk_hash_alg) client_options = setup_provider_options( - S2N.ClientMode, port, cipher, curve, None, random_bytes, client_psk_params) + S2N.ClientMode, port, cipher, curve, None, random_bytes, client_psk_params + ) server_psk_params = setup_openssl_psk_params( - PSK_IDENTITY_NO_MATCH, PSK_SECRET_NO_MATCH) - server_psk_params += ['-nocert'] + PSK_IDENTITY_NO_MATCH, PSK_SECRET_NO_MATCH + ) + server_psk_params += ["-nocert"] server_options = setup_provider_options( - provider.ServerMode, port, cipher, curve, None, None, server_psk_params) + provider.ServerMode, port, cipher, curve, None, None, server_psk_params + ) - server = managed_process(provider, server_options, - timeout=5, close_marker=str(random_bytes)) + server = managed_process( + provider, server_options, timeout=5, close_marker=str(random_bytes) + ) client = managed_process(S2N, client_options, timeout=5) for results in client.get_results(): - assert to_bytes( - "Failed to negotiate: 'TLS alert received'") in results.stderr - validate_negotiated_psk_s2n( - Outcome.handshake_failed, psk_identity, results) + assert to_bytes("Failed to negotiate: 'TLS alert received'") in results.stderr + validate_negotiated_psk_s2n(Outcome.handshake_failed, psk_identity, results) for results in server.get_results(): assert to_bytes("SSL_accept:error in error") in results.stderr diff --git a/tests/integrationv2/test_fragmentation.py b/tests/integrationv2/test_fragmentation.py index 2195c85840e..73151a79ab7 100644 --- a/tests/integrationv2/test_fragmentation.py +++ b/tests/integrationv2/test_fragmentation.py @@ -7,19 +7,21 @@ from common import ProviderOptions, Ciphers, Certificates, data_bytes from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N, OpenSSL, GnuTLS -from utils import invalid_test_parameters, get_parameter_name, get_expected_s2n_version, to_bytes +from utils import ( + invalid_test_parameters, + get_parameter_name, + get_expected_s2n_version, + to_bytes, +) CIPHERS_TO_TEST = [ Ciphers.AES256_SHA, Ciphers.ECDHE_ECDSA_AES256_SHA, - Ciphers.AES256_GCM_SHA384 + Ciphers.AES256_GCM_SHA384, ] -CERTIFICATES_TO_TEST = [ - Certificates.RSA_4096_SHA384, - Certificates.ECDSA_384 -] +CERTIFICATES_TO_TEST = [Certificates.RSA_4096_SHA384, Certificates.ECDSA_384] @pytest.mark.uncollect_if(func=invalid_test_parameters) @@ -28,10 +30,13 @@ @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("protocol", PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("certificate", CERTIFICATES_TO_TEST, ids=get_parameter_name) -def test_s2n_server_low_latency(managed_process, cipher, provider, other_provider, protocol, certificate): - if provider is OpenSSL and 'openssl-1.0.2' in provider.get_version(): +def test_s2n_server_low_latency( + managed_process, cipher, provider, other_provider, protocol, certificate +): + if provider is OpenSSL and "openssl-1.0.2" in provider.get_version(): pytest.skip( - '{} does not allow setting max fragmentation for packets'.format(provider)) + "{} does not allow setting max fragmentation for packets".format(provider) + ) port = next(available_ports) @@ -42,12 +47,13 @@ def test_s2n_server_low_latency(managed_process, cipher, provider, other_provide cipher=cipher, data_to_send=random_bytes, insecure=True, - protocol=protocol) + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.data_to_send = None server_options.mode = Provider.ServerMode - server_options.extra_flags = ['--prefer-low-latency'] + server_options.extra_flags = ["--prefer-low-latency"] server_options.key = certificate.key server_options.cert = certificate.cert server_options.cipher = None @@ -62,8 +68,10 @@ def test_s2n_server_low_latency(managed_process, cipher, provider, other_provide for results in server.get_results(): results.assert_success() - assert to_bytes("Actual protocol version: {}".format( - expected_version)) in results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + in results.stdout + ) assert random_bytes in results.stdout @@ -85,12 +93,16 @@ def invalid_test_parameters_frag_len(*args, **kwargs): @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("protocol", PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("certificate", CERTIFICATES_TO_TEST, ids=get_parameter_name) -@pytest.mark.parametrize("frag_len", [512, 2048, 8192, 12345, 16384], ids=get_parameter_name) -def test_s2n_server_framented_data(managed_process, cipher, provider, other_provider, protocol, certificate, - frag_len): - if provider is OpenSSL and 'openssl-1.0.2' in provider.get_version(): +@pytest.mark.parametrize( + "frag_len", [512, 2048, 8192, 12345, 16384], ids=get_parameter_name +) +def test_s2n_server_framented_data( + managed_process, cipher, provider, other_provider, protocol, certificate, frag_len +): + if provider is OpenSSL and "openssl-1.0.2" in provider.get_version(): pytest.skip( - '{} does not allow setting max fragmentation for packets'.format(provider)) + "{} does not allow setting max fragmentation for packets".format(provider) + ) port = next(available_ports) @@ -102,7 +114,7 @@ def test_s2n_server_framented_data(managed_process, cipher, provider, other_prov data_to_send=random_bytes, insecure=True, record_size=frag_len, - protocol=protocol + protocol=protocol, ) server_options = copy.copy(client_options) @@ -123,8 +135,10 @@ def test_s2n_server_framented_data(managed_process, cipher, provider, other_prov for server_results in server.get_results(): server_results.assert_success() - assert to_bytes("Actual protocol version: {}".format( - expected_version)) in server_results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + in server_results.stdout + ) if provider == GnuTLS: # GnuTLS ignores data sent through stdin past frag_len up to the application data diff --git a/tests/integrationv2/test_happy_path.py b/tests/integrationv2/test_happy_path.py index b264e815078..63e68351b92 100644 --- a/tests/integrationv2/test_happy_path.py +++ b/tests/integrationv2/test_happy_path.py @@ -3,11 +3,22 @@ import copy import pytest -from configuration import available_ports, ALL_TEST_CIPHERS, ALL_TEST_CURVES, ALL_TEST_CERTS, PROTOCOLS +from configuration import ( + available_ports, + ALL_TEST_CIPHERS, + ALL_TEST_CURVES, + ALL_TEST_CERTS, + PROTOCOLS, +) from common import ProviderOptions, data_bytes from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N, OpenSSL, JavaSSL, GnuTLS, SSLv3Provider -from utils import invalid_test_parameters, get_parameter_name, get_expected_s2n_version, to_bytes +from utils import ( + invalid_test_parameters, + get_parameter_name, + get_expected_s2n_version, + to_bytes, +) @pytest.mark.flaky(reruns=5, reruns_delay=2) @@ -17,7 +28,9 @@ @pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name) @pytest.mark.parametrize("protocol", PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("certificate", ALL_TEST_CERTS, ids=get_parameter_name) -def test_s2n_server_happy_path(managed_process, cipher, provider, curve, protocol, certificate): +def test_s2n_server_happy_path( + managed_process, cipher, provider, curve, protocol, certificate +): port = next(available_ports) # s2nd can receive large amounts of data because all the data is @@ -34,7 +47,7 @@ def test_s2n_server_happy_path(managed_process, cipher, provider, curve, protoco curve=curve, data_to_send=random_bytes, insecure=True, - protocol=protocol + protocol=protocol, ) server_options = copy.copy(client_options) @@ -60,13 +73,17 @@ def test_s2n_server_happy_path(managed_process, cipher, provider, curve, protoco # the stdout reliably. for server_results in server.get_results(): server_results.assert_success() - assert to_bytes("Actual protocol version: {}".format( - expected_version)) in server_results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + in server_results.stdout + ) assert random_bytes in server_results.stdout if provider is not S2N: - assert to_bytes("Cipher negotiated: {}".format( - cipher.name)) in server_results.stdout + assert ( + to_bytes("Cipher negotiated: {}".format(cipher.name)) + in server_results.stdout + ) @pytest.mark.flaky(reruns=5, reruns_delay=2) @@ -76,7 +93,9 @@ def test_s2n_server_happy_path(managed_process, cipher, provider, curve, protoco @pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name) @pytest.mark.parametrize("protocol", PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("certificate", ALL_TEST_CERTS, ids=get_parameter_name) -def test_s2n_client_happy_path(managed_process, cipher, provider, curve, protocol, certificate): +def test_s2n_client_happy_path( + managed_process, cipher, provider, curve, protocol, certificate +): port = next(available_ports) # We can only send 4096 - 1 (\n at the end) bytes here because of the @@ -108,8 +127,9 @@ def test_s2n_client_happy_path(managed_process, cipher, provider, curve, protoco # Passing the type of client and server as a parameter will # allow us to use a fixture to enumerate all possibilities. - server = managed_process(provider, server_options, - timeout=5, kill_marker=kill_marker) + server = managed_process( + provider, server_options, timeout=5, kill_marker=kill_marker + ) client = managed_process(S2N, client_options, timeout=5) expected_version = get_expected_s2n_version(protocol, provider) @@ -118,8 +138,10 @@ def test_s2n_client_happy_path(managed_process, cipher, provider, curve, protoco # the stdout reliably. for client_results in client.get_results(): client_results.assert_success() - assert to_bytes("Actual protocol version: {}".format( - expected_version)) in client_results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + in client_results.stdout + ) # The server will be one of all supported providers. We # just want to make sure there was no exception and that @@ -128,4 +150,5 @@ def test_s2n_client_happy_path(managed_process, cipher, provider, curve, protoco server_results.assert_success() # Avoid debugging information that sometimes gets inserted after the first character. assert any( - [random_bytes[1:] in stream for stream in server_results.output_streams()]) + [random_bytes[1:] in stream for stream in server_results.output_streams()] + ) diff --git a/tests/integrationv2/test_hello_retry_requests.py b/tests/integrationv2/test_hello_retry_requests.py index f44cedfc940..4de68982143 100644 --- a/tests/integrationv2/test_hello_retry_requests.py +++ b/tests/integrationv2/test_hello_retry_requests.py @@ -4,7 +4,12 @@ import pytest import re -from configuration import available_ports, TLS13_CIPHERS, ALL_TEST_CURVES, ALL_TEST_CERTS +from configuration import ( + available_ports, + TLS13_CIPHERS, + ALL_TEST_CURVES, + ALL_TEST_CERTS, +) from common import ProviderOptions, Protocols, data_bytes, Curves from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N, OpenSSL @@ -18,7 +23,7 @@ "X25519": "x25519", "P-256": "secp256r1", "P-384": "secp384r1", - "P-521": "secp521r1" + "P-521": "secp521r1", } @@ -38,7 +43,9 @@ def test_nothing(): @pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name) @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("certificate", ALL_TEST_CERTS, ids=get_parameter_name) -def test_hrr_with_s2n_as_client(managed_process, cipher, provider, other_provider, curve, protocol, certificate): +def test_hrr_with_s2n_as_client( + managed_process, cipher, provider, other_provider, curve, protocol, certificate +): if curve == S2N_DEFAULT_CURVE: pytest.skip("No retry if server curve matches client curve") @@ -51,7 +58,8 @@ def test_hrr_with_s2n_as_client(managed_process, cipher, provider, other_provide cipher=cipher, data_to_send=random_bytes, insecure=True, - protocol=protocol) + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.data_to_send = None @@ -69,8 +77,7 @@ def test_hrr_with_s2n_as_client(managed_process, cipher, provider, other_provide # The client should connect and return without error for results in client.get_results(): results.assert_success() - assert to_bytes("Curve: {}".format( - CURVE_NAMES[curve.name])) in results.stdout + assert to_bytes("Curve: {}".format(CURVE_NAMES[curve.name])) in results.stdout assert S2N_HRR_MARKER in results.stdout marker_part1 = b"cf 21 ad 74 e5" @@ -80,9 +87,17 @@ def test_hrr_with_s2n_as_client(managed_process, cipher, provider, other_provide results.assert_success() assert marker_part1 in results.stdout and marker_part2 in results.stdout # The "test_all" s2n security policy includes draft Hybrid PQ groups that Openssl server prints as hex values - assert re.search(b'Supported Elliptic Groups: [x0-9A-F:]*X25519:P-256:P-384', results.stdout) is not None - assert to_bytes("Shared Elliptic groups: {}".format( - server_options.curve)) in results.stdout + assert ( + re.search( + b"Supported Elliptic Groups: [x0-9A-F:]*X25519:P-256:P-384", + results.stdout, + ) + is not None + ) + assert ( + to_bytes("Shared Elliptic groups: {}".format(server_options.curve)) + in results.stdout + ) assert random_bytes in results.stdout @@ -93,7 +108,9 @@ def test_hrr_with_s2n_as_client(managed_process, cipher, provider, other_provide @pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name) @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("certificate", ALL_TEST_CERTS, ids=get_parameter_name) -def test_hrr_with_s2n_as_server(managed_process, cipher, provider, other_provider, curve, protocol, certificate): +def test_hrr_with_s2n_as_server( + managed_process, cipher, provider, other_provider, curve, protocol, certificate +): port = next(available_ports) random_bytes = data_bytes(64) @@ -104,8 +121,9 @@ def test_hrr_with_s2n_as_server(managed_process, cipher, provider, other_provide data_to_send=random_bytes, insecure=True, curve=curve, - extra_flags=['-msg', '-curves', 'X448:'+str(curve)], - protocol=protocol) + extra_flags=["-msg", "-curves", "X448:" + str(curve)], + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.data_to_send = None @@ -123,8 +141,7 @@ def test_hrr_with_s2n_as_server(managed_process, cipher, provider, other_provide for results in server.get_results(): results.assert_success() assert random_bytes in results.stdout - assert to_bytes("Curve: {}".format( - CURVE_NAMES[curve.name])) in results.stdout + assert to_bytes("Curve: {}".format(CURVE_NAMES[curve.name])) in results.stdout assert random_bytes in results.stdout assert S2N_HRR_MARKER in results.stdout @@ -137,9 +154,9 @@ def test_hrr_with_s2n_as_server(managed_process, cipher, provider, other_provide for results in client.get_results(): results.assert_success() assert marker in results.stdout - client_hello_count = results.stdout.count(b'ClientHello') - server_hello_count = results.stdout.count(b'ServerHello') - finished_count = results.stdout.count(b'Finished') + client_hello_count = results.stdout.count(b"ClientHello") + server_hello_count = results.stdout.count(b"ServerHello") + finished_count = results.stdout.count(b"Finished") assert client_hello_count == 2 assert server_hello_count == 2 @@ -157,7 +174,9 @@ def test_hrr_with_s2n_as_server(managed_process, cipher, provider, other_provide @pytest.mark.parametrize("curve", TEST_CURVES, ids=get_parameter_name) @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("certificate", ALL_TEST_CERTS, ids=get_parameter_name) -def test_hrr_with_default_keyshare(managed_process, cipher, provider, other_provider, curve, protocol, certificate): +def test_hrr_with_default_keyshare( + managed_process, cipher, provider, other_provider, curve, protocol, certificate +): port = next(available_ports) random_bytes = data_bytes(64) @@ -167,7 +186,8 @@ def test_hrr_with_default_keyshare(managed_process, cipher, provider, other_prov cipher=cipher, data_to_send=random_bytes, insecure=True, - protocol=protocol) + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.data_to_send = None @@ -185,8 +205,7 @@ def test_hrr_with_default_keyshare(managed_process, cipher, provider, other_prov # The client should connect and return without error for results in client.get_results(): results.assert_success() - assert to_bytes("Curve: {}".format( - CURVE_NAMES[curve.name])) in results.stdout + assert to_bytes("Curve: {}".format(CURVE_NAMES[curve.name])) in results.stdout assert S2N_HRR_MARKER in results.stdout marker_part1 = b"cf 21 ad 74 e5" @@ -196,7 +215,15 @@ def test_hrr_with_default_keyshare(managed_process, cipher, provider, other_prov results.assert_success() assert marker_part1 in results.stdout and marker_part2 in results.stdout # The "test_all" s2n security policy includes draft Hybrid PQ groups that Openssl server prints as hex values - assert re.search(b'Supported Elliptic Groups: [x0-9A-F:]*X25519:P-256:P-384', results.stdout) is not None - assert to_bytes("Shared Elliptic groups: {}".format( - server_options.curve)) in results.stdout + assert ( + re.search( + b"Supported Elliptic Groups: [x0-9A-F:]*X25519:P-256:P-384", + results.stdout, + ) + is not None + ) + assert ( + to_bytes("Shared Elliptic groups: {}".format(server_options.curve)) + in results.stdout + ) assert random_bytes in results.stdout diff --git a/tests/integrationv2/test_key_update.py b/tests/integrationv2/test_key_update.py index e14da5d2571..3bb5e70dd8e 100644 --- a/tests/integrationv2/test_key_update.py +++ b/tests/integrationv2/test_key_update.py @@ -28,7 +28,9 @@ def test_nothing(): @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) -def test_s2n_server_key_update(managed_process, cipher, provider, other_provider, protocol): +def test_s2n_server_key_update( + managed_process, cipher, provider, other_provider, protocol +): host = "localhost" port = next(available_ports) @@ -55,9 +57,7 @@ def test_s2n_server_key_update(managed_process, cipher, provider, other_provider server_options.cert = "../pems/ecdsa_p384_pkcs1_cert.pem" server_options.data_to_send = [SERVER_DATA.encode()] - server = managed_process( - S2N, server_options, send_marker=CLIENT_DATA, timeout=30 - ) + server = managed_process(S2N, server_options, send_marker=CLIENT_DATA, timeout=30) client = managed_process( provider, client_options, @@ -82,7 +82,9 @@ def test_s2n_server_key_update(managed_process, cipher, provider, other_provider @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) -def test_s2n_client_key_update(managed_process, cipher, provider, other_provider, protocol): +def test_s2n_client_key_update( + managed_process, cipher, provider, other_provider, protocol +): host = "localhost" port = next(available_ports) diff --git a/tests/integrationv2/test_npn.py b/tests/integrationv2/test_npn.py index ddb29914e55..6d082d89f4c 100644 --- a/tests/integrationv2/test_npn.py +++ b/tests/integrationv2/test_npn.py @@ -3,7 +3,13 @@ import copy import pytest -from configuration import available_ports, ALL_TEST_CIPHERS, ALL_TEST_CURVES, MINIMAL_TEST_CERTS, PROTOCOLS +from configuration import ( + available_ports, + ALL_TEST_CIPHERS, + ALL_TEST_CURVES, + MINIMAL_TEST_CERTS, + PROTOCOLS, +) from common import ProviderOptions, Protocols from fixtures import managed_process # lgtm [py/unused-import] from providers import OpenSSL, S2N, Provider @@ -24,12 +30,14 @@ OPENSSL_CLIENT_NPN_NO_OVERLAP_MARKER = "Next protocol: (2) " # Test lists -PROTOCOL_LIST = 'http/1.1,h2,h3' -PROTOCOL_LIST_ALT_ORDER = 'h2,h3,http/1.1' -PROTOCOL_LIST_NO_OVERLAP = 'spdy' +PROTOCOL_LIST = "http/1.1,h2,h3" +PROTOCOL_LIST_ALT_ORDER = "h2,h3,http/1.1" +PROTOCOL_LIST_NO_OVERLAP = "spdy" -def s2n_client_npn_handshake(managed_process, cipher, curve, certificate, protocol, provider, server_list): +def s2n_client_npn_handshake( + managed_process, cipher, curve, certificate, protocol, provider, server_list +): options = ProviderOptions( port=next(available_ports), cipher=cipher, @@ -43,12 +51,12 @@ def s2n_client_npn_handshake(managed_process, cipher, curve, certificate, protoc client_options = copy.copy(options) client_options.mode = Provider.ClientMode # Flags to turn on NPN for s2nc - client_options.extra_flags = ['--alpn', PROTOCOL_LIST, '--npn'] + client_options.extra_flags = ["--alpn", PROTOCOL_LIST, "--npn"] server_options = copy.copy(options) server_options.mode = Provider.ServerMode # Flags to turn on NPN for OpenSSL server - server_options.extra_flags = ['-nextprotoneg', server_list] + server_options.extra_flags = ["-nextprotoneg", server_list] server = managed_process(provider, server_options, timeout=5) s2n_client = managed_process(S2N, client_options, timeout=5) @@ -67,11 +75,20 @@ def s2n_client_npn_handshake(managed_process, cipher, curve, certificate, protoc @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("protocol", TLS_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) -def test_s2n_client_npn(managed_process, cipher, curve, certificate, protocol, provider): - s2n_client, server = s2n_client_npn_handshake(managed_process, cipher, curve, certificate, protocol, provider, - server_list=PROTOCOL_LIST) +def test_s2n_client_npn( + managed_process, cipher, curve, certificate, protocol, provider +): + s2n_client, server = s2n_client_npn_handshake( + managed_process, + cipher, + curve, + certificate, + protocol, + provider, + server_list=PROTOCOL_LIST, + ) - expected_protocol = 'http/1.1' + expected_protocol = "http/1.1" for results in server.get_results(): results.assert_success() @@ -94,11 +111,20 @@ def test_s2n_client_npn(managed_process, cipher, curve, certificate, protocol, p @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("protocol", TLS_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) -def test_s2n_client_npn_server_preference(managed_process, cipher, curve, certificate, protocol, provider): - s2n_client, server = s2n_client_npn_handshake(managed_process, cipher, curve, certificate, protocol, provider, - server_list=PROTOCOL_LIST_ALT_ORDER) +def test_s2n_client_npn_server_preference( + managed_process, cipher, curve, certificate, protocol, provider +): + s2n_client, server = s2n_client_npn_handshake( + managed_process, + cipher, + curve, + certificate, + protocol, + provider, + server_list=PROTOCOL_LIST_ALT_ORDER, + ) - expected_protocol = 'h2' + expected_protocol = "h2" for results in server.get_results(): results.assert_success() @@ -121,11 +147,20 @@ def test_s2n_client_npn_server_preference(managed_process, cipher, curve, certif @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("protocol", TLS_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) -def test_s2n_client_npn_no_overlap(managed_process, cipher, curve, certificate, protocol, provider): - s2n_client, server = s2n_client_npn_handshake(managed_process, cipher, curve, certificate, protocol, provider, - server_list=PROTOCOL_LIST_NO_OVERLAP) +def test_s2n_client_npn_no_overlap( + managed_process, cipher, curve, certificate, protocol, provider +): + s2n_client, server = s2n_client_npn_handshake( + managed_process, + cipher, + curve, + certificate, + protocol, + provider, + server_list=PROTOCOL_LIST_NO_OVERLAP, + ) - expected_protocol = 'http/1.1' + expected_protocol = "http/1.1" for results in server.get_results(): results.assert_success() @@ -137,7 +172,9 @@ def test_s2n_client_npn_no_overlap(managed_process, cipher, curve, certificate, assert to_bytes(S2N_APPLICATION_MARKER + expected_protocol) in results.stdout -def s2n_server_npn_handshake(managed_process, cipher, curve, certificate, protocol, provider, server_list): +def s2n_server_npn_handshake( + managed_process, cipher, curve, certificate, protocol, provider, server_list +): options = ProviderOptions( port=next(available_ports), cipher=cipher, @@ -151,12 +188,12 @@ def s2n_server_npn_handshake(managed_process, cipher, curve, certificate, protoc client_options = copy.copy(options) client_options.mode = Provider.ClientMode # Flags to turn on NPN for OpenSSL client - client_options.extra_flags = ['-nextprotoneg', PROTOCOL_LIST] + client_options.extra_flags = ["-nextprotoneg", PROTOCOL_LIST] server_options = copy.copy(options) server_options.mode = Provider.ServerMode # Flags to turn on NPN for s2nd. - server_options.extra_flags = ['--alpn', server_list, '--npn'] + server_options.extra_flags = ["--alpn", server_list, "--npn"] s2n_server = managed_process(S2N, server_options, timeout=5) client = managed_process(provider, client_options, timeout=5) @@ -175,14 +212,23 @@ def s2n_server_npn_handshake(managed_process, cipher, curve, certificate, protoc @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("protocol", TLS_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) -def test_s2n_server_npn(managed_process, cipher, curve, certificate, protocol, provider): +def test_s2n_server_npn( + managed_process, cipher, curve, certificate, protocol, provider +): # We only send one protocol on the s2n server # due to the fact that it re-purposes the alpn list(which only sends one protocol) # to work for the NPN list. - client, s2n_server = s2n_server_npn_handshake(managed_process, cipher, curve, certificate, protocol, provider, - server_list='http/1.1') + client, s2n_server = s2n_server_npn_handshake( + managed_process, + cipher, + curve, + certificate, + protocol, + provider, + server_list="http/1.1", + ) - expected_protocol = 'http/1.1' + expected_protocol = "http/1.1" for results in s2n_server.get_results(): results.assert_success() @@ -206,11 +252,20 @@ def test_s2n_server_npn(managed_process, cipher, curve, certificate, protocol, p @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("protocol", TLS_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) -def test_s2n_server_npn_no_overlap(managed_process, cipher, curve, certificate, protocol, provider): - client, s2n_server = s2n_server_npn_handshake(managed_process, cipher, curve, certificate, protocol, provider, - server_list=PROTOCOL_LIST_NO_OVERLAP) +def test_s2n_server_npn_no_overlap( + managed_process, cipher, curve, certificate, protocol, provider +): + client, s2n_server = s2n_server_npn_handshake( + managed_process, + cipher, + curve, + certificate, + protocol, + provider, + server_list=PROTOCOL_LIST_NO_OVERLAP, + ) - expected_protocol = 'http/1.1' + expected_protocol = "http/1.1" for results in s2n_server.get_results(): results.assert_success() @@ -219,4 +274,7 @@ def test_s2n_server_npn_no_overlap(managed_process, cipher, curve, certificate, for results in client.get_results(): results.assert_success() - assert to_bytes(OPENSSL_CLIENT_NPN_NO_OVERLAP_MARKER + expected_protocol) in results.stdout + assert ( + to_bytes(OPENSSL_CLIENT_NPN_NO_OVERLAP_MARKER + expected_protocol) + in results.stdout + ) diff --git a/tests/integrationv2/test_ocsp.py b/tests/integrationv2/test_ocsp.py index 71af0e59d15..ed1253d2a9d 100644 --- a/tests/integrationv2/test_ocsp.py +++ b/tests/integrationv2/test_ocsp.py @@ -21,7 +21,9 @@ @pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name) @pytest.mark.parametrize("protocol", PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("certificate", OCSP_CERTS, ids=get_parameter_name) -def test_s2n_client_ocsp_response(managed_process, cipher, provider, other_provider, curve, protocol, certificate): +def test_s2n_client_ocsp_response( + managed_process, cipher, provider, other_provider, curve, protocol, certificate +): if "boringssl" in get_flag(S2N_PROVIDER_VERSION): pytest.skip("s2n-tls client with boringssl does not support ocsp") @@ -36,7 +38,7 @@ def test_s2n_client_ocsp_response(managed_process, cipher, provider, other_provi protocol=protocol, insecure=True, data_to_send=random_bytes, - enable_client_ocsp=True + enable_client_ocsp=True, ) server_options = ProviderOptions( @@ -48,8 +50,8 @@ def test_s2n_client_ocsp_response(managed_process, cipher, provider, other_provi key=certificate.key, cert=certificate.cert, ocsp_response={ - "RSA": TEST_OCSP_DIRECTORY + "ocsp_response.der", - "EC": TEST_OCSP_DIRECTORY + "ocsp_ecdsa_response.der" + "RSA": TEST_OCSP_DIRECTORY + "ocsp_response.der", + "EC": TEST_OCSP_DIRECTORY + "ocsp_ecdsa_response.der", }.get(certificate.algorithm), ) @@ -59,10 +61,7 @@ def test_s2n_client_ocsp_response(managed_process, cipher, provider, other_provi kill_marker = random_bytes server = managed_process( - provider, - server_options, - timeout=30, - kill_marker=kill_marker + provider, server_options, timeout=30, kill_marker=kill_marker ) client = managed_process(S2N, client_options, timeout=30) @@ -73,7 +72,10 @@ def test_s2n_client_ocsp_response(managed_process, cipher, provider, other_provi for server_results in server.get_results(): server_results.assert_success() # Avoid debugging information that sometimes gets inserted after the first character. - assert random_bytes[1:] in server_results.stdout or random_bytes[1:] in server_results.stderr + assert ( + random_bytes[1:] in server_results.stdout + or random_bytes[1:] in server_results.stderr + ) @pytest.mark.uncollect_if(func=invalid_test_parameters) @@ -83,7 +85,9 @@ def test_s2n_client_ocsp_response(managed_process, cipher, provider, other_provi @pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name) @pytest.mark.parametrize("protocol", PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("certificate", OCSP_CERTS, ids=get_parameter_name) -def test_s2n_server_ocsp_response(managed_process, cipher, provider, other_provider, curve, protocol, certificate): +def test_s2n_server_ocsp_response( + managed_process, cipher, provider, other_provider, curve, protocol, certificate +): port = next(available_ports) random_bytes = data_bytes(128) @@ -95,7 +99,7 @@ def test_s2n_server_ocsp_response(managed_process, cipher, provider, other_provi protocol=protocol, insecure=True, data_to_send=random_bytes, - enable_client_ocsp=True + enable_client_ocsp=True, ) server_options = ProviderOptions( @@ -108,8 +112,8 @@ def test_s2n_server_ocsp_response(managed_process, cipher, provider, other_provi key=certificate.key, cert=certificate.cert, ocsp_response={ - "RSA": TEST_OCSP_DIRECTORY + "ocsp_response.der", - "EC": TEST_OCSP_DIRECTORY + "ocsp_ecdsa_response.der" + "RSA": TEST_OCSP_DIRECTORY + "ocsp_response.der", + "EC": TEST_OCSP_DIRECTORY + "ocsp_ecdsa_response.der", }.get(certificate.algorithm), ) @@ -120,21 +124,27 @@ def test_s2n_server_ocsp_response(managed_process, cipher, provider, other_provi kill_marker = b"Sent: " server = managed_process(S2N, server_options, timeout=90) - client = managed_process(provider, client_options, - timeout=90, kill_marker=kill_marker) + client = managed_process( + provider, client_options, timeout=90, kill_marker=kill_marker + ) for client_results in client.get_results(): client_results.assert_success() - assert any([ - { - GnuTLS: b"OCSP Response Information:\n\tResponse Status: Successful", - OpenSSL: b"OCSP Response Status: successful" - }.get(provider) in stream for stream in client_results.output_streams() - ]) + assert any( + [ + { + GnuTLS: b"OCSP Response Information:\n\tResponse Status: Successful", + OpenSSL: b"OCSP Response Status: successful", + }.get(provider) + in stream + for stream in client_results.output_streams() + ] + ) for server_results in server.get_results(): server_results.assert_success() # Avoid debugging information that sometimes gets inserted after the first character. assert any( - [random_bytes[1:] in stream for stream in server_results.output_streams()]) + [random_bytes[1:] in stream for stream in server_results.output_streams()] + ) diff --git a/tests/integrationv2/test_pq_handshake.py b/tests/integrationv2/test_pq_handshake.py index 73f218e1cfc..b3684b1dbe6 100644 --- a/tests/integrationv2/test_pq_handshake.py +++ b/tests/integrationv2/test_pq_handshake.py @@ -4,7 +4,15 @@ import os from configuration import available_ports -from common import Ciphers, Curves, ProviderOptions, Protocols, KemGroups, Certificates, pq_enabled +from common import ( + Ciphers, + Curves, + ProviderOptions, + Protocols, + KemGroups, + Certificates, + pq_enabled, +) from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N, OpenSSL, BoringSSL from utils import invalid_test_parameters, get_parameter_name, to_bytes @@ -28,33 +36,42 @@ EXPECTED_RESULTS = { # The tuple keys have the form: # (client_{cipher, kem_group}, server_{cipher, kem_group}): {"cipher": {expected_cipher}, "kem_group": {expected_kem_group}} - (Ciphers.PQ_TLS_1_0_2023_01, Ciphers.PQ_TLS_1_0_2023_01): - {"cipher": "TLS_AES_256_GCM_SHA384", - "kem_group": "_kyber-512-r3"}, - (KemGroups.P384_KYBER768R3, Ciphers.PQ_TLS_1_3_2023_06_01): - {"cipher": "AES256_GCM_SHA384", - "kem_group": "secp384r1_kyber-768-r3"}, - (KemGroups.P521_KYBER1024R3, Ciphers.PQ_TLS_1_3_2023_06_01): - {"cipher": "AES256_GCM_SHA384", - "kem_group": "secp521r1_kyber-1024-r3"}, - (Ciphers.PQ_TLS_1_3_2023_06_01, KemGroups.X25519Kyber768Draft00): - {"cipher": "TLS_AES_256_GCM_SHA384", - "kem_group": "X25519Kyber768Draft00"}, - (Ciphers.PQ_TLS_1_3_2023_06_01, KemGroups.SecP256r1Kyber768Draft00): - {"cipher": "TLS_AES_256_GCM_SHA384", - "kem_group": "SecP256r1Kyber768Draft00"}, - (Ciphers.PQ_TLS_1_3_2023_06_01, Ciphers.PQ_TLS_1_3_2023_06_01): - {"cipher": "TLS_AES_256_GCM_SHA384", - "kem_group": "SecP256r1Kyber768Draft00"}, - (Ciphers.PQ_TLS_1_3_2023_06_01, Ciphers.KMS_TLS_1_0_2018_10): - {"cipher": "ECDHE-RSA-AES256-GCM-SHA384", - "kem_group": None}, - (Ciphers.KMS_TLS_1_0_2018_10, Ciphers.PQ_TLS_1_3_2023_06_01): - {"cipher": "ECDHE-RSA-AES128-GCM-SHA256", - "kem_group": None}, - (Ciphers.KMS_TLS_1_0_2018_10, Ciphers.KMS_TLS_1_0_2018_10): - {"cipher": "ECDHE-RSA-AES256-GCM-SHA384", - "kem_group": None}, + (Ciphers.PQ_TLS_1_0_2023_01, Ciphers.PQ_TLS_1_0_2023_01): { + "cipher": "TLS_AES_256_GCM_SHA384", + "kem_group": "_kyber-512-r3", + }, + (KemGroups.P384_KYBER768R3, Ciphers.PQ_TLS_1_3_2023_06_01): { + "cipher": "AES256_GCM_SHA384", + "kem_group": "secp384r1_kyber-768-r3", + }, + (KemGroups.P521_KYBER1024R3, Ciphers.PQ_TLS_1_3_2023_06_01): { + "cipher": "AES256_GCM_SHA384", + "kem_group": "secp521r1_kyber-1024-r3", + }, + (Ciphers.PQ_TLS_1_3_2023_06_01, KemGroups.X25519Kyber768Draft00): { + "cipher": "TLS_AES_256_GCM_SHA384", + "kem_group": "X25519Kyber768Draft00", + }, + (Ciphers.PQ_TLS_1_3_2023_06_01, KemGroups.SecP256r1Kyber768Draft00): { + "cipher": "TLS_AES_256_GCM_SHA384", + "kem_group": "SecP256r1Kyber768Draft00", + }, + (Ciphers.PQ_TLS_1_3_2023_06_01, Ciphers.PQ_TLS_1_3_2023_06_01): { + "cipher": "TLS_AES_256_GCM_SHA384", + "kem_group": "SecP256r1Kyber768Draft00", + }, + (Ciphers.PQ_TLS_1_3_2023_06_01, Ciphers.KMS_TLS_1_0_2018_10): { + "cipher": "ECDHE-RSA-AES256-GCM-SHA384", + "kem_group": None, + }, + (Ciphers.KMS_TLS_1_0_2018_10, Ciphers.PQ_TLS_1_3_2023_06_01): { + "cipher": "ECDHE-RSA-AES128-GCM-SHA256", + "kem_group": None, + }, + (Ciphers.KMS_TLS_1_0_2018_10, Ciphers.KMS_TLS_1_0_2018_10): { + "cipher": "ECDHE-RSA-AES256-GCM-SHA384", + "kem_group": None, + }, } """ @@ -73,7 +90,9 @@ def invalid_pq_handshake_test_parameters(*args, **kwargs): # `or` is correct: invalid_test_parameters() returns True if the parameters are invalid; # we want to return True here if either of the sets of parameters are invalid. - return invalid_test_parameters(*args, **client_cipher_kwargs) or invalid_test_parameters(*args, **server_cipher_kwargs) + return invalid_test_parameters( + *args, **client_cipher_kwargs + ) or invalid_test_parameters(*args, **server_cipher_kwargs) def get_oqs_openssl_override_env_vars(): @@ -88,14 +107,16 @@ def get_oqs_openssl_override_env_vars(): def assert_s2n_negotiation_parameters(s2n_results, expected_result): if expected_result is not None: - assert to_bytes( - ("Cipher negotiated: " + expected_result['cipher'])) in s2n_results.stdout - if expected_result['kem_group']: + assert ( + to_bytes(("Cipher negotiated: " + expected_result["cipher"])) + in s2n_results.stdout + ) + if expected_result["kem_group"]: # Purposefully leave off the "KEM Group: " prefix in order to perform partial matches # without specifying the curve. - assert to_bytes(expected_result['kem_group']) in s2n_results.stdout + assert to_bytes(expected_result["kem_group"]) in s2n_results.stdout assert to_bytes(PQ_ENABLED_FLAG) in s2n_results.stdout - if not expected_result['kem_group']: + if not expected_result["kem_group"]: assert to_bytes(PQ_ENABLED_FLAG) not in s2n_results.stdout assert to_bytes("Curve:") in s2n_results.stdout @@ -103,8 +124,8 @@ def assert_s2n_negotiation_parameters(s2n_results, expected_result): def assert_awslc_negotiation_parameters(awslc_results, expected_result): assert expected_result is not None assert awslc_results.exit_code is 0 - assert to_bytes(("group: " + expected_result['kem_group'])) in awslc_results.stderr - assert to_bytes(("Cipher: " + expected_result['cipher'])) in awslc_results.stderr + assert to_bytes(("group: " + expected_result["kem_group"])) in awslc_results.stderr + assert to_bytes(("Cipher: " + expected_result["cipher"])) in awslc_results.stderr def test_nothing(): @@ -117,14 +138,25 @@ def test_nothing(): @pytest.mark.uncollect_if(func=invalid_pq_handshake_test_parameters) -@pytest.mark.parametrize("protocol", [Protocols.TLS12, Protocols.TLS13], ids=get_parameter_name) -@pytest.mark.parametrize("certificate", [Certificates.RSA_4096_SHA512], ids=get_parameter_name) +@pytest.mark.parametrize( + "protocol", [Protocols.TLS12, Protocols.TLS13], ids=get_parameter_name +) +@pytest.mark.parametrize( + "certificate", [Certificates.RSA_4096_SHA512], ids=get_parameter_name +) @pytest.mark.parametrize("client_cipher", CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("server_cipher", CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -def test_s2nc_to_s2nd_pq_handshake(managed_process, protocol, certificate, client_cipher, server_cipher, provider, - other_provider): +def test_s2nc_to_s2nd_pq_handshake( + managed_process, + protocol, + certificate, + client_cipher, + server_cipher, + provider, + other_provider, +): port = next(available_ports) client_options = ProviderOptions( @@ -132,7 +164,8 @@ def test_s2nc_to_s2nd_pq_handshake(managed_process, protocol, certificate, clien port=port, insecure=True, cipher=client_cipher, - protocol=protocol) + protocol=protocol, + ) server_options = ProviderOptions( mode=Provider.ServerMode, @@ -140,19 +173,19 @@ def test_s2nc_to_s2nd_pq_handshake(managed_process, protocol, certificate, clien protocol=protocol, cipher=server_cipher, cert=certificate.cert, - key=certificate.key) + key=certificate.key, + ) server = managed_process(S2N, server_options, timeout=5) client = managed_process(S2N, client_options, timeout=5) if pq_enabled(): - expected_result = EXPECTED_RESULTS.get( - (client_cipher, server_cipher), None) + expected_result = EXPECTED_RESULTS.get((client_cipher, server_cipher), None) else: # If PQ is not enabled in s2n, we expect classic handshakes to be negotiated. # Leave the expected cipher blank, as there are multiple possibilities - the # important thing is that kem and kem_group are NONE. - expected_result = {"cipher": "", "kem_group": None} + expected_result = {"cipher": "", "kem_group": None} # Client and server are both s2n; can make meaningful assertions about negotiation for both for results in client.get_results(): @@ -164,18 +197,29 @@ def test_s2nc_to_s2nd_pq_handshake(managed_process, protocol, certificate, clien assert_s2n_negotiation_parameters(results, expected_result) -@pytest.mark.parametrize("s2n_client_policy", [Ciphers.PQ_TLS_1_3_2023_06_01], ids=get_parameter_name) -@pytest.mark.parametrize("awslc_server_group", [KemGroups.SecP256r1Kyber768Draft00, KemGroups.X25519Kyber768Draft00], ids=get_parameter_name) -def test_s2nc_to_awslc_pq_handshake(managed_process, s2n_client_policy, awslc_server_group): - +@pytest.mark.parametrize( + "s2n_client_policy", [Ciphers.PQ_TLS_1_3_2023_06_01], ids=get_parameter_name +) +@pytest.mark.parametrize( + "awslc_server_group", + [KemGroups.SecP256r1Kyber768Draft00, KemGroups.X25519Kyber768Draft00], + ids=get_parameter_name, +) +def test_s2nc_to_awslc_pq_handshake( + managed_process, s2n_client_policy, awslc_server_group +): if not pq_enabled(): pytest.skip("PQ not enabled") if "awslc" not in get_flag(S2N_PROVIDER_VERSION): - pytest.skip("s2n must be compiled with awslc libcrypto in order to test PQ TLS compatibility") + pytest.skip( + "s2n must be compiled with awslc libcrypto in order to test PQ TLS compatibility" + ) if "fips" in get_flag(S2N_PROVIDER_VERSION): - pytest.skip("No FIPS validated version of AWS-LC has support for negotiating Hybrid PQ TLS yet") + pytest.skip( + "No FIPS validated version of AWS-LC has support for negotiating Hybrid PQ TLS yet" + ) port = next(available_ports) @@ -184,17 +228,21 @@ def test_s2nc_to_awslc_pq_handshake(managed_process, s2n_client_policy, awslc_se port=port, insecure=True, cipher=s2n_client_policy, - protocol=Protocols.TLS13) + protocol=Protocols.TLS13, + ) awslc_server_options = ProviderOptions( mode=Provider.ServerMode, port=port, protocol=Protocols.TLS13, - curve=Curves.from_name(awslc_server_group.oqs_name)) + curve=Curves.from_name(awslc_server_group.oqs_name), + ) awslc_server = managed_process(BoringSSL, awslc_server_options, timeout=5) s2n_client = managed_process(S2N, s2nc_client_options, timeout=5) - expected_result = EXPECTED_RESULTS.get((s2n_client_policy, awslc_server_group), None) + expected_result = EXPECTED_RESULTS.get( + (s2n_client_policy, awslc_server_group), None + ) awslc_result = next(awslc_server.get_results()) assert_awslc_negotiation_parameters(awslc_result, expected_result) @@ -203,18 +251,29 @@ def test_s2nc_to_awslc_pq_handshake(managed_process, s2n_client_policy, awslc_se assert_s2n_negotiation_parameters(s2nd_result, expected_result) -@pytest.mark.parametrize("s2n_server_policy", [Ciphers.PQ_TLS_1_3_2023_06_01], ids=get_parameter_name) -@pytest.mark.parametrize("awslc_client_group", [KemGroups.SecP256r1Kyber768Draft00, KemGroups.X25519Kyber768Draft00], ids=get_parameter_name) -def test_s2nd_to_awslc_pq_handshake(managed_process, s2n_server_policy, awslc_client_group): - +@pytest.mark.parametrize( + "s2n_server_policy", [Ciphers.PQ_TLS_1_3_2023_06_01], ids=get_parameter_name +) +@pytest.mark.parametrize( + "awslc_client_group", + [KemGroups.SecP256r1Kyber768Draft00, KemGroups.X25519Kyber768Draft00], + ids=get_parameter_name, +) +def test_s2nd_to_awslc_pq_handshake( + managed_process, s2n_server_policy, awslc_client_group +): if not pq_enabled(): pytest.skip("PQ not enabled") if "awslc" not in get_flag(S2N_PROVIDER_VERSION): - pytest.skip("s2n must be compiled with awslc libcrypto in order to test PQ TLS compatibility") + pytest.skip( + "s2n must be compiled with awslc libcrypto in order to test PQ TLS compatibility" + ) if "fips" in get_flag(S2N_PROVIDER_VERSION): - pytest.skip("No FIPS validated version of AWS-LC has support for negotiating Hybrid PQ TLS yet") + pytest.skip( + "No FIPS validated version of AWS-LC has support for negotiating Hybrid PQ TLS yet" + ) port = next(available_ports) @@ -223,17 +282,21 @@ def test_s2nd_to_awslc_pq_handshake(managed_process, s2n_server_policy, awslc_cl port=port, insecure=True, cipher=s2n_server_policy, - protocol=Protocols.TLS13) + protocol=Protocols.TLS13, + ) awslc_client_options = ProviderOptions( mode=Provider.ClientMode, port=port, protocol=Protocols.TLS13, - curve=Curves.from_name(awslc_client_group.oqs_name)) + curve=Curves.from_name(awslc_client_group.oqs_name), + ) s2nd_server = managed_process(S2N, s2nd_server_options, timeout=5) awslc_client = managed_process(BoringSSL, awslc_client_options, timeout=5) - expected_result = EXPECTED_RESULTS.get((s2n_server_policy, awslc_client_group), None) + expected_result = EXPECTED_RESULTS.get( + (s2n_server_policy, awslc_client_group), None + ) awslc_result = next(awslc_client.get_results()) assert_awslc_negotiation_parameters(awslc_result, expected_result) @@ -244,7 +307,9 @@ def test_s2nd_to_awslc_pq_handshake(managed_process, s2n_server_policy, awslc_cl @pytest.mark.uncollect_if(func=invalid_test_parameters) @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) -@pytest.mark.parametrize("cipher", [Ciphers.PQ_TLS_1_3_2023_06_01], ids=get_parameter_name) +@pytest.mark.parametrize( + "cipher", [Ciphers.PQ_TLS_1_3_2023_06_01], ids=get_parameter_name +) @pytest.mark.parametrize("kem_group", KEM_GROUPS, ids=get_parameter_name) def test_s2nc_to_oqs_openssl_pq_handshake(managed_process, protocol, cipher, kem_group): # If PQ is not enabled in s2n, there is no reason to test against oqs_openssl @@ -258,7 +323,8 @@ def test_s2nc_to_oqs_openssl_pq_handshake(managed_process, protocol, cipher, kem port=port, insecure=True, cipher=cipher, - protocol=protocol) + protocol=protocol, + ) server_options = ProviderOptions( mode=Provider.ServerMode, @@ -267,7 +333,8 @@ def test_s2nc_to_oqs_openssl_pq_handshake(managed_process, protocol, cipher, kem cert=Certificates.RSA_4096_SHA512.cert, key=Certificates.RSA_4096_SHA512.key, env_overrides=get_oqs_openssl_override_env_vars(), - extra_flags=['-groups', kem_group.oqs_name]) + extra_flags=["-groups", kem_group.oqs_name], + ) server = managed_process(OpenSSL, server_options, timeout=5) client = managed_process(S2N, client_options, timeout=5) @@ -286,7 +353,9 @@ def test_s2nc_to_oqs_openssl_pq_handshake(managed_process, protocol, cipher, kem @pytest.mark.uncollect_if(func=invalid_test_parameters) @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) -@pytest.mark.parametrize("cipher", [Ciphers.PQ_TLS_1_3_2023_06_01], ids=get_parameter_name) +@pytest.mark.parametrize( + "cipher", [Ciphers.PQ_TLS_1_3_2023_06_01], ids=get_parameter_name +) @pytest.mark.parametrize("kem_group", KEM_GROUPS, ids=get_parameter_name) def test_oqs_openssl_to_s2nd_pq_handshake(managed_process, protocol, cipher, kem_group): # If PQ is not enabled in s2n, there is no reason to test against oqs_openssl @@ -300,7 +369,8 @@ def test_oqs_openssl_to_s2nd_pq_handshake(managed_process, protocol, cipher, kem port=port, protocol=protocol, env_overrides=get_oqs_openssl_override_env_vars(), - extra_flags=['-groups', kem_group.oqs_name]) + extra_flags=["-groups", kem_group.oqs_name], + ) server_options = ProviderOptions( mode=Provider.ServerMode, @@ -308,7 +378,8 @@ def test_oqs_openssl_to_s2nd_pq_handshake(managed_process, protocol, cipher, kem protocol=protocol, cipher=cipher, cert=Certificates.RSA_4096_SHA512.cert, - key=Certificates.RSA_4096_SHA512.key) + key=Certificates.RSA_4096_SHA512.key, + ) server = managed_process(S2N, server_options, timeout=5) client = managed_process(OpenSSL, client_options, timeout=5) diff --git a/tests/integrationv2/test_record_padding.py b/tests/integrationv2/test_record_padding.py index 05e4162137a..0bc78974ebf 100644 --- a/tests/integrationv2/test_record_padding.py +++ b/tests/integrationv2/test_record_padding.py @@ -5,21 +5,27 @@ import pytest import re -from configuration import available_ports, TLS13_CIPHERS, ALL_TEST_CURVES, MINIMAL_TEST_CERTS +from configuration import ( + available_ports, + TLS13_CIPHERS, + ALL_TEST_CURVES, + MINIMAL_TEST_CERTS, +) from common import ProviderOptions, Protocols, data_bytes from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N, OpenSSL -from utils import invalid_test_parameters, get_parameter_name, get_expected_s2n_version, to_bytes +from utils import ( + invalid_test_parameters, + get_parameter_name, + get_expected_s2n_version, + to_bytes, +) PADDING_SIZE_SMALL = 250 PADDING_SIZE_MEDIUM = 1000 PADDING_SIZE_MAX = 1 << 14 -PADDING_SIZES = [ - PADDING_SIZE_SMALL, - PADDING_SIZE_MEDIUM, - PADDING_SIZE_MAX -] +PADDING_SIZES = [PADDING_SIZE_SMALL, PADDING_SIZE_MEDIUM, PADDING_SIZE_MAX] # arbitrarily large payload size PAYLOAD_SIZE = 1024 @@ -36,7 +42,7 @@ def strip_string_of_bytes(s: str) -> str: def get_payload_size_from_openssl_trace(record_size_bytes: str) -> int: # record_size_bytes is in the form XX XX where X is a hex digit - size_in_hex = record_size_bytes.replace(' ', '') + size_in_hex = record_size_bytes.replace(" ", "") size = int(size_in_hex, 16) # record includes 16 bytes of aead tag return size - 16 @@ -45,11 +51,9 @@ def get_payload_size_from_openssl_trace(record_size_bytes: str) -> int: def assert_openssl_records_are_padded_correctly(openssl_output: str, padding_size: int): number_of_app_data_records = 0 - records_written = re.findall( - OPENSSL_RECORD_WRITTEN_PATTERN, openssl_output) + records_written = re.findall(OPENSSL_RECORD_WRITTEN_PATTERN, openssl_output) for record_prefix in records_written: - app_data_header = re.search( - OPENSSL_APP_DATA_HEADER_PATTERN, record_prefix) + app_data_header = re.search(OPENSSL_APP_DATA_HEADER_PATTERN, record_prefix) if app_data_header: size_bytes = app_data_header.group(RECORD_SIZE_GROUP) size = get_payload_size_from_openssl_trace(size_bytes) @@ -81,8 +85,9 @@ def test_nothing(): @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("padding_size", PADDING_SIZES, ids=get_parameter_name) -def test_s2n_server_handles_padded_records(managed_process, cipher, provider, curve, protocol, certificate, - padding_size): +def test_s2n_server_handles_padded_records( + managed_process, cipher, provider, curve, protocol, certificate, padding_size +): port = next(available_ports) random_bytes = data_bytes(PAYLOAD_SIZE) @@ -95,7 +100,7 @@ def test_s2n_server_handles_padded_records(managed_process, cipher, provider, cu data_to_send=random_bytes, insecure=True, protocol=protocol, - extra_flags=['-record_padding', padding_size] + extra_flags=["-record_padding", padding_size], ) server_options = copy.copy(client_options) @@ -110,7 +115,8 @@ def test_s2n_server_handles_padded_records(managed_process, cipher, provider, cu for client_results in openssl.get_results(): client_results.assert_success() assert_openssl_records_are_padded_correctly( - str(client_results.stdout), padding_size) + str(client_results.stdout), padding_size + ) expected_version = get_expected_s2n_version(protocol, provider) @@ -119,14 +125,20 @@ def test_s2n_server_handles_padded_records(managed_process, cipher, provider, cu # verify that the payload was correctly received by the server assert random_bytes in server_results.stdout # verify that the version was correctly negotiated - assert to_bytes("Actual protocol version: {}".format( - expected_version)) in server_results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + in server_results.stdout + ) # verify that the cipher was correctly negotiated - assert to_bytes("Cipher negotiated: {}".format( - cipher.name)) in server_results.stdout + assert ( + to_bytes("Cipher negotiated: {}".format(cipher.name)) + in server_results.stdout + ) -@pytest.mark.flaky(reruns=5, reruns_delay=2, condition=platform.machine().startswith("aarch")) +@pytest.mark.flaky( + reruns=5, reruns_delay=2, condition=platform.machine().startswith("aarch") +) @pytest.mark.uncollect_if(func=invalid_test_parameters) @pytest.mark.parametrize("cipher", TLS13_CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL]) @@ -135,8 +147,9 @@ def test_s2n_server_handles_padded_records(managed_process, cipher, provider, cu @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("padding_size", PADDING_SIZES, ids=get_parameter_name) -def test_s2n_client_handles_padded_records(managed_process, cipher, provider, curve, protocol, certificate, - padding_size): +def test_s2n_client_handles_padded_records( + managed_process, cipher, provider, curve, protocol, certificate, padding_size +): port = next(available_ports) client_random_bytes = data_bytes(PAYLOAD_SIZE) @@ -152,7 +165,7 @@ def test_s2n_client_handles_padded_records(managed_process, cipher, provider, cu insecure=True, protocol=protocol, data_to_send=server_random_bytes, - extra_flags=['-record_padding', padding_size] + extra_flags=["-record_padding", padding_size], ) client_options = copy.copy(server_options) @@ -161,24 +174,41 @@ def test_s2n_client_handles_padded_records(managed_process, cipher, provider, cu client_options.data_to_send = client_random_bytes # openssl will send its response after it has received s2nc's record - openssl = managed_process(provider, server_options, - timeout=5, send_marker=strip_string_of_bytes(str(client_random_bytes))) + openssl = managed_process( + provider, + server_options, + timeout=5, + send_marker=strip_string_of_bytes(str(client_random_bytes)), + ) # s2nc will wait until it has received the server's response before closing - s2nc = managed_process(S2N, client_options, timeout=5, - close_marker=strip_string_of_bytes(str(server_random_bytes))) + s2nc = managed_process( + S2N, + client_options, + timeout=5, + close_marker=strip_string_of_bytes(str(server_random_bytes)), + expect_stderr=True, + ) expected_version = get_expected_s2n_version(protocol, provider) for client_results in s2nc.get_results(): + # Aware of I/O issues causing this testcase to sometimes fail with a non zero exit status + # https://github.com/aws/s2n-tls/issues/5130 + client_results.expect_nonzero_exit = True client_results.assert_success() # assert that the client has received server's application payload assert server_random_bytes in client_results.stdout - assert to_bytes("Actual protocol version: {}".format( - expected_version)) in client_results.stdout - assert to_bytes("Cipher negotiated: {}".format( - cipher.name)) in client_results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + in client_results.stdout + ) + assert ( + to_bytes("Cipher negotiated: {}".format(cipher.name)) + in client_results.stdout + ) for server_results in openssl.get_results(): server_results.assert_success() assert_openssl_records_are_padded_correctly( - str(server_results.stdout), padding_size) + str(server_results.stdout), padding_size + ) diff --git a/tests/integrationv2/test_renegotiate.py b/tests/integrationv2/test_renegotiate.py index 05afdb391de..b75bc54023c 100644 --- a/tests/integrationv2/test_renegotiate.py +++ b/tests/integrationv2/test_renegotiate.py @@ -4,7 +4,13 @@ import pytest import random -from configuration import available_ports, ALL_TEST_CIPHERS, ALL_TEST_CURVES, MINIMAL_TEST_CERTS, PROTOCOLS +from configuration import ( + available_ports, + ALL_TEST_CIPHERS, + ALL_TEST_CURVES, + MINIMAL_TEST_CERTS, + PROTOCOLS, +) from common import ProviderOptions, Protocols from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N, OpenSSL @@ -20,7 +26,7 @@ S2N_RENEG_ACCEPT = "accept" S2N_RENEG_REJECT = "reject" S2N_RENEG_WAIT = "wait" -OPENSSL_RENEG_CTRL_CMD = 'r\n' +OPENSSL_RENEG_CTRL_CMD = "r\n" # Output indicating renegotiation state @@ -44,17 +50,19 @@ def renegotiate_was_started(s2n_results): def renegotiate_was_successful(s2n_results): - return renegotiate_was_started(s2n_results) and \ - to_bytes(S2N_RENEG_SUCCESS_MARKER) in s2n_results.stdout + return ( + renegotiate_was_started(s2n_results) + and to_bytes(S2N_RENEG_SUCCESS_MARKER) in s2n_results.stdout + ) # Basic conversion methods def to_bytes(val): - return str(val).encode('utf-8') + return str(val).encode("utf-8") def to_marker(val): - return bytes(val).decode('utf-8') + return bytes(val).decode("utf-8") """ @@ -101,7 +109,7 @@ def data_to_send(messages, mode): # but our framework is not good at handling non-ASCII characters due # to inconsistent use of bytes vs decode and str vs encode. # As a workaround, just prepend a throwaway non-ASCII utf-8 character. - data_bytes.append(bytes([0xc2, 0xbb]) + to_bytes(message.data_str)) + data_bytes.append(bytes([0xC2, 0xBB]) + to_bytes(message.data_str)) # We assume that the client will close the connection. # Give the server one last message to send without a corresponding send_marker. # The message will never be sent, but waiting to send it will prevent the server @@ -112,7 +120,11 @@ def data_to_send(messages, mode): @staticmethod def expected_output(messages, mode): - return [to_bytes(message.data_str) for message in messages if not message.ctrl and message.mode is not mode] + return [ + to_bytes(message.data_str) + for message in messages + if not message.ctrl and message.mode is not mode + ] @staticmethod def send_markers(messages, mode): @@ -126,27 +138,31 @@ def send_markers(messages, mode): # Assume that the first sender is s2n send_markers.append(S2N.get_send_marker()) else: - previous = messages[i-1] - assert (previous.mode is not mode) + previous = messages[i - 1] + assert previous.mode is not mode send_markers.append(previous.data_str) return send_markers @staticmethod def close_marker(messages): # Assume that the last sender is the server - assert (messages[-1].mode is Provider.ServerMode) + assert messages[-1].mode is Provider.ServerMode output = Msg.expected_output(messages, Provider.ClientMode) return to_marker(output[-1]) @staticmethod def debug(messages): - print(f'client data to send: {Msg.data_to_send(messages, Provider.ClientMode)}') - print(f'server data to send: {Msg.data_to_send(messages, Provider.ServerMode)}') - print(f'client send markers: {Msg.send_markers(messages, Provider.ClientMode)}') - print(f'server send markers: {Msg.send_markers(messages, Provider.ServerMode)}') - print(f'client close_marker: {Msg.close_marker(messages)}') - print(f'client expected output: {Msg.expected_output(messages, Provider.ClientMode)}') - print(f'server expected output: {Msg.expected_output(messages, Provider.ServerMode)}') + print(f"client data to send: {Msg.data_to_send(messages, Provider.ClientMode)}") + print(f"server data to send: {Msg.data_to_send(messages, Provider.ServerMode)}") + print(f"client send markers: {Msg.send_markers(messages, Provider.ClientMode)}") + print(f"server send markers: {Msg.send_markers(messages, Provider.ServerMode)}") + print(f"client close_marker: {Msg.close_marker(messages)}") + print( + f"client expected output: {Msg.expected_output(messages, Provider.ClientMode)}" + ) + print( + f"server expected output: {Msg.expected_output(messages, Provider.ServerMode)}" + ) # The order of messages that will trigger renegotiation @@ -166,7 +182,16 @@ def debug(messages): ] -def basic_reneg_test(managed_process, cipher, curve, certificate, protocol, provider, messages=RENEG_MESSAGES, reneg_option=None): +def basic_reneg_test( + managed_process, + cipher, + curve, + certificate, + protocol, + provider, + messages=RENEG_MESSAGES, + reneg_option=None, +): options = ProviderOptions( port=next(available_ports), cipher=cipher, @@ -188,16 +213,20 @@ def basic_reneg_test(managed_process, cipher, curve, certificate, protocol, prov server_options.mode = Provider.ServerMode server_options.data_to_send = Msg.data_to_send(messages, Provider.ServerMode) - server = managed_process(provider, server_options, - send_marker=Msg.send_markers(messages, Provider.ServerMode), - timeout=8 - ) + server = managed_process( + provider, + server_options, + send_marker=Msg.send_markers(messages, Provider.ServerMode), + timeout=8, + ) - s2n_client = managed_process(S2N, client_options, - send_marker=Msg.send_markers(messages, Provider.ClientMode), - close_marker=Msg.close_marker(messages), - timeout=8 - ) + s2n_client = managed_process( + S2N, + client_options, + send_marker=Msg.send_markers(messages, Provider.ClientMode), + close_marker=Msg.close_marker(messages), + timeout=8, + ) return (s2n_client, server) @@ -215,8 +244,12 @@ def basic_reneg_test(managed_process, cipher, curve, certificate, protocol, prov @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("protocol", TEST_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) -def test_s2n_client_ignores_openssl_hello_request(managed_process, cipher, curve, certificate, protocol, provider): - (s2n_client, server) = basic_reneg_test(managed_process, cipher, curve, certificate, protocol, provider) +def test_s2n_client_ignores_openssl_hello_request( + managed_process, cipher, curve, certificate, protocol, provider +): + (s2n_client, server) = basic_reneg_test( + managed_process, cipher, curve, certificate, protocol, provider + ) for results in server.get_results(): results.assert_success() @@ -243,9 +276,18 @@ def test_s2n_client_ignores_openssl_hello_request(managed_process, cipher, curve @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("protocol", TEST_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) -def test_s2n_client_rejects_openssl_hello_request(managed_process, cipher, curve, certificate, protocol, provider): - (s2n_client, server) = basic_reneg_test(managed_process, cipher, curve, certificate, protocol, provider, - reneg_option=S2N_RENEG_REJECT) +def test_s2n_client_rejects_openssl_hello_request( + managed_process, cipher, curve, certificate, protocol, provider +): + (s2n_client, server) = basic_reneg_test( + managed_process, + cipher, + curve, + certificate, + protocol, + provider, + reneg_option=S2N_RENEG_REJECT, + ) for results in server.get_results(): assert renegotiate_was_requested(results) @@ -268,9 +310,18 @@ def test_s2n_client_rejects_openssl_hello_request(managed_process, cipher, curve @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("protocol", TEST_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) -def test_s2n_client_renegotiate_with_openssl(managed_process, cipher, curve, certificate, protocol, provider): - (s2n_client, server) = basic_reneg_test(managed_process, cipher, curve, certificate, protocol, provider, - reneg_option=S2N_RENEG_ACCEPT) +def test_s2n_client_renegotiate_with_openssl( + managed_process, cipher, curve, certificate, protocol, provider +): + (s2n_client, server) = basic_reneg_test( + managed_process, + cipher, + curve, + certificate, + protocol, + provider, + reneg_option=S2N_RENEG_ACCEPT, + ) for results in server.get_results(): results.assert_success() @@ -301,19 +352,29 @@ def test_s2n_client_renegotiate_with_openssl(managed_process, cipher, curve, cer @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("protocol", TEST_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) -def test_s2n_client_renegotiate_with_client_auth_with_openssl(managed_process, cipher, curve, certificate, protocol, provider): +def test_s2n_client_renegotiate_with_client_auth_with_openssl( + managed_process, cipher, curve, certificate, protocol, provider +): # We want to use the same messages to test renegotiation, # but with 'R' instead of 'r' to trigger the Openssl renegotiate request. messages = copy.deepcopy(RENEG_MESSAGES) for m in messages: if m.ctrl: - m.data_str = 'R\n' + m.data_str = "R\n" client_auth_marker = "|CLIENT_AUTH" no_client_cert_marker = "|NO_CLIENT_CERT" - (s2n_client, server) = basic_reneg_test(managed_process, cipher, curve, certificate, protocol, provider, - messages=messages, reneg_option=S2N_RENEG_WAIT) + (s2n_client, server) = basic_reneg_test( + managed_process, + cipher, + curve, + certificate, + protocol, + provider, + messages=messages, + reneg_option=S2N_RENEG_WAIT, + ) for results in server.get_results(): results.assert_success() @@ -351,10 +412,19 @@ def test_s2n_client_renegotiate_with_client_auth_with_openssl(managed_process, c @pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("protocol", TEST_PROTOCOLS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) -def test_s2n_client_renegotiate_with_app_data_with_openssl(managed_process, cipher, curve, certificate, protocol, provider): +def test_s2n_client_renegotiate_with_app_data_with_openssl( + managed_process, cipher, curve, certificate, protocol, provider +): first_server_app_data = Msg.expected_output(RENEG_MESSAGES, Provider.ClientMode)[0] - (s2n_client, server) = basic_reneg_test(managed_process, cipher, curve, certificate, protocol, provider, - reneg_option=S2N_RENEG_WAIT) + (s2n_client, server) = basic_reneg_test( + managed_process, + cipher, + curve, + certificate, + protocol, + provider, + reneg_option=S2N_RENEG_WAIT, + ) for results in server.get_results(): results.assert_success() diff --git a/tests/integrationv2/test_renegotiate_apache.py b/tests/integrationv2/test_renegotiate_apache.py index 2a23640ed6d..2407fc4fd3c 100644 --- a/tests/integrationv2/test_renegotiate_apache.py +++ b/tests/integrationv2/test_renegotiate_apache.py @@ -29,7 +29,9 @@ def create_get_request(route): @pytest.mark.uncollect_if(func=invalid_test_parameters) @pytest.mark.parametrize("protocol", TEST_PROTOCOLS, ids=get_parameter_name) -@pytest.mark.parametrize("endpoint", [CHANGE_CIPHER_SUITE_ENDPOINT, MUTUAL_AUTH_ENDPOINT]) +@pytest.mark.parametrize( + "endpoint", [CHANGE_CIPHER_SUITE_ENDPOINT, MUTUAL_AUTH_ENDPOINT] +) def test_apache_endpoints_fail_with_no_reneg(managed_process, protocol, endpoint): options = ProviderOptions( mode=Provider.ClientMode, @@ -40,7 +42,7 @@ def test_apache_endpoints_fail_with_no_reneg(managed_process, protocol, endpoint trust_store=APACHE_SERVER_CERT, cert=APACHE_CLIENT_CERT, key=APACHE_CLIENT_KEY, - use_client_auth=True + use_client_auth=True, ) with tempfile.NamedTemporaryFile("w+") as http_request_file: @@ -48,13 +50,17 @@ def test_apache_endpoints_fail_with_no_reneg(managed_process, protocol, endpoint http_request_file.flush() options.extra_flags = ["--send-file", http_request_file.name] - s2n_client = managed_process(S2N, options, timeout=20, close_marker="You don't have permission") + s2n_client = managed_process( + S2N, options, timeout=20, close_marker="You don't have permission" + ) for results in s2n_client.get_results(): results.assert_success() assert b"403 Forbidden" in results.stdout - assert b"You don't have permission to access this resource." in results.stdout + assert ( + b"You don't have permission to access this resource." in results.stdout + ) @pytest.mark.uncollect_if(func=invalid_test_parameters) @@ -99,7 +105,7 @@ def test_mutual_auth_endpoint(managed_process, curve, protocol): trust_store=APACHE_SERVER_CERT, cert=APACHE_CLIENT_CERT, key=APACHE_CLIENT_KEY, - use_client_auth=True + use_client_auth=True, ) options.extra_flags = [S2N_RENEG_OPTION, S2N_RENEG_ACCEPT] diff --git a/tests/integrationv2/test_serialization.py b/tests/integrationv2/test_serialization.py index 3605b0cd08d..8b3f7aa5f73 100644 --- a/tests/integrationv2/test_serialization.py +++ b/tests/integrationv2/test_serialization.py @@ -11,8 +11,8 @@ from providers import Provider, S2N from utils import invalid_test_parameters, get_parameter_name, to_bytes -SERVER_STATE_FILE = 'server_state' -CLIENT_STATE_FILE = 'client_state' +SERVER_STATE_FILE = "server_state" +CLIENT_STATE_FILE = "client_state" SERVER_DATA = f"Some random data from the server:" + random_str(10) CLIENT_DATA = f"Some random data from the client:" + random_str(10) @@ -41,10 +41,20 @@ class Mode(Enum): @pytest.mark.uncollect_if(func=invalid_test_parameters) -@pytest.mark.parametrize("protocol", [Protocols.TLS13, Protocols.TLS12], ids=get_parameter_name) -@pytest.mark.parametrize("mainline_role", [MainlineRole.Serialize, MainlineRole.Deserialize], ids=get_parameter_name) -@pytest.mark.parametrize("version_change", [Mode.Server, Mode.Client], ids=get_parameter_name) -def test_server_serialization_backwards_compat(managed_process, tmp_path, protocol, mainline_role, version_change): +@pytest.mark.parametrize( + "protocol", [Protocols.TLS13, Protocols.TLS12], ids=get_parameter_name +) +@pytest.mark.parametrize( + "mainline_role", + [MainlineRole.Serialize, MainlineRole.Deserialize], + ids=get_parameter_name, +) +@pytest.mark.parametrize( + "version_change", [Mode.Server, Mode.Client], ids=get_parameter_name +) +def test_server_serialization_backwards_compat( + managed_process, tmp_path, protocol, mainline_role, version_change +): server_state_file = str(tmp_path / SERVER_STATE_FILE) client_state_file = str(tmp_path / CLIENT_STATE_FILE) assert not os.path.exists(server_state_file) @@ -58,11 +68,11 @@ def test_server_serialization_backwards_compat(managed_process, tmp_path, protoc client_options = copy.copy(options) client_options.mode = Provider.ClientMode - client_options.extra_flags = ['--serialize-out', client_state_file] + client_options.extra_flags = ["--serialize-out", client_state_file] server_options = copy.copy(options) server_options.mode = Provider.ServerMode - server_options.extra_flags = ['--serialize-out', server_state_file] + server_options.extra_flags = ["--serialize-out", server_state_file] if mainline_role is MainlineRole.Serialize: if version_change == Mode.Server: @@ -70,23 +80,28 @@ def test_server_serialization_backwards_compat(managed_process, tmp_path, protoc else: client_options.use_mainline_version = True - server = managed_process( - S2N, server_options, send_marker=S2N.get_send_marker()) + server = managed_process(S2N, server_options, send_marker=S2N.get_send_marker()) client = managed_process(S2N, client_options, send_marker=S2N.get_send_marker()) for results in client.get_results(): results.assert_success() - assert to_bytes("Actual protocol version: {}".format(protocol.value)) in results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(protocol.value)) + in results.stdout + ) for results in server.get_results(): results.assert_success() - assert to_bytes("Actual protocol version: {}".format(protocol.value)) in results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(protocol.value)) + in results.stdout + ) assert os.path.exists(server_state_file) assert os.path.exists(client_state_file) - client_options.extra_flags = ['--deserialize-in', client_state_file] - server_options.extra_flags = ['--deserialize-in', server_state_file] + client_options.extra_flags = ["--deserialize-in", client_state_file] + server_options.extra_flags = ["--deserialize-in", server_state_file] if mainline_role is MainlineRole.Deserialize: if version_change == Mode.Server: server_options.use_mainline_version = True @@ -97,7 +112,12 @@ def test_server_serialization_backwards_compat(managed_process, tmp_path, protoc client_options.data_to_send = CLIENT_DATA.encode() server = managed_process(S2N, server_options, send_marker=CLIENT_DATA) - client = managed_process(S2N, client_options, send_marker="Connected to localhost", close_marker=SERVER_DATA) + client = managed_process( + S2N, + client_options, + send_marker="Connected to localhost", + close_marker=SERVER_DATA, + ) for results in server.get_results(): results.assert_success() diff --git a/tests/integrationv2/test_session_resumption.py b/tests/integrationv2/test_session_resumption.py index 89f332fc284..11341dc1d9e 100644 --- a/tests/integrationv2/test_session_resumption.py +++ b/tests/integrationv2/test_session_resumption.py @@ -5,23 +5,45 @@ import platform import pytest -from configuration import available_ports, ALL_TEST_CIPHERS, ALL_TEST_CURVES, ALL_TEST_CERTS, PROTOCOLS, TLS13_CIPHERS +from configuration import ( + available_ports, + ALL_TEST_CIPHERS, + ALL_TEST_CURVES, + ALL_TEST_CERTS, + PROTOCOLS, + TLS13_CIPHERS, +) from common import ProviderOptions, Protocols, data_bytes from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N, OpenSSL -from utils import invalid_test_parameters, get_parameter_name, get_expected_s2n_version, to_bytes +from utils import ( + invalid_test_parameters, + get_parameter_name, + get_expected_s2n_version, + to_bytes, +) @pytest.mark.uncollect_if(func=invalid_test_parameters) @pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name) @pytest.mark.parametrize("certificate", ALL_TEST_CERTS, ids=get_parameter_name) -@pytest.mark.parametrize("protocol", [p for p in PROTOCOLS if p != Protocols.TLS13], ids=get_parameter_name) +@pytest.mark.parametrize( + "protocol", [p for p in PROTOCOLS if p != Protocols.TLS13], ids=get_parameter_name +) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("use_ticket", [True, False]) -def test_session_resumption_s2n_server(managed_process, cipher, curve, certificate, protocol, provider, other_provider, - use_ticket): +def test_session_resumption_s2n_server( + managed_process, + cipher, + curve, + certificate, + protocol, + provider, + other_provider, + use_ticket, +): port = next(available_ports) client_options = ProviderOptions( @@ -31,12 +53,13 @@ def test_session_resumption_s2n_server(managed_process, cipher, curve, certifica curve=curve, insecure=True, reconnect=True, - protocol=protocol) + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.reconnects_before_exit = 6 server_options.mode = Provider.ServerMode - server_options.use_session_ticket = use_ticket, + server_options.use_session_ticket = (use_ticket,) server_options.key = certificate.key server_options.cert = certificate.cert @@ -55,20 +78,34 @@ def test_session_resumption_s2n_server(managed_process, cipher, curve, certifica # S2N should indicate the procotol version in a successful connection. for results in server.get_results(): results.assert_success() - assert results.stdout.count( - to_bytes("Actual protocol version: {}".format(expected_version))) == 6 + assert ( + results.stdout.count( + to_bytes("Actual protocol version: {}".format(expected_version)) + ) + == 6 + ) @pytest.mark.uncollect_if(func=invalid_test_parameters) @pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name) @pytest.mark.parametrize("certificate", ALL_TEST_CERTS, ids=get_parameter_name) -@pytest.mark.parametrize("protocol", [p for p in PROTOCOLS if p != Protocols.TLS13], ids=get_parameter_name) +@pytest.mark.parametrize( + "protocol", [p for p in PROTOCOLS if p != Protocols.TLS13], ids=get_parameter_name +) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) @pytest.mark.parametrize("use_ticket", [True, False]) -def test_session_resumption_s2n_client(managed_process, cipher, curve, protocol, provider, other_provider, certificate, - use_ticket): +def test_session_resumption_s2n_client( + managed_process, + cipher, + curve, + protocol, + provider, + other_provider, + certificate, + use_ticket, +): port = next(available_ports) client_options = ProviderOptions( @@ -79,7 +116,8 @@ def test_session_resumption_s2n_client(managed_process, cipher, curve, protocol, insecure=True, reconnect=True, use_session_ticket=use_ticket, - protocol=protocol) + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.reconnects_before_exit = 6 @@ -96,8 +134,12 @@ def test_session_resumption_s2n_client(managed_process, cipher, curve, protocol, expected_version = get_expected_s2n_version(protocol, OpenSSL) for results in client.get_results(): results.assert_success() - assert results.stdout.count( - to_bytes("Actual protocol version: {}".format(expected_version))) == 6 + assert ( + results.stdout.count( + to_bytes("Actual protocol version: {}".format(expected_version)) + ) + == 6 + ) for results in server.get_results(): results.assert_success() @@ -111,12 +153,20 @@ def test_session_resumption_s2n_client(managed_process, cipher, curve, protocol, @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -def test_tls13_session_resumption_s2n_server(managed_process, tmp_path, cipher, curve, certificate, protocol, provider, - other_provider): +def test_tls13_session_resumption_s2n_server( + managed_process, + tmp_path, + cipher, + curve, + certificate, + protocol, + provider, + other_provider, +): port = str(next(available_ports)) # Use temp directory to store session tickets - p = tmp_path / 'ticket.pem' + p = tmp_path / "ticket.pem" path_to_ticket = str(p) close_marker_bytes = data_bytes(10) @@ -128,8 +178,9 @@ def test_tls13_session_resumption_s2n_server(managed_process, tmp_path, cipher, curve=curve, insecure=True, reconnect=False, - extra_flags=['-sess_out', path_to_ticket], - protocol=protocol) + extra_flags=["-sess_out", path_to_ticket], + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.mode = Provider.ServerMode @@ -140,47 +191,55 @@ def test_tls13_session_resumption_s2n_server(managed_process, tmp_path, cipher, server_options.data_to_send = close_marker_bytes server = managed_process( - S2N, server_options, timeout=5, send_marker=S2N.get_send_marker()) - client = managed_process(provider, client_options, - timeout=5, close_marker=str(close_marker_bytes)) + S2N, server_options, timeout=5, send_marker=S2N.get_send_marker() + ) + client = managed_process( + provider, client_options, timeout=5, close_marker=str(close_marker_bytes) + ) # The client should have received a session ticket for results in client.get_results(): results.assert_success() - assert b'Post-Handshake New Session Ticket arrived:' in results.stdout + assert b"Post-Handshake New Session Ticket arrived:" in results.stdout for results in server.get_results(): results.assert_success() # The first connection is a full handshake - assert b'Resumed session' not in results.stdout + assert b"Resumed session" not in results.stdout # Client inputs received session ticket to resume a session assert os.path.exists(path_to_ticket) - client_options.extra_flags = ['-sess_in', path_to_ticket] + client_options.extra_flags = ["-sess_in", path_to_ticket] port = str(next(available_ports)) client_options.port = port server_options.port = port server = managed_process( - S2N, server_options, timeout=5, send_marker=S2N.get_send_marker()) - client = managed_process(provider, client_options, - timeout=5, close_marker=str(close_marker_bytes)) + S2N, server_options, timeout=5, send_marker=S2N.get_send_marker() + ) + client = managed_process( + provider, client_options, timeout=5, close_marker=str(close_marker_bytes) + ) s2n_version = get_expected_s2n_version(protocol, provider) # Client has not read server certificate message as this is a resumed session for results in client.get_results(): results.assert_success() - assert to_bytes( - "SSL_connect:SSLv3/TLS read server certificate") not in results.stderr + assert ( + to_bytes("SSL_connect:SSLv3/TLS read server certificate") + not in results.stderr + ) # The server should indicate a session has been resumed for results in server.get_results(): results.assert_success() - assert b'Resumed session' in results.stdout - assert to_bytes("Actual protocol version: {}".format( - s2n_version)) in results.stdout + assert b"Resumed session" in results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(s2n_version)) + in results.stdout + ) @pytest.mark.uncollect_if(func=invalid_test_parameters) @@ -190,8 +249,9 @@ def test_tls13_session_resumption_s2n_server(managed_process, tmp_path, cipher, @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL, S2N], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -def test_tls13_session_resumption_s2n_client(managed_process, cipher, curve, certificate, protocol, provider, - other_provider): +def test_tls13_session_resumption_s2n_client( + managed_process, cipher, curve, certificate, protocol, provider, other_provider +): port = str(next(available_ports)) # The reconnect option for s2nc allows the client to reconnect automatically @@ -208,13 +268,16 @@ def test_tls13_session_resumption_s2n_client(managed_process, cipher, curve, cer insecure=True, use_session_ticket=True, reconnect=True, - protocol=protocol) + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.mode = Provider.ServerMode server_options.key = certificate.key server_options.cert = certificate.cert - server_options.reconnects_before_exit = num_resumed_connections + num_full_connections + server_options.reconnects_before_exit = ( + num_resumed_connections + num_full_connections + ) server = managed_process(provider, server_options, timeout=5) client = managed_process(S2N, client_options, timeout=5) @@ -224,29 +287,37 @@ def test_tls13_session_resumption_s2n_client(managed_process, cipher, curve, cer # s2nc indicates the number of resumed connections in its output for results in client.get_results(): results.assert_success() - assert results.stdout.count( - b'Resumed session') == num_resumed_connections - assert to_bytes("Actual protocol version: {}".format( - s2n_version)) in results.stdout + assert results.stdout.count(b"Resumed session") == num_resumed_connections + assert ( + to_bytes("Actual protocol version: {}".format(s2n_version)) + in results.stdout + ) - server_accepts_str = str( - num_resumed_connections + num_full_connections) + " server accepts that finished" + server_accepts_str = ( + str(num_resumed_connections + num_full_connections) + + " server accepts that finished" + ) for results in server.get_results(): results.assert_success() if provider is S2N: - assert results.stdout.count( - b'Resumed session') == num_resumed_connections - assert to_bytes("Actual protocol version: {}".format( - s2n_version)) in results.stdout + assert results.stdout.count(b"Resumed session") == num_resumed_connections + assert ( + to_bytes("Actual protocol version: {}".format(s2n_version)) + in results.stdout + ) else: assert to_bytes(server_accepts_str) in results.stdout # s_server only writes one certificate message in all of the connections - assert results.stderr.count( - b'SSL_accept:SSLv3/TLS write certificate') == num_full_connections + assert ( + results.stderr.count(b"SSL_accept:SSLv3/TLS write certificate") + == num_full_connections + ) -@pytest.mark.flaky(reruns=7, reruns_delay=2, condition=platform.machine().startswith("aarch")) +@pytest.mark.flaky( + reruns=7, reruns_delay=2, condition=platform.machine().startswith("aarch") +) @pytest.mark.uncollect_if(func=invalid_test_parameters) @pytest.mark.parametrize("cipher", TLS13_CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name) @@ -254,12 +325,20 @@ def test_tls13_session_resumption_s2n_client(managed_process, cipher, curve, cer @pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -def test_s2nd_falls_back_to_full_connection(managed_process, tmp_path, cipher, curve, certificate, protocol, provider, - other_provider): +def test_s2nd_falls_back_to_full_connection( + managed_process, + tmp_path, + cipher, + curve, + certificate, + protocol, + provider, + other_provider, +): port = str(next(available_ports)) # Use temp directory to store session tickets - p = tmp_path / 'ticket.pem' + p = tmp_path / "ticket.pem" path_to_ticket = str(p) """ @@ -275,9 +354,10 @@ def test_s2nd_falls_back_to_full_connection(managed_process, tmp_path, cipher, c curve=curve, insecure=True, reconnect=False, - extra_flags=['-sess_out', path_to_ticket], + extra_flags=["-sess_out", path_to_ticket], data_to_send=data_bytes(4069), - protocol=protocol) + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.mode = Provider.ServerMode @@ -291,16 +371,16 @@ def test_s2nd_falls_back_to_full_connection(managed_process, tmp_path, cipher, c # The client should have received a session ticket for results in client.get_results(): results.assert_success() - assert b'Post-Handshake New Session Ticket arrived:' in results.stdout + assert b"Post-Handshake New Session Ticket arrived:" in results.stdout for results in server.get_results(): results.assert_success() # Server should have sent certificate message as this is a full connection - assert b'SSL_accept:SSLv3/TLS write certificate' in results.stderr + assert b"SSL_accept:SSLv3/TLS write certificate" in results.stderr # Client inputs received session ticket to resume a session assert os.path.exists(path_to_ticket) - client_options.extra_flags = ['-sess_in', path_to_ticket] + client_options.extra_flags = ["-sess_in", path_to_ticket] port = str(next(available_ports)) client_options.port = port @@ -315,25 +395,32 @@ def test_s2nd_falls_back_to_full_connection(managed_process, tmp_path, cipher, c # Client has read server certificate because this is a full connection for results in client.get_results(): results.assert_success() - assert to_bytes( - "SSL_connect:SSLv3/TLS read server certificate") in results.stderr + assert ( + to_bytes("SSL_connect:SSLv3/TLS read server certificate") in results.stderr + ) # The server should indicate a session has not been resumed for results in server.get_results(): results.assert_success() - assert b'Resumed session' not in results.stdout - assert to_bytes("Actual protocol version: {}".format( - s2n_version)) in results.stdout + assert b"Resumed session" not in results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(s2n_version)) + in results.stdout + ) @pytest.mark.uncollect_if(func=invalid_test_parameters) @pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name) @pytest.mark.parametrize("certificate", ALL_TEST_CERTS, ids=get_parameter_name) -@pytest.mark.parametrize("protocol", [p for p in PROTOCOLS if p < Protocols.TLS13], ids=get_parameter_name) +@pytest.mark.parametrize( + "protocol", [p for p in PROTOCOLS if p < Protocols.TLS13], ids=get_parameter_name +) @pytest.mark.parametrize("provider", [OpenSSL, S2N], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -def test_session_resumption_s2n_client_tls13_server_not_tls13(managed_process, cipher, curve, protocol, provider, other_provider, certificate): +def test_session_resumption_s2n_client_tls13_server_not_tls13( + managed_process, cipher, curve, protocol, provider, other_provider, certificate +): port = next(available_ports) # This test verifies that an S2N client that supports TLS1.3 can resume sessions @@ -353,7 +440,8 @@ def test_session_resumption_s2n_client_tls13_server_not_tls13(managed_process, c insecure=True, reconnect=True, use_session_ticket=True, - protocol=Protocols.TLS13) + protocol=Protocols.TLS13, + ) server_options = ProviderOptions( mode=Provider.ServerMode, @@ -366,7 +454,8 @@ def test_session_resumption_s2n_client_tls13_server_not_tls13(managed_process, c protocol=protocol, reconnects_before_exit=num_resumed_connections + num_full_connections, key=certificate.key, - cert=certificate.cert) + cert=certificate.cert, + ) # Passing the type of client and server as a parameter will # allow us to use a fixture to enumerate all possibilities. @@ -374,22 +463,26 @@ def test_session_resumption_s2n_client_tls13_server_not_tls13(managed_process, c client = managed_process(S2N, client_options, timeout=5) expected_version = get_expected_s2n_version(protocol, provider) - server_accepts_str = str( - num_resumed_connections + num_full_connections) + " server accepts that finished" + server_accepts_str = ( + str(num_resumed_connections + num_full_connections) + + " server accepts that finished" + ) for results in client.get_results(): results.assert_success() - assert results.stdout.count( - b'Resumed session') == num_resumed_connections - assert to_bytes("Actual protocol version: {}".format( - expected_version)) in results.stdout + assert results.stdout.count(b"Resumed session") == num_resumed_connections + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + in results.stdout + ) for results in server.get_results(): results.assert_success() if provider is S2N: - assert results.stdout.count( - b'Resumed session') == num_resumed_connections - assert to_bytes("Actual protocol version: {}".format( - expected_version)) in results.stdout + assert results.stdout.count(b"Resumed session") == num_resumed_connections + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + in results.stdout + ) else: assert to_bytes(server_accepts_str) in results.stdout diff --git a/tests/integrationv2/test_signature_algorithms.py b/tests/integrationv2/test_signature_algorithms.py index 2b52e686e4b..38cde001c8a 100644 --- a/tests/integrationv2/test_signature_algorithms.py +++ b/tests/integrationv2/test_signature_algorithms.py @@ -7,7 +7,12 @@ from common import ProviderOptions, Protocols, Signatures, data_bytes from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N, OpenSSL, GnuTLS -from utils import invalid_test_parameters, get_parameter_name, get_expected_s2n_version, to_bytes +from utils import ( + invalid_test_parameters, + get_parameter_name, + get_expected_s2n_version, + to_bytes, +) all_sigs = [ Signatures.RSA_SHA1, @@ -33,7 +38,7 @@ def expected_signature(protocol, signature): # ECDSA by default hashes with SHA-1. # # This is inferred from extended version of TLS1.1 rfc- https://www.rfc-editor.org/rfc/rfc4492#section-5.10 - if signature.sig_type == 'ECDSA': + if signature.sig_type == "ECDSA": signature = Signatures.ECDSA_SHA1 else: signature = Signatures.RSA_MD5_SHA1 @@ -41,16 +46,19 @@ def expected_signature(protocol, signature): def signature_marker(mode, signature): - return to_bytes("{mode} signature negotiated: {type}+{digest}" - .format(mode=mode.title(), type=signature.sig_type, digest=signature.sig_digest)) + return to_bytes( + "{mode} signature negotiated: {type}+{digest}".format( + mode=mode.title(), type=signature.sig_type, digest=signature.sig_digest + ) + ) def skip_ciphers(*args, **kwargs): - provider = kwargs.get('provider') - cert = kwargs.get('certificate') - cipher = kwargs.get('cipher') - protocol = kwargs.get('protocol') - sigalg = kwargs.get('signature') + provider = kwargs.get("provider") + cert = kwargs.get("certificate") + cipher = kwargs.get("cipher") + protocol = kwargs.get("protocol") + sigalg = kwargs.get("signature") if not provider.supports_signature(sigalg): return True @@ -74,12 +82,28 @@ def skip_ciphers(*args, **kwargs): @pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL, GnuTLS]) @pytest.mark.parametrize("other_provider", [S2N]) -@pytest.mark.parametrize("protocol", [Protocols.TLS13, Protocols.TLS12, Protocols.TLS11], ids=get_parameter_name) +@pytest.mark.parametrize( + "protocol", + [Protocols.TLS13, Protocols.TLS12, Protocols.TLS11], + ids=get_parameter_name, +) @pytest.mark.parametrize("certificate", ALL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("signature", all_sigs, ids=get_parameter_name) -@pytest.mark.parametrize("client_auth", [True, False], ids=lambda val: "client-auth" if val else "no-client-auth") -def test_s2n_server_signature_algorithms(managed_process, cipher, provider, other_provider, protocol, certificate, - signature, client_auth): +@pytest.mark.parametrize( + "client_auth", + [True, False], + ids=lambda val: "client-auth" if val else "no-client-auth", +) +def test_s2n_server_signature_algorithms( + managed_process, + cipher, + provider, + other_provider, + protocol, + certificate, + signature, + client_auth, +): port = next(available_ports) random_bytes = data_bytes(64) @@ -93,7 +117,7 @@ def test_s2n_server_signature_algorithms(managed_process, cipher, provider, othe key=certificate.key, cert=certificate.cert, signature_algorithm=signature, - protocol=protocol + protocol=protocol, ) if provider == GnuTLS: @@ -117,12 +141,22 @@ def test_s2n_server_signature_algorithms(managed_process, cipher, provider, othe for results in server.get_results(): results.assert_success() - assert to_bytes("Actual protocol version: {}".format( - expected_version)) in results.stdout - assert signature_marker(Provider.ServerMode, - expected_signature(protocol, signature)) in results.stdout - assert (signature_marker(Provider.ClientMode, - expected_signature(protocol, signature)) in results.stdout) == client_auth + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + in results.stdout + ) + assert ( + signature_marker( + Provider.ServerMode, expected_signature(protocol, signature) + ) + in results.stdout + ) + assert ( + signature_marker( + Provider.ClientMode, expected_signature(protocol, signature) + ) + in results.stdout + ) == client_auth assert random_bytes in results.stdout @@ -130,12 +164,28 @@ def test_s2n_server_signature_algorithms(managed_process, cipher, provider, othe @pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("provider", [OpenSSL, GnuTLS]) @pytest.mark.parametrize("other_provider", [S2N]) -@pytest.mark.parametrize("protocol", [Protocols.TLS13, Protocols.TLS12, Protocols.TLS11], ids=get_parameter_name) +@pytest.mark.parametrize( + "protocol", + [Protocols.TLS13, Protocols.TLS12, Protocols.TLS11], + ids=get_parameter_name, +) @pytest.mark.parametrize("certificate", ALL_TEST_CERTS, ids=get_parameter_name) @pytest.mark.parametrize("signature", all_sigs, ids=get_parameter_name) -@pytest.mark.parametrize("client_auth", [True, False], ids=lambda val: "client-auth" if val else "no-client-auth") -def test_s2n_client_signature_algorithms(managed_process, cipher, provider, other_provider, protocol, certificate, - signature, client_auth): +@pytest.mark.parametrize( + "client_auth", + [True, False], + ids=lambda val: "client-auth" if val else "no-client-auth", +) +def test_s2n_client_signature_algorithms( + managed_process, + cipher, + provider, + other_provider, + protocol, + certificate, + signature, + client_auth, +): port = next(available_ports) random_bytes = data_bytes(64) @@ -148,7 +198,8 @@ def test_s2n_client_signature_algorithms(managed_process, cipher, provider, othe use_client_auth=client_auth, key=certificate.key, cert=certificate.cert, - protocol=protocol) + protocol=protocol, + ) server_options = copy.copy(client_options) server_options.data_to_send = None @@ -162,14 +213,14 @@ def test_s2n_client_signature_algorithms(managed_process, cipher, provider, othe if provider == GnuTLS: kill_marker = random_bytes - server = managed_process(provider, server_options, - timeout=5, kill_marker=kill_marker) + server = managed_process( + provider, server_options, timeout=5, kill_marker=kill_marker + ) client = managed_process(S2N, client_options, timeout=5) for results in server.get_results(): results.assert_success() - assert any( - [random_bytes in stream for stream in results.output_streams()]) + assert any([random_bytes in stream for stream in results.output_streams()]) expected_version = get_expected_s2n_version(protocol, provider) @@ -181,14 +232,24 @@ def test_s2n_client_signature_algorithms(managed_process, cipher, provider, othe # # This mostly has to be inferred from the RFCs, but this blog post is a pretty good summary # of the situation: https://timtaubert.de/blog/2016/07/the-evolution-of-signatures-in-tls/ - server_sigalg_used = not cipher.iana_standard_name.startswith( - "TLS_RSA_WITH_") + server_sigalg_used = not cipher.iana_standard_name.startswith("TLS_RSA_WITH_") for results in client.get_results(): results.assert_success() - assert to_bytes("Actual protocol version: {}".format( - expected_version)) in results.stdout - assert signature_marker( - Provider.ServerMode, expected_signature(protocol, signature)) in results.stdout or not server_sigalg_used - assert (signature_marker(Provider.ClientMode, expected_signature(protocol, signature)) - in results.stdout) == client_auth + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + in results.stdout + ) + assert ( + signature_marker( + Provider.ServerMode, expected_signature(protocol, signature) + ) + in results.stdout + or not server_sigalg_used + ) + assert ( + signature_marker( + Provider.ClientMode, expected_signature(protocol, signature) + ) + in results.stdout + ) == client_auth diff --git a/tests/integrationv2/test_sni_match.py b/tests/integrationv2/test_sni_match.py index 4abb5b75dcc..0ef27bad89e 100644 --- a/tests/integrationv2/test_sni_match.py +++ b/tests/integrationv2/test_sni_match.py @@ -6,7 +6,12 @@ from common import ProviderOptions, Protocols from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N, OpenSSL -from utils import invalid_test_parameters, get_parameter_name, get_expected_s2n_version, to_bytes +from utils import ( + invalid_test_parameters, + get_parameter_name, + get_expected_s2n_version, + to_bytes, +) def filter_cipher_list(*args, **kwargs): @@ -17,11 +22,12 @@ def filter_cipher_list(*args, **kwargs): This function handles that unique grouping. """ - protocol = kwargs.get('protocol') - cert_test_case = kwargs.get('cert_test_case') + protocol = kwargs.get("protocol") + cert_test_case = kwargs.get("cert_test_case") lowest_protocol_cipher = min( - cert_test_case.client_ciphers, key=lambda x: x.min_version) + cert_test_case.client_ciphers, key=lambda x: x.min_version + ) if protocol < lowest_protocol_cipher.min_version: return True @@ -31,7 +37,9 @@ def filter_cipher_list(*args, **kwargs): @pytest.mark.uncollect_if(func=filter_cipher_list) @pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -@pytest.mark.parametrize("protocol", [Protocols.TLS13, Protocols.TLS12], ids=get_parameter_name) +@pytest.mark.parametrize( + "protocol", [Protocols.TLS13, Protocols.TLS12], ids=get_parameter_name +) @pytest.mark.parametrize("cert_test_case", MULTI_CERT_TEST_CASES) def test_sni_match(managed_process, provider, other_provider, protocol, cert_test_case): port = next(available_ports) @@ -43,20 +51,18 @@ def test_sni_match(managed_process, provider, other_provider, protocol, cert_tes verify_hostname=True, server_name=cert_test_case.client_sni, cipher=cert_test_case.client_ciphers, - protocol=protocol) + protocol=protocol, + ) server_options = ProviderOptions( - mode=Provider.ServerMode, - port=port, - extra_flags=[], - protocol=protocol) + mode=Provider.ServerMode, port=port, extra_flags=[], protocol=protocol + ) # Setup the certificate chain for S2ND based on the multicert test case - cert_key_list = [(cert[0], cert[1]) - for cert in cert_test_case.server_certs] + cert_key_list = [(cert[0], cert[1]) for cert in cert_test_case.server_certs] for cert_key_path in cert_key_list: - server_options.extra_flags.extend(['--cert', cert_key_path[0]]) - server_options.extra_flags.extend(['--key', cert_key_path[1]]) + server_options.extra_flags.extend(["--cert", cert_key_path[0]]) + server_options.extra_flags.extend(["--key", cert_key_path[1]]) server = managed_process(S2N, server_options, timeout=5) client = managed_process(provider, client_options, timeout=5) @@ -68,8 +74,12 @@ def test_sni_match(managed_process, provider, other_provider, protocol, cert_tes for results in server.get_results(): results.assert_success() - assert to_bytes("Actual protocol version: {}".format( - expected_version)) in results.stdout + assert ( + to_bytes("Actual protocol version: {}".format(expected_version)) + in results.stdout + ) if cert_test_case.client_sni is not None: - assert to_bytes("Server name: {}".format( - cert_test_case.client_sni)) in results.stdout + assert ( + to_bytes("Server name: {}".format(cert_test_case.client_sni)) + in results.stdout + ) diff --git a/tests/integrationv2/test_sslyze.py b/tests/integrationv2/test_sslyze.py index 23091092da2..4a1f968e40f 100644 --- a/tests/integrationv2/test_sslyze.py +++ b/tests/integrationv2/test_sslyze.py @@ -19,7 +19,7 @@ Protocols.TLS10, Protocols.TLS11, Protocols.TLS12, - Protocols.TLS13 + Protocols.TLS13, ] SSLYZE_SCANS_TO_TEST = [ @@ -30,11 +30,14 @@ sslyze.ScanCommand.TLS_FALLBACK_SCSV, sslyze.ScanCommand.HEARTBLEED, sslyze.ScanCommand.OPENSSL_CCS_INJECTION, - sslyze.ScanCommand.SESSION_RENEGOTIATION + sslyze.ScanCommand.SESSION_RENEGOTIATION, ] CERTS_TO_TEST = [ - cert for cert in ALL_TEST_CERTS if cert.name not in { + cert + for cert in ALL_TEST_CERTS + if cert.name + not in { "RSA_PSS_2048_SHA256" # SSLyze errors when given an RSA PSS cert } ] @@ -44,7 +47,7 @@ Protocols.TLS10.value: sslyze.ScanCommand.TLS_1_0_CIPHER_SUITES, Protocols.TLS11.value: sslyze.ScanCommand.TLS_1_1_CIPHER_SUITES, Protocols.TLS12.value: sslyze.ScanCommand.TLS_1_2_CIPHER_SUITES, - Protocols.TLS13.value: sslyze.ScanCommand.TLS_1_3_CIPHER_SUITES + Protocols.TLS13.value: sslyze.ScanCommand.TLS_1_3_CIPHER_SUITES, } @@ -66,7 +69,8 @@ def assert_scan_success(self): assert self.scan_result.is_tls_version_supported is True rejected_ciphers = [ - cipher for rejected_cipher in self.scan_result.rejected_cipher_suites + cipher + for rejected_cipher in self.scan_result.rejected_cipher_suites if (cipher := Ciphers.from_iana(rejected_cipher.cipher_suite.name)) ] @@ -77,7 +81,7 @@ def assert_scan_success(self): protocol=self.protocol, provider=S2N, certificate=self.certificate, - cipher=cipher + cipher=cipher, ) @@ -86,13 +90,16 @@ def assert_scan_success(self): assert self.scan_result.supports_ecdh_key_exchange is True rejected_curves = [ - curve for rejected_curve in self.scan_result.rejected_curves - if (curve := { - "X25519": Curves.X25519, - "prime256v1": Curves.P256, - "prime384v1": Curves.P384, - "prime521v1": Curves.P521 - }.get(rejected_curve.name)) + curve + for rejected_curve in self.scan_result.rejected_curves + if ( + curve := { + "X25519": Curves.X25519, + "prime256v1": Curves.P256, + "prime384v1": Curves.P384, + "prime521v1": Curves.P521, + }.get(rejected_curve.name) + ) ] for curve in rejected_curves: @@ -102,16 +109,22 @@ def assert_scan_success(self): protocol=self.protocol, provider=S2N, certificate=self.certificate, - curve=curve + curve=curve, ) class RobotVerifier(ScanVerifier): def assert_scan_success(self): if self.protocol == Protocols.TLS13: - assert self.scan_result.robot_result == sslyze.RobotScanResultEnum.NOT_VULNERABLE_RSA_NOT_SUPPORTED + assert ( + self.scan_result.robot_result + == sslyze.RobotScanResultEnum.NOT_VULNERABLE_RSA_NOT_SUPPORTED + ) else: - assert self.scan_result.robot_result == sslyze.RobotScanResultEnum.NOT_VULNERABLE_NO_ORACLE + assert ( + self.scan_result.robot_result + == sslyze.RobotScanResultEnum.NOT_VULNERABLE_NO_ORACLE + ) class SessionResumptionVerifier(ScanVerifier): @@ -119,7 +132,10 @@ def assert_scan_success(self): if self.protocol == Protocols.TLS13: pass # SSLyze does not support session resumption scans for tls 1.3 else: - assert self.scan_result.tls_ticket_resumption_result == sslyze.TlsResumptionSupportEnum.FULLY_SUPPORTED + assert ( + self.scan_result.tls_ticket_resumption_result + == sslyze.TlsResumptionSupportEnum.FULLY_SUPPORTED + ) class CrimeVerifier(ScanVerifier): @@ -160,16 +176,16 @@ def validate_scan_result(scan_attempt, protocol, certificate=None): scan_result = scan_attempt.result verifier_cls = { - sslyze.CipherSuitesScanResult: CipherSuitesVerifier, - sslyze.SupportedEllipticCurvesScanResult: EllipticCurveVerifier, - sslyze.RobotScanResult: RobotVerifier, - sslyze.SessionResumptionSupportScanResult: SessionResumptionVerifier, - sslyze.CompressionScanResult: CrimeVerifier, - sslyze.EarlyDataScanResult: EarlyDataVerifier, - sslyze.FallbackScsvScanResult: DowngradePreventionVerifier, - sslyze.HeartbleedScanResult: HeartbleedVerifier, - sslyze.OpenSslCcsInjectionScanResult: CCSInjectionVerifier, - sslyze.SessionRenegotiationScanResult: InsecureRenegotiationVerifier + sslyze.CipherSuitesScanResult: CipherSuitesVerifier, + sslyze.SupportedEllipticCurvesScanResult: EllipticCurveVerifier, + sslyze.RobotScanResult: RobotVerifier, + sslyze.SessionResumptionSupportScanResult: SessionResumptionVerifier, + sslyze.CompressionScanResult: CrimeVerifier, + sslyze.EarlyDataScanResult: EarlyDataVerifier, + sslyze.FallbackScsvScanResult: DowngradePreventionVerifier, + sslyze.HeartbleedScanResult: HeartbleedVerifier, + sslyze.OpenSslCcsInjectionScanResult: CCSInjectionVerifier, + sslyze.SessionRenegotiationScanResult: InsecureRenegotiationVerifier, }.get(type(scan_result)) assert verifier_cls is not None, f"unexpected scan: {scan_attempt}" @@ -181,12 +197,15 @@ def validate_scan_result(scan_attempt, protocol, certificate=None): def get_scan_attempts(scan_results): # scan_results (sslyze.AllScanCommandsAttempts) is an object containing parameters mapped to scan attempts. convert # this to a list containing just scan attempts, and then filter out tests that were not scheduled. - scan_attribute_names = [attr_name for attr_name in dir( - scan_results) if not attr_name.startswith("__")] - scan_attempts = [getattr(scan_results, attr_name) - for attr_name in scan_attribute_names] + scan_attribute_names = [ + attr_name for attr_name in dir(scan_results) if not attr_name.startswith("__") + ] scan_attempts = [ - scan_attempt for scan_attempt in scan_attempts + getattr(scan_results, attr_name) for attr_name in scan_attribute_names + ] + scan_attempts = [ + scan_attempt + for scan_attempt in scan_attempts if scan_attempt.status != sslyze.ScanCommandAttemptStatusEnum.NOT_SCHEDULED ] return scan_attempts @@ -196,19 +215,23 @@ def assert_scan_result_completed(scan_result): def get_connectivity_error_str(tb): return "\n".join(tb.stack.format()) - assert scan_result.connectivity_status == sslyze.ServerConnectivityStatusEnum.COMPLETED, \ + assert ( + scan_result.connectivity_status == sslyze.ServerConnectivityStatusEnum.COMPLETED + ), ( f"sslyze could not connect to server: {get_connectivity_error_str(scan_result.connectivity_error_trace)}" + ) def assert_scan_attempt_completed(scan_attempt): - assert scan_attempt.status == sslyze.ScanCommandAttemptStatusEnum.COMPLETED, \ + assert scan_attempt.status == sslyze.ScanCommandAttemptStatusEnum.COMPLETED, ( f"scan attempt ({scan_attempt}) failed: {scan_attempt.status}" + ) def run_sslyze_scan(host, port, scans): scan_request = sslyze.ServerScanRequest( server_location=sslyze.ServerNetworkLocation(hostname=host, port=port), - scan_commands=scans + scan_commands=scans, ) scanner = sslyze.Scanner(per_server_concurrent_connections_limit=1) scanner.queue_scans([scan_request]) @@ -224,7 +247,7 @@ def invalid_sslyze_scan_parameters(*args, **kwargs): if "fips" in get_flag(S2N_PROVIDER_VERSION) and protocol != Protocols.TLS13: if scan_command in [ sslyze.ScanCommand.TLS_COMPRESSION, - sslyze.ScanCommand.SESSION_RENEGOTIATION + sslyze.ScanCommand.SESSION_RENEGOTIATION, ]: return True # BUG_IN_SSLYZE error for session resumption scan with openssl 1.0.2 fips @@ -247,25 +270,29 @@ def test_sslyze_scans(managed_process, protocol, scan_command, provider): host=HOST, port=port, protocol=protocol, - extra_flags=["--parallelize"] + extra_flags=["--parallelize"], ) # Test 1.3 exclusively if protocol == Protocols.TLS13: server_options.cipher = Cipher( - "test_all_tls13", Protocols.TLS13, False, False, s2n=True) + "test_all_tls13", Protocols.TLS13, False, False, s2n=True + ) if scan_command == sslyze.ScanCommand.SESSION_RESUMPTION: - server_options.reconnect = True, - server_options.use_session_ticket = True, + server_options.reconnect = (True,) + server_options.use_session_ticket = (True,) if scan_command == sslyze.ScanCommand.TLS_1_3_EARLY_DATA: server_options.insecure = True server_options.use_session_ticket = True - server_options.extra_flags.extend([ - "--max-early-data", "65535", - "--https-server" # Early data scan sends http requests - ]) + server_options.extra_flags.extend( + [ + "--max-early-data", + "65535", + "--https-server", # Early data scan sends http requests + ] + ) server = managed_process(S2N, server_options, timeout=30) @@ -306,7 +333,7 @@ def invalid_certificate_scans_parameters(*args, **kwargs): if "RSA" in certificate.name and protocol in [ Protocols.SSLv3, Protocols.TLS10, - Protocols.TLS11 + Protocols.TLS11, ]: return True elif certificate_scan == CertificateScan.ELLIPTIC_CURVE_SCAN: @@ -325,11 +352,14 @@ def invalid_certificate_scans_parameters(*args, **kwargs): @pytest.mark.parametrize("protocol", PROTOCOLS_TO_TEST, ids=get_parameter_name) @pytest.mark.parametrize("certificate", CERTS_TO_TEST, ids=get_parameter_name) @pytest.mark.parametrize("provider", [S2N], ids=get_parameter_name) -@pytest.mark.parametrize("certificate_scan", [ - CertificateScan.CIPHER_SUITE_SCAN, - CertificateScan.ELLIPTIC_CURVE_SCAN -], ids=lambda certificate_scan: certificate_scan.name) -def test_sslyze_certificate_scans(managed_process, protocol, certificate, provider, certificate_scan): +@pytest.mark.parametrize( + "certificate_scan", + [CertificateScan.CIPHER_SUITE_SCAN, CertificateScan.ELLIPTIC_CURVE_SCAN], + ids=lambda certificate_scan: certificate_scan.name, +) +def test_sslyze_certificate_scans( + managed_process, protocol, certificate, provider, certificate_scan +): port = next(available_ports) server_options = ProviderOptions( @@ -340,13 +370,13 @@ def test_sslyze_certificate_scans(managed_process, protocol, certificate, provid key=certificate.key, cert=certificate.cert, insecure=True, - extra_flags=["--parallelize"] + extra_flags=["--parallelize"], ) server = managed_process(S2N, server_options, timeout=30) scan = { CertificateScan.CIPHER_SUITE_SCAN: CIPHER_SUITE_SCANS.get(protocol.value), - CertificateScan.ELLIPTIC_CURVE_SCAN: sslyze.ScanCommand.ELLIPTIC_CURVES + CertificateScan.ELLIPTIC_CURVE_SCAN: sslyze.ScanCommand.ELLIPTIC_CURVES, }.get(certificate_scan) scan_attempt_results = run_sslyze_scan(HOST, port, [scan]) diff --git a/tests/integrationv2/test_version_negotiation.py b/tests/integrationv2/test_version_negotiation.py index c19ecfd95f3..e1b0f7d484b 100644 --- a/tests/integrationv2/test_version_negotiation.py +++ b/tests/integrationv2/test_version_negotiation.py @@ -3,12 +3,23 @@ import copy import pytest -from configuration import available_ports, ALL_TEST_CIPHERS, ALL_TEST_CURVES, ALL_TEST_CERTS +from configuration import ( + available_ports, + ALL_TEST_CIPHERS, + ALL_TEST_CURVES, + ALL_TEST_CERTS, +) from common import ProviderOptions, Protocols, data_bytes from fixtures import managed_process # lgtm [py/unused-import] from providers import Provider, S2N, OpenSSL, GnuTLS -from utils import invalid_test_parameters, get_parameter_name, get_expected_s2n_version, get_expected_openssl_version, \ - to_bytes, get_expected_gnutls_version +from utils import ( + invalid_test_parameters, + get_parameter_name, + get_expected_s2n_version, + get_expected_openssl_version, + to_bytes, + get_expected_gnutls_version, +) def test_nothing(): @@ -22,10 +33,7 @@ def test_nothing(): def invalid_version_negotiation_test_parameters(*args, **kwargs): # Since s2nd/s2nc will always be using TLS 1.3, make sure the libcrypto is compatible - if invalid_test_parameters(**{ - "provider": S2N, - "protocol": Protocols.TLS13 - }): + if invalid_test_parameters(**{"provider": S2N, "protocol": Protocols.TLS13}): return True return invalid_test_parameters(*args, **kwargs) @@ -35,10 +43,16 @@ def invalid_version_negotiation_test_parameters(*args, **kwargs): @pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name) @pytest.mark.parametrize("certificate", ALL_TEST_CERTS, ids=get_parameter_name) -@pytest.mark.parametrize("protocol", [Protocols.TLS12, Protocols.TLS11, Protocols.TLS10], ids=get_parameter_name) +@pytest.mark.parametrize( + "protocol", + [Protocols.TLS12, Protocols.TLS11, Protocols.TLS10], + ids=get_parameter_name, +) @pytest.mark.parametrize("provider", [S2N, OpenSSL, GnuTLS], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -def test_s2nc_tls13_negotiates_tls12(managed_process, cipher, curve, certificate, protocol, provider, other_provider): +def test_s2nc_tls13_negotiates_tls12( + managed_process, cipher, curve, certificate, protocol, provider, other_provider +): port = next(available_ports) random_bytes = data_bytes(24) @@ -49,7 +63,7 @@ def test_s2nc_tls13_negotiates_tls12(managed_process, cipher, curve, certificate curve=curve, data_to_send=random_bytes, insecure=True, - protocol=Protocols.TLS13 + protocol=Protocols.TLS13, ) server_options = copy.copy(client_options) @@ -63,8 +77,9 @@ def test_s2nc_tls13_negotiates_tls12(managed_process, cipher, curve, certificate if provider == GnuTLS: kill_marker = random_bytes - server = managed_process(provider, server_options, - timeout=5, kill_marker=kill_marker) + server = managed_process( + provider, server_options, timeout=5, kill_marker=kill_marker + ) client = managed_process(S2N, client_options, timeout=5) client_version = get_expected_s2n_version(Protocols.TLS13, provider) @@ -72,10 +87,14 @@ def test_s2nc_tls13_negotiates_tls12(managed_process, cipher, curve, certificate for results in client.get_results(): results.assert_success() - assert to_bytes("Client protocol version: {}".format( - client_version)) in results.stdout - assert to_bytes("Actual protocol version: {}".format( - actual_version)) in results.stdout + assert ( + to_bytes("Client protocol version: {}".format(client_version)) + in results.stdout + ) + assert ( + to_bytes("Actual protocol version: {}".format(actual_version)) + in results.stdout + ) for results in server.get_results(): results.assert_success() @@ -83,25 +102,32 @@ def test_s2nc_tls13_negotiates_tls12(managed_process, cipher, curve, certificate # whether the S2N client was able to negotiate a lower TLS version. if provider is S2N: # The client sends a TLS 1.3 client hello so a client protocol version of TLS 1.3 should always be expected. - assert to_bytes("Client protocol version: {}".format( - Protocols.TLS13.value)) in results.stdout - assert to_bytes("Actual protocol version: {}".format( - actual_version)) in results.stdout + assert ( + to_bytes("Client protocol version: {}".format(Protocols.TLS13.value)) + in results.stdout + ) + assert ( + to_bytes("Actual protocol version: {}".format(actual_version)) + in results.stdout + ) - assert any([ - random_bytes[1:] in stream - for stream in results.output_streams() - ]) + assert any([random_bytes[1:] in stream for stream in results.output_streams()]) @pytest.mark.uncollect_if(func=invalid_version_negotiation_test_parameters) @pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name) @pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name) @pytest.mark.parametrize("certificate", ALL_TEST_CERTS, ids=get_parameter_name) -@pytest.mark.parametrize("protocol", [Protocols.TLS12, Protocols.TLS11, Protocols.TLS10], ids=get_parameter_name) +@pytest.mark.parametrize( + "protocol", + [Protocols.TLS12, Protocols.TLS11, Protocols.TLS10], + ids=get_parameter_name, +) @pytest.mark.parametrize("provider", [S2N, OpenSSL, GnuTLS], ids=get_parameter_name) @pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name) -def test_s2nd_tls13_negotiates_tls12(managed_process, cipher, curve, certificate, protocol, provider, other_provider): +def test_s2nd_tls13_negotiates_tls12( + managed_process, cipher, curve, certificate, protocol, provider, other_provider +): port = next(available_ports) random_bytes = data_bytes(24) @@ -112,7 +138,7 @@ def test_s2nd_tls13_negotiates_tls12(managed_process, cipher, curve, certificate curve=curve, data_to_send=random_bytes, insecure=True, - protocol=protocol + protocol=protocol, ) server_options = copy.copy(client_options) @@ -134,16 +160,19 @@ def test_s2nd_tls13_negotiates_tls12(managed_process, cipher, curve, certificate results.assert_success() if provider is S2N: # The client will get the server version from the SERVER HELLO, which will be the negotiated version - assert to_bytes("Server protocol version: {}".format( - actual_version)) in results.stdout - assert to_bytes("Actual protocol version: {}".format( - actual_version)) in results.stdout + assert ( + to_bytes("Server protocol version: {}".format(actual_version)) + in results.stdout + ) + assert ( + to_bytes("Actual protocol version: {}".format(actual_version)) + in results.stdout + ) elif provider is OpenSSL: # This check cares about other providers because we want to know that they did negotiate the version # that our S2N server intended to negotiate. openssl_version = get_expected_openssl_version(protocol) - assert to_bytes("Protocol : {}".format( - openssl_version)) in results.stdout + assert to_bytes("Protocol : {}".format(openssl_version)) in results.stdout elif provider is GnuTLS: gnutls_version = get_expected_gnutls_version(protocol) assert to_bytes(f"Version: {gnutls_version}") in results.stdout diff --git a/tests/integrationv2/test_well_known_endpoints.py b/tests/integrationv2/test_well_known_endpoints.py index 6d5d7e9d5a3..23c091737d1 100644 --- a/tests/integrationv2/test_well_known_endpoints.py +++ b/tests/integrationv2/test_well_known_endpoints.py @@ -1,13 +1,14 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 + def test_well_known_endpoints(): - ''' - This is a stub test, which allows the existing CI to continue passing while + """ + This is a stub test, which allows the existing CI to continue passing while https://github.com/aws/s2n-tls/pull/4884 is merged in. Once the PR is merged, the Codebuild spec for NixIntegV2Batch will be updated to remove the "well_known_endpoints" argument (manual process) and then this test can be fully removed (PR). - ''' + """ assert 1 == 1 diff --git a/tests/integrationv2/utils.py b/tests/integrationv2/utils.py index d1a342ff9bd..beda803a3f1 100644 --- a/tests/integrationv2/utils.py +++ b/tests/integrationv2/utils.py @@ -6,7 +6,7 @@ def to_bytes(val): - return bytes(str(val).encode('utf-8')) + return bytes(str(val).encode("utf-8")) def to_string(val: bytes): @@ -22,7 +22,7 @@ def get_expected_s2n_version(protocol, provider): protocol is less than tls12. """ if provider == S2N and protocol != Protocols.TLS13: - version = '33' + version = "33" else: version = protocol.value @@ -34,7 +34,7 @@ def get_expected_openssl_version(protocol): Protocols.TLS10.value: "TLSv1", Protocols.TLS11.value: "TLSv1.1", Protocols.TLS12.value: "TLSv1.2", - Protocols.TLS13.value: "TLSv1.3" + Protocols.TLS13.value: "TLSv1.3", }.get(protocol.value) @@ -43,7 +43,7 @@ def get_expected_gnutls_version(protocol): Protocols.TLS10.value: "TLS1.0", Protocols.TLS11.value: "TLS1.1", Protocols.TLS12.value: "TLS1.2", - Protocols.TLS13.value: "TLS1.3" + Protocols.TLS13.value: "TLS1.3", }.get(protocol.value) @@ -59,14 +59,14 @@ def invalid_test_parameters(*args, **kwargs): This function returns True or False, indicating whether a test should be "deselected" based on the arguments. """ - protocol = kwargs.get('protocol') - provider = kwargs.get('provider') - other_provider = kwargs.get('other_provider') - certificate = kwargs.get('certificate') - client_certificate = kwargs.get('client_certificate') - cipher = kwargs.get('cipher') - curve = kwargs.get('curve') - signature = kwargs.get('signature') + protocol = kwargs.get("protocol") + provider = kwargs.get("provider") + other_provider = kwargs.get("other_provider") + certificate = kwargs.get("certificate") + client_certificate = kwargs.get("client_certificate") + cipher = kwargs.get("cipher") + curve = kwargs.get("curve") + signature = kwargs.get("signature") providers = [provider_ for provider_ in [provider, other_provider] if provider_] # Always consider S2N @@ -76,9 +76,9 @@ def invalid_test_parameters(*args, **kwargs): # Older versions do not support RSA-PSS-PSS certificates if protocol and protocol < Protocols.TLS12: - if client_certificate and client_certificate.algorithm == 'RSAPSS': + if client_certificate and client_certificate.algorithm == "RSAPSS": return True - if certificate and certificate.algorithm == 'RSAPSS': + if certificate and certificate.algorithm == "RSAPSS": return True for provider_ in providers: diff --git a/tests/pcap/Cargo.toml b/tests/pcap/Cargo.toml index 04f87f0928b..f91a7d5f9a4 100644 --- a/tests/pcap/Cargo.toml +++ b/tests/pcap/Cargo.toml @@ -14,12 +14,12 @@ bytes = "1.7.1" hex = "0.4.3" reqwest = { version = "0.12.7", features = ["blocking"] } semver = "1.0.23" -rtshark = "2.9.0" +rtshark = "3.1.0" [dependencies] anyhow = "1.0.86" hex = "0.4.3" -rtshark = "2.9.0" +rtshark = "3.1.0" [dev-dependencies] # We want to test against the latest, local version of s2n diff --git a/tests/unit/s2n_cert_validation_callback_test.c b/tests/unit/s2n_cert_validation_callback_test.c index 01615f53500..be4392098f0 100644 --- a/tests/unit/s2n_cert_validation_callback_test.c +++ b/tests/unit/s2n_cert_validation_callback_test.c @@ -23,6 +23,7 @@ struct s2n_cert_validation_data { unsigned return_success : 1; int invoked_count; + struct s2n_cert_validation_info *info; }; static int s2n_test_cert_validation_callback(struct s2n_connection *conn, struct s2n_cert_validation_info *info, void *ctx) @@ -30,6 +31,8 @@ static int s2n_test_cert_validation_callback(struct s2n_connection *conn, struct struct s2n_cert_validation_data *data = (struct s2n_cert_validation_data *) ctx; data->invoked_count += 1; + /* Pass the `s2n_cert_validation_info` struct to application-defined `ctx` */ + data->info = info; int ret = S2N_FAILURE; if (data->return_success) { @@ -187,16 +190,6 @@ int main(int argc, char *argv[]) .data = { .call_accept_or_reject = true, .accept = false, .return_success = false }, .expected_error = S2N_ERR_CANCELLED }, - { - .data = { .call_accept_or_reject = false, .return_success = false }, - .expected_error = S2N_ERR_CANCELLED - }, - - /* Error if accept or reject wasn't called from the callback */ - { - .data = { .call_accept_or_reject = false, .return_success = true }, - .expected_error = S2N_ERR_INVALID_STATE - }, }; /* clang-format on */ @@ -444,6 +437,139 @@ int main(int argc, char *argv[]) EXPECT_EQUAL(data.invoked_count, 1); } + + /* For async cases, accept or reject API will be called outside of the validation callback. + * Iterate over both TLS 1.3 and 1.2 policies to ensure the stuffer reset logic works in all cases. + */ + struct s2n_cert_validation_data async_test_cases[] = { + { .call_accept_or_reject = false, .accept = true, .return_success = true }, + { .call_accept_or_reject = false, .accept = false, .return_success = true }, + }; + const char *versions[] = { "20240501", "20170210" }; + + /* Async callback is invoked on the client after receiving the server's certificate */ + for (int test_case_idx = 0; test_case_idx < s2n_array_len(async_test_cases); test_case_idx++) { + for (int version_idx = 0; version_idx < s2n_array_len(versions); version_idx++) { + DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(), s2n_config_ptr_free); + EXPECT_NOT_NULL(config); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key)); + EXPECT_SUCCESS(s2n_config_set_verification_ca_location(config, S2N_DEFAULT_TEST_CERT_CHAIN, NULL)); + EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, versions[version_idx])); + + struct s2n_cert_validation_data data = async_test_cases[test_case_idx]; + EXPECT_SUCCESS(s2n_config_set_cert_validation_cb(config, s2n_test_cert_validation_callback_self_talk, &data)); + + DEFER_CLEANUP(struct s2n_connection *server_conn = s2n_connection_new(S2N_SERVER), s2n_connection_ptr_free); + EXPECT_NOT_NULL(server_conn); + EXPECT_SUCCESS(s2n_connection_set_config(server_conn, config)); + + DEFER_CLEANUP(struct s2n_connection *client_conn = s2n_connection_new(S2N_CLIENT), s2n_connection_ptr_free); + EXPECT_NOT_NULL(client_conn); + EXPECT_SUCCESS(s2n_connection_set_config(client_conn, config)); + EXPECT_SUCCESS(s2n_connection_set_blinding(client_conn, S2N_SELF_SERVICE_BLINDING)); + EXPECT_SUCCESS(s2n_set_server_name(client_conn, "localhost")); + + DEFER_CLEANUP(struct s2n_test_io_pair io_pair = { 0 }, s2n_io_pair_close); + EXPECT_SUCCESS(s2n_io_pair_init_non_blocking(&io_pair)); + EXPECT_SUCCESS(s2n_connection_set_io_pair(client_conn, &io_pair)); + EXPECT_SUCCESS(s2n_connection_set_io_pair(server_conn, &io_pair)); + + for (int i = 0; i < 3; i++) { + EXPECT_FAILURE_WITH_ERRNO(s2n_negotiate_test_server_and_client(server_conn, client_conn), + S2N_ERR_ASYNC_BLOCKED); + EXPECT_EQUAL(data.invoked_count, 1); + } + + /* Ensure that the server's certificate chain can be retrieved after `S2N_ERR_ASYNC_BLOCKED` */ + DEFER_CLEANUP(struct s2n_cert_chain_and_key *peer_cert_chain = s2n_cert_chain_and_key_new(), + s2n_cert_chain_and_key_ptr_free); + EXPECT_NOT_NULL(peer_cert_chain); + EXPECT_SUCCESS(s2n_connection_get_peer_cert_chain(client_conn, peer_cert_chain)); + /* Ensure the certificate chain is non-empty */ + uint32_t peer_cert_chain_len = 0; + EXPECT_SUCCESS(s2n_cert_chain_get_length(peer_cert_chain, &peer_cert_chain_len)); + EXPECT_TRUE(peer_cert_chain_len > 0); + + struct s2n_cert_validation_info *info = data.info; + EXPECT_NOT_NULL(info); + + if (async_test_cases[test_case_idx].accept) { + EXPECT_SUCCESS(s2n_cert_validation_accept(info)); + EXPECT_SUCCESS(s2n_negotiate_test_server_and_client(server_conn, client_conn)); + } else { + EXPECT_SUCCESS(s2n_cert_validation_reject(info)); + EXPECT_FAILURE_WITH_ERRNO(s2n_negotiate_test_server_and_client(server_conn, client_conn), + S2N_ERR_CERT_REJECTED); + } + + EXPECT_EQUAL(data.invoked_count, 1); + } + } + + /* Async callback is invoked on the server after receiving the client's certificate */ + for (int test_case_idx = 0; test_case_idx < s2n_array_len(async_test_cases); test_case_idx++) { + for (int version_idx = 0; version_idx < s2n_array_len(versions); version_idx++) { + DEFER_CLEANUP(struct s2n_config *server_config = s2n_config_new(), s2n_config_ptr_free); + EXPECT_NOT_NULL(server_config); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(server_config, chain_and_key)); + EXPECT_SUCCESS(s2n_config_set_verification_ca_location(server_config, S2N_DEFAULT_TEST_CERT_CHAIN, NULL)); + EXPECT_SUCCESS(s2n_config_set_cipher_preferences(server_config, versions[version_idx])); + EXPECT_SUCCESS(s2n_config_set_client_auth_type(server_config, S2N_CERT_AUTH_REQUIRED)); + + struct s2n_cert_validation_data data = async_test_cases[test_case_idx]; + EXPECT_SUCCESS(s2n_config_set_cert_validation_cb(server_config, + s2n_test_cert_validation_callback_self_talk_server, &data)); + + DEFER_CLEANUP(struct s2n_connection *server_conn = s2n_connection_new(S2N_SERVER), s2n_connection_ptr_free); + EXPECT_NOT_NULL(server_conn); + EXPECT_SUCCESS(s2n_connection_set_config(server_conn, server_config)); + EXPECT_SUCCESS(s2n_connection_set_blinding(server_conn, S2N_SELF_SERVICE_BLINDING)); + + DEFER_CLEANUP(struct s2n_config *client_config = s2n_config_new(), s2n_config_ptr_free); + EXPECT_NOT_NULL(client_config); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(client_config, chain_and_key)); + EXPECT_SUCCESS(s2n_config_set_verification_ca_location(client_config, S2N_DEFAULT_TEST_CERT_CHAIN, NULL)); + EXPECT_SUCCESS(s2n_config_set_cipher_preferences(client_config, versions[version_idx])); + EXPECT_SUCCESS(s2n_config_set_client_auth_type(client_config, S2N_CERT_AUTH_OPTIONAL)); + + DEFER_CLEANUP(struct s2n_connection *client_conn = s2n_connection_new(S2N_CLIENT), s2n_connection_ptr_free); + EXPECT_NOT_NULL(client_conn); + EXPECT_SUCCESS(s2n_connection_set_config(client_conn, client_config)); + EXPECT_SUCCESS(s2n_set_server_name(client_conn, "localhost")); + + DEFER_CLEANUP(struct s2n_test_io_pair io_pair = { 0 }, s2n_io_pair_close); + EXPECT_SUCCESS(s2n_io_pair_init_non_blocking(&io_pair)); + EXPECT_SUCCESS(s2n_connection_set_io_pair(client_conn, &io_pair)); + EXPECT_SUCCESS(s2n_connection_set_io_pair(server_conn, &io_pair)); + + for (int i = 0; i < 3; i++) { + EXPECT_FAILURE_WITH_ERRNO(s2n_negotiate_test_server_and_client(server_conn, client_conn), + S2N_ERR_ASYNC_BLOCKED); + EXPECT_EQUAL(data.invoked_count, 1); + } + + /* Ensure that the client's certificate chain can be retrieved after `S2N_ERR_ASYNC_BLOCKED` */ + uint8_t *der_cert_chain = 0; + uint32_t cert_chain_len = 0; + EXPECT_SUCCESS(s2n_connection_get_client_cert_chain(server_conn, &der_cert_chain, &cert_chain_len)); + /* Ensure the certificate chain is non-empty */ + EXPECT_TRUE(cert_chain_len > 0); + + struct s2n_cert_validation_info *info = data.info; + EXPECT_NOT_NULL(info); + + if (async_test_cases[test_case_idx].accept) { + EXPECT_SUCCESS(s2n_cert_validation_accept(info)); + EXPECT_SUCCESS(s2n_negotiate_test_server_and_client(server_conn, client_conn)); + } else { + EXPECT_SUCCESS(s2n_cert_validation_reject(info)); + EXPECT_FAILURE_WITH_ERRNO(s2n_negotiate_test_server_and_client(server_conn, client_conn), + S2N_ERR_CERT_REJECTED); + } + + EXPECT_EQUAL(data.invoked_count, 1); + } + } } END_TEST(); diff --git a/tests/unit/s2n_ecdsa_test.c b/tests/unit/s2n_ecdsa_test.c index cd21b9c3d54..614b0f62b66 100644 --- a/tests/unit/s2n_ecdsa_test.c +++ b/tests/unit/s2n_ecdsa_test.c @@ -55,17 +55,6 @@ int main(int argc, char **argv) char *cert_chain_pem = NULL; char *private_key_pem = NULL; - const int supported_hash_algorithms[] = { - S2N_HASH_NONE, - S2N_HASH_MD5, - S2N_HASH_SHA1, - S2N_HASH_SHA224, - S2N_HASH_SHA256, - S2N_HASH_SHA384, - S2N_HASH_SHA512, - S2N_HASH_MD5_SHA1 - }; - BEGIN_TEST(); EXPECT_SUCCESS(s2n_disable_tls13_in_test()); @@ -156,13 +145,25 @@ int main(int argc, char **argv) EXPECT_SUCCESS(s2n_hash_new(&hash_one)); EXPECT_SUCCESS(s2n_hash_new(&hash_two)); - for (size_t i = 0; i < s2n_array_len(supported_hash_algorithms); i++) { - int hash_alg = supported_hash_algorithms[i]; - - if (!s2n_hash_is_available(hash_alg) || hash_alg == S2N_HASH_NONE) { - /* Skip hash algorithms that are not available. */ + /* Determining all possible valid combinations of hash algorithms and + * signature algorithms is actually surprisingly complicated. + * + * For example: awslc-fips will fail for MD5+ECDSA. However, that is not + * a real problem because there is no valid signature scheme that uses both + * MD5 and ECDSA. + * + * To avoid enumerating all the exceptions, just use the actual supported + * signature scheme list as the source of truth. + */ + const struct s2n_signature_preferences *all_sig_schemes = + security_policy_test_all.signature_preferences; + + for (size_t i = 0; i < all_sig_schemes->count; i++) { + const struct s2n_signature_scheme *scheme = all_sig_schemes->signature_schemes[i]; + if (scheme->sig_alg != S2N_SIGNATURE_ECDSA) { continue; } + const s2n_hash_algorithm hash_alg = scheme->hash_alg; EXPECT_SUCCESS(s2n_hash_init(&hash_one, hash_alg)); EXPECT_SUCCESS(s2n_hash_init(&hash_two, hash_alg)); @@ -173,25 +174,8 @@ int main(int argc, char **argv) /* Reset signature size when we compute a new signature */ signature.size = maximum_signature_length; - /* Not all hash algorithms are supported for EVP ECDSA signing. - * See s2n_evp_signing_validate_hash_alg. - */ - bool hash_is_md5 = (hash_alg == S2N_HASH_MD5 || hash_alg == S2N_HASH_MD5_SHA1); - bool hash_is_supported = !(hash_is_md5 && s2n_is_in_fips_mode()); - - int sign_result = s2n_pkey_sign(&priv_key, S2N_SIGNATURE_ECDSA, &hash_one, &signature); - if (hash_is_supported) { - EXPECT_SUCCESS(sign_result); - } else { - EXPECT_FAILURE_WITH_ERRNO(sign_result, S2N_ERR_HASH_INVALID_ALGORITHM); - } - - int verify_result = s2n_pkey_verify(&pub_key, S2N_SIGNATURE_ECDSA, &hash_two, &signature); - if (hash_is_supported) { - EXPECT_SUCCESS(verify_result); - } else { - EXPECT_FAILURE_WITH_ERRNO(verify_result, S2N_ERR_HASH_INVALID_ALGORITHM); - } + EXPECT_SUCCESS(s2n_pkey_sign(&priv_key, S2N_SIGNATURE_ECDSA, &hash_one, &signature)); + EXPECT_SUCCESS(s2n_pkey_verify(&pub_key, S2N_SIGNATURE_ECDSA, &hash_two, &signature)); EXPECT_SUCCESS(s2n_hash_reset(&hash_one)); EXPECT_SUCCESS(s2n_hash_reset(&hash_two)); diff --git a/tests/unit/s2n_evp_signing_test.c b/tests/unit/s2n_evp_signing_test.c index 8869d7a00f6..2275cefd149 100644 --- a/tests/unit/s2n_evp_signing_test.c +++ b/tests/unit/s2n_evp_signing_test.c @@ -36,12 +36,6 @@ const uint8_t input_data[INPUT_DATA_SIZE] = "hello hash"; -static bool s2n_hash_alg_is_supported(s2n_signature_algorithm sig_alg, s2n_hash_algorithm hash_alg) -{ - return (hash_alg != S2N_HASH_NONE) && (hash_alg != S2N_HASH_MD5) - && (hash_alg != S2N_HASH_MD5_SHA1 || sig_alg == S2N_SIGNATURE_RSA); -} - static S2N_RESULT s2n_test_hash_init(struct s2n_hash_state *hash_state, s2n_hash_algorithm hash_alg) { RESULT_GUARD_POSIX(s2n_hash_init(hash_state, hash_alg)); @@ -69,74 +63,139 @@ static S2N_RESULT s2n_test_evp_sign(s2n_signature_algorithm sig_alg, s2n_hash_al } static S2N_RESULT s2n_test_evp_verify(s2n_signature_algorithm sig_alg, s2n_hash_algorithm hash_alg, - struct s2n_pkey *public_key, - struct s2n_blob *evp_signature, struct s2n_blob *expected_signature) + struct s2n_pkey *public_key, struct s2n_blob *expected_signature) { DEFER_CLEANUP(struct s2n_hash_state hash_state = { 0 }, s2n_hash_free); RESULT_GUARD_POSIX(s2n_hash_new(&hash_state)); /* Verify that the EVP methods can verify their own signature */ RESULT_GUARD(s2n_test_hash_init(&hash_state, hash_alg)); - RESULT_GUARD_POSIX(s2n_evp_verify(public_key, sig_alg, &hash_state, evp_signature)); + RESULT_GUARD_POSIX(s2n_evp_verify(public_key, sig_alg, &hash_state, expected_signature)); /* Verify that using the pkey directly can verify own signature */ RESULT_GUARD(s2n_test_hash_init(&hash_state, hash_alg)); - RESULT_GUARD_POSIX(s2n_pkey_verify(public_key, sig_alg, &hash_state, evp_signature)); - - /* Verify that the EVP methods can verify the known good signature */ - RESULT_GUARD(s2n_test_hash_init(&hash_state, hash_alg)); - RESULT_GUARD_POSIX(s2n_evp_verify(public_key, sig_alg, &hash_state, expected_signature)); + RESULT_GUARD_POSIX(s2n_pkey_verify(public_key, sig_alg, &hash_state, expected_signature)); return S2N_RESULT_OK; } +static bool s2n_test_legacy_signing_supported() +{ + return !s2n_libcrypto_is_openssl_fips(); +} + int main(int argc, char **argv) { BEGIN_TEST(); - /* Sanity check that we're enabling evp signing properly. - * awslc-fips is known to require evp signing. - */ - if (s2n_is_in_fips_mode() && s2n_libcrypto_is_awslc()) { - EXPECT_TRUE(s2n_evp_signing_supported()); - } - - if (!s2n_evp_signing_supported()) { - END_TEST(); - } - - DEFER_CLEANUP(struct s2n_hash_state hash_state = { 0 }, s2n_hash_free); - EXPECT_SUCCESS(s2n_hash_new(&hash_state)); - struct s2n_cert_chain_and_key *rsa_cert_chain = NULL; EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&rsa_cert_chain, S2N_RSA_2048_PKCS1_CERT_CHAIN, S2N_RSA_2048_PKCS1_KEY)); - /* Test that unsupported hash algs are treated as invalid. - * Later tests will ignore unsupported algs, so ensure they are actually invalid. */ - { - /* This pkey should never actually be needed -- any pkey will do */ - struct s2n_pkey *pkey = rsa_cert_chain->private_key; - - for (s2n_signature_algorithm sig_alg = 0; sig_alg <= UINT8_MAX; sig_alg++) { - for (s2n_hash_algorithm hash_alg = 0; hash_alg < S2N_HASH_ALGS_COUNT; hash_alg++) { - if (s2n_hash_alg_is_supported(sig_alg, hash_alg)) { - continue; - } - - s2n_stack_blob(evp_signature, OUTPUT_DATA_SIZE, OUTPUT_DATA_SIZE); - EXPECT_ERROR_WITH_ERRNO(s2n_test_evp_sign(sig_alg, hash_alg, pkey, &evp_signature), - S2N_ERR_HASH_INVALID_ALGORITHM); - EXPECT_ERROR_WITH_ERRNO(s2n_test_evp_verify(sig_alg, hash_alg, pkey, &evp_signature, &evp_signature), - S2N_ERR_HASH_INVALID_ALGORITHM); - } - } - }; + /* Determining all possible valid combinations of hash algorithms and + * signature algorithms is actually surprisingly complicated. + * + * For example: awslc-fips will fail for MD5+ECDSA. However, that is not + * a real problem because there is no valid signature scheme that uses both + * MD5 and ECDSA. + * + * To avoid enumerating all the exceptions, just use the actual supported + * signature scheme list as the source of truth. + */ + const struct s2n_signature_preferences *all_sig_schemes = + security_policy_test_all.signature_preferences; /* EVP signing must match RSA signing */ { s2n_signature_algorithm sig_alg = S2N_SIGNATURE_RSA; + const char *valid_signatures[S2N_HASH_ALGS_COUNT] = { + [S2N_HASH_MD5_SHA1] = + "59 5b 8b 75 95 16 21 ae a1 7c 63 84 8f 9e 86 fd f9 79 e1 " + "d4 8b d7 01 91 37 43 86 75 d0 20 ce 64 31 4e 31 d1 dc dd 7e " + "a4 f3 86 36 8f d8 36 ef 27 7c 6b 09 8c e2 8b 35 79 c4 c0 a9 " + "f5 c1 ae 22 fd a8 23 83 af 52 4a 61 fe f5 14 c7 7c 7d fe af " + "11 bf 6b 6a 4a 3a e0 cb 63 24 13 2d 8d c3 6b b3 51 bb 1b f4 " + "d3 1b b2 67 bd 4c ad 7e b5 eb eb 52 fd 42 78 41 ef 40 c5 ba " + "42 1a 72 5a 45 3a 92 fd ff 43 6a 38 dd b9 de 13 32 34 4d 58 " + "77 53 56 96 bd 93 98 87 de 3d 6e 53 2e ff ea 01 71 97 dd a6 " + "62 02 9d 9c 58 45 af ec 72 ba 6f ff 60 75 25 6c 50 0e 1c bc " + "c6 c9 73 33 d4 b4 05 f6 2d 1d cd 95 29 95 a0 4f 8a cc 18 e6 " + "4d c9 53 e7 bd 60 2d 54 f0 c7 1f 25 b6 1b c6 b7 8c 4d 72 c4 " + "bb 1f 99 84 97 bd ac f0 80 a3 e3 88 67 11 a2 00 c6 2c 62 76 " + "2e 2f 37 86 d8 90 17 2d 2c d0 34 6e ca 4f 9f d5 59", + [S2N_HASH_SHA1] = + "3c e9 9f a7 7a 37 c3 72 ab 08 03 c9 aa 3a 7b 41 12 a1 07 10 " + "35 9a 57 f6 60 c5 79 a8 2f ad cc 62 b6 13 f3 fe a3 1e 94 b1 " + "c9 11 d1 50 24 15 66 7c 02 9d 17 f2 99 84 3e 61 bd 56 9e 09 " + "6a e1 18 fa cd 78 8a 00 d9 9a 28 95 1f ee e1 01 89 6f c4 2b " + "44 06 0c 2f 0e f5 ba dc 55 3c 7b d4 10 20 74 1b 1d f6 e0 ba " + "29 7b c0 7c 9f ab 1a 79 aa 58 d6 01 2e 9c b0 5b 97 4e c7 45 " + "76 b6 45 dc 36 7d da 8b 5e 8a bf 2c 51 d0 23 1d f4 a9 12 11 " + "0e e3 e1 0d 2e 5f 92 19 10 48 54 18 d1 4e 61 ec e6 47 60 13 " + "65 eb 84 cf d8 b9 4b 99 37 99 ef 83 58 6c e7 fd c0 fc a2 35 " + "99 0f 26 48 24 5e 0a 21 42 e1 77 a7 50 a2 ec ae d8 2f f1 18 " + "44 31 b4 5a a7 c7 93 1e 60 e7 2e 8b 9a 22 4a ee d4 0e 8d eb " + "da 36 01 ae e3 1d 52 3f 33 fb 84 b8 f8 a4 1b 75 c5 ce 51 9f " + "d8 2b 56 e0 32 98 be c4 f3 24 f2 7a fa c6 72 21", + [S2N_HASH_SHA224] = + "06 f6 e6 82 f2 79 98 a6 9a e0 5f 20 ad c7 eb 9f 41 0d 18 10 " + "86 9e d1 7f e4 b1 7d 39 e0 9f 05 4e 7a ce 7c c1 ba 29 c4 f4 " + "f0 e8 89 44 91 3f 65 8e 57 84 27 8e 88 9f 14 ee 04 fd 73 47 " + "40 03 fe 53 a6 c7 cd e0 db 27 9f 12 36 47 fc e7 7c 3f b9 f2 " + "f7 55 15 93 02 f9 5d a0 10 c7 13 cb d9 98 5c 22 d0 63 c7 5f " + "c0 8b 1a ac ec 2d 5d 2c 3e db 41 34 31 f3 0b c1 29 bc 83 a4 " + "27 37 61 17 5c 15 01 43 68 8f 3d 6e 23 76 f4 f1 a4 44 ce 5e " + "fc 61 88 85 5e d9 0e 2f 80 7e 56 ac 62 aa 2a a9 aa 46 8f da " + "ee f4 fe 1a 28 e8 78 25 fb b5 83 22 c9 d0 dd 28 f7 93 02 e5 " + "93 31 db 0f 9b 17 ae e2 a7 72 56 c8 53 ee 3a 80 c2 7c 15 3c " + "59 66 d5 c4 e3 99 9f cb f2 16 67 ac 9a 3a 03 b8 17 ce 77 12 " + "28 8a fd 21 ca 4c bf 06 b1 73 8e 6d 51 1c a3 d5 ec 82 66 ef " + "62 f3 9f 4a 22 c4 22 ed 13 a2 6d b8 96 5c b8 73", + [S2N_HASH_SHA256] = + "05 25 f7 42 ee 12 e9 ca 45 05 7c 96 32 03 a5 50 04 46 06 a2 " + "a5 57 d3 69 00 4a bc c2 21 a3 e9 2c 11 56 97 16 92 54 ba f3 " + "3b a6 67 ae 7f e6 89 74 be e7 16 43 3c 66 a3 51 93 96 c6 13 " + "af 8a 46 fe a9 f5 00 d7 de d5 02 76 2a f5 80 52 1f 6f 4f d6 " + "b9 a7 ab 62 66 57 51 5c 77 6e 46 03 e2 ef c6 dc f2 da f2 fc " + "8c 2a 80 ec 3b 9a ac 64 2e 34 49 cd ac 3f bc a4 82 84 6e 6d " + "49 cd 94 1b e3 ad be 96 15 27 89 a5 8f f7 35 16 7f e5 71 fe " + "b7 4a 45 4d ca 44 c7 bc ed 91 9a c4 0f bf 75 53 22 51 df 84 " + "7f 7e 71 b8 ef 4f 1f e5 cb 19 a3 87 4f 32 8d e7 06 a9 3f 81 " + "b9 ff 3c 14 07 9a b6 cb fa 02 d8 51 16 9f 4c 2d 03 ac d5 c1 " + "7e 73 5a a4 c6 b9 d1 7d a2 1a 17 9c c4 c1 7c a2 77 18 e5 2b " + "41 9e ab e6 e9 46 03 6f 44 95 11 8f 5e 51 d6 0a f4 e6 04 30 " + "89 18 9f 16 25 91 1d 74 64 c4 23 5d b5 fc f9 47", + [S2N_HASH_SHA384] = + "36 15 7c 11 a3 02 67 6d 40 8d 0f 7a c5 7e 2d 41 52 e6 16 " + "f2 4a 6b 60 a8 a7 0c 91 dc 5d a5 ed b4 98 98 24 be 05 d6 49 " + "aa 05 4f ba 54 5b 8d 21 e2 1f c5 1c 7e 99 52 f7 c9 19 fe a7 " + "e0 62 61 57 67 05 fa 15 1b a3 45 72 01 e0 0b e4 1f 69 1a 05 05" + "69 af f3 8a 4b 30 37 76 24 25 fd 55 c8 87 7b bc b7 b5 37 21 " + "dc f2 15 76 7e 68 11 ae 38 ce 2d e4 75 36 4a e1 f4 55 13 90 " + "70 8b db 1d 94 83 3b 88 83 48 bb 5e 0c 2e 23 f9 00 ed 59 a4 " + "c7 54 9e a0 0d 7d d3 72 7c e2 26 5c f8 34 34 eb 6a 85 f2 a3 " + "9a 47 8d c0 20 60 49 05 bb b7 6b 8b 52 f2 bc 35 11 da 97 f3 " + "4c 2d 93 29 ae 63 96 16 38 bc 8b a8 ba e7 d1 74 08 14 db f3 " + "51 a0 6f 87 4d 20 02 a7 db 0e 73 6a c1 55 75 26 61 34 5c 03 " + "f8 0c c0 a3 b6 ca 76 a7 68 61 84 53 58 f9 cf 11 67 29 04 8f " + "7b 24 a5 91 4b c2 b2 b2 21 81 f6 48 33 18 0c 7a", + [S2N_HASH_SHA512] = + "53 73 1c 33 80 1a 25 76 4c 0f 91 d6 7e 41 58 03 c9 71 56 " + "ef 54 06 19 05 37 99 57 10 63 91 a3 5b 83 85 dd 65 09 42 af " + "b7 51 45 83 e2 b9 ca 23 4c 92 eb 85 35 d0 23 c1 02 62 c5 46 " + "24 95 75 be 3c 1d e4 6c 45 87 a9 7a f3 c7 32 81 09 22 b2 c3 " + "43 d0 02 22 04 93 08 89 de 07 0a bb d2 68 25 06 6a 95 13 07 2d " + "74 4a 2c 37 a8 0d 74 e3 b5 b8 e2 8e ad 4d 7a 94 11 c7 4b 90 " + "0c 66 ec 4b 21 cd 2b b7 ae 68 32 01 0b 4c 93 6a 8f 7f a4 e6 " + "d1 7b a4 48 ef 6a 5e 29 c9 2b 20 51 6b 39 22 17 15 40 ef 7e " + "49 87 75 77 92 ed 4e af ae 92 b0 e5 10 47 ea b1 e9 8d 05 23 " + "dc 99 f1 b8 94 22 96 f4 02 6e 9a 35 57 8e 85 08 ee 03 7c 5e " + "df 2c 3f 49 22 bd 04 50 ff e9 48 eb 96 7a ee 80 51 e2 ab 94 " + "6d c8 73 73 3b 5e 65 f7 c7 49 de a8 3b 91 e1 5f 25 63 13 e0 " + "e9 51 79 99 54 0d 1a 1f 91 d3 41 e1 a3 b3 05 05", + }; + DEFER_CLEANUP(struct s2n_pkey public_key_parsed = { 0 }, s2n_pkey_free); EXPECT_OK(s2n_setup_public_key(&public_key_parsed, rsa_cert_chain)); @@ -145,26 +204,41 @@ int main(int argc, char **argv) EXPECT_PKEY_USES_EVP_SIGNING(private_key); EXPECT_PKEY_USES_EVP_SIGNING(public_key); - for (s2n_hash_algorithm hash_alg = 0; hash_alg < S2N_HASH_ALGS_COUNT; hash_alg++) { - if (!s2n_hash_alg_is_supported(sig_alg, hash_alg)) { + for (size_t i = 0; i < all_sig_schemes->count; i++) { + const struct s2n_signature_scheme *scheme = all_sig_schemes->signature_schemes[i]; + if (scheme->sig_alg != sig_alg) { continue; } + const s2n_hash_algorithm hash_alg = scheme->hash_alg; - /* Calculate the signature using EVP methods */ + /* Test that EVP can sign and verify */ s2n_stack_blob(evp_signature, OUTPUT_DATA_SIZE, OUTPUT_DATA_SIZE); EXPECT_OK(s2n_test_evp_sign(sig_alg, hash_alg, private_key, &evp_signature)); - - /* Calculate the signature using RSA methods */ - s2n_stack_blob(rsa_signature, OUTPUT_DATA_SIZE, OUTPUT_DATA_SIZE); - EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); - EXPECT_SUCCESS(s2n_rsa_pkcs1v15_sign(private_key, &hash_state, &rsa_signature)); - - /* Verify that the EVP methods can verify both signatures */ - EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &evp_signature, &rsa_signature)); - - /* Verify that the RSA methods can verify the EVP signature */ - EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); - EXPECT_SUCCESS(s2n_rsa_pkcs1v15_verify(public_key, &hash_state, &evp_signature)); + EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &evp_signature)); + + /* Test known value matches sign: RSA PKCS1 is deterministic */ + S2N_BLOB_FROM_HEX(known_value, valid_signatures[hash_alg]); + EXPECT_EQUAL(known_value.size, evp_signature.size); + EXPECT_BYTEARRAY_EQUAL(known_value.data, evp_signature.data, evp_signature.size); + /* Test verifying known value */ + EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &known_value)); + + /* Verify using legacy methods */ + if (s2n_test_legacy_signing_supported()) { + DEFER_CLEANUP(struct s2n_hash_state hash_state = { 0 }, s2n_hash_free); + EXPECT_SUCCESS(s2n_hash_new(&hash_state)); + + s2n_stack_blob(rsa_signature, OUTPUT_DATA_SIZE, OUTPUT_DATA_SIZE); + EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); + EXPECT_SUCCESS(s2n_rsa_pkcs1v15_sign(private_key, &hash_state, &rsa_signature)); + + /* EVP verifies legacy signature */ + EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &rsa_signature)); + + /* legacy verifies EVP signature */ + EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); + EXPECT_SUCCESS(s2n_rsa_pkcs1v15_verify(public_key, &hash_state, &evp_signature)); + } } }; @@ -172,6 +246,44 @@ int main(int argc, char **argv) { s2n_signature_algorithm sig_alg = S2N_SIGNATURE_ECDSA; + const char *valid_signatures[S2N_HASH_ALGS_COUNT] = { + [S2N_HASH_SHA1] = + "30 65 02 30 2f d6 a0 48 6b 17 9a e9 d6 c3 ad 16 db a4 04 " + "27 d3 c8 84 63 67 2b 07 b3 df 98 d7 f2 88 58 d1 9a 45 0d e7 " + "f7 f6 c8 ef 83 76 a2 23 24 60 3e a3 81 02 31 00 a4 9f f2 d2 " + "34 d8 96 40 02 73 3b 08 91 17 76 67 6f ce d4 00 83 87 1e 4e " + "9e 88 a9 3a 1b f0 06 f6 39 f8 ac a5 0d da 27 b8 89 bd be a6 " + "58 ce a5 b6", + [S2N_HASH_SHA224] = + "30 65 02 31 00 ac 40 4c e0 c6 96 d2 00 c3 a0 d2 d0 21 6a 87 " + "75 60 8a 95 47 e5 81 d3 9e d0 ba 1a be 57 49 15 1a df 8f c7 " + "be 21 84 49 4b a6 1c 22 cb 89 e3 57 14 02 30 52 c4 ea bf c1 " + "05 d9 a8 76 73 70 8c 2a d2 de 68 df 73 80 5d 89 13 ff c4 b9 " + "4e eb fc fc cf 4c 2e 9c 90 d8 85 ad 6c bb 13 86 63 04 ff 58 " + "d0 1a 34", + [S2N_HASH_SHA256] = + "30 66 02 31 00 eb 65 34 a1 7e de 30 11 fd a7 8f ba 41 5f " + "3b 72 88 23 ae fa 41 14 05 3c ee ef d7 2c fa 4f 51 0d 66 63 " + "4f b2 a4 34 6c 1b 28 69 96 eb b5 5f 13 1b 02 31 00 88 7b ed " + "90 f6 ab d7 4b b8 60 ef 60 50 19 2e 65 f8 e9 20 a8 23 10 ac " + "45 81 37 fb 8b 0c f2 10 d1 18 d1 46 62 15 06 06 8c bb a7 6b " + "e5 29 d2 26 d4", + [S2N_HASH_SHA384] = + "30 64 02 30 76 f2 dc 15 27 47 b5 d2 12 6e 97 ca 48 27 89 " + "13 f4 ea 34 1b 6c cd e7 ef 8a 56 15 0a 87 7d 55 d7 74 08 61 " + "78 04 1c 27 6d 55 81 32 90 9d 31 8f 35 02 30 46 c6 88 8a 2f " + "b1 d9 a1 db cd 52 d3 fc c2 e4 cd 62 ec 42 28 e5 e3 58 9c b0 " + "02 cd e5 60 39 53 7c 86 e6 17 ad 03 16 50 75 cc a1 22 61 04 " + "a0 30 19", + [S2N_HASH_SHA512] = + "30 65 02 30 7c 40 b3 ba a7 4c 0b 81 02 97 0c ff 3e 66 53 69 " + "86 83 e0 83 a0 14 f8 77 d1 1b 61 32 3e a2 c7 04 d3 cd b2 8c " + "92 b5 3c 01 a9 21 c3 8b 8d e2 e3 f6 02 31 00 c0 ea c3 b3 65 " + "ed ed fb cc 94 bb e7 db 44 93 e4 59 88 f2 d0 2c 8b 1e a7 70 " + "fe cf 12 dd 84 3d 70 79 05 8c 53 de a6 94 e0 e6 fa ef 35 75 " + "d8 11 11", + }; + struct s2n_cert_chain_and_key *ecdsa_cert_chain = NULL; EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&ecdsa_cert_chain, S2N_ECDSA_P384_PKCS1_CERT_CHAIN, S2N_ECDSA_P384_PKCS1_KEY)); @@ -183,26 +295,38 @@ int main(int argc, char **argv) EXPECT_PKEY_USES_EVP_SIGNING(private_key); EXPECT_PKEY_USES_EVP_SIGNING(public_key); - for (s2n_hash_algorithm hash_alg = 0; hash_alg < S2N_HASH_ALGS_COUNT; hash_alg++) { - if (!s2n_hash_alg_is_supported(sig_alg, hash_alg)) { + for (size_t i = 0; i < all_sig_schemes->count; i++) { + const struct s2n_signature_scheme *scheme = all_sig_schemes->signature_schemes[i]; + if (scheme->sig_alg != sig_alg) { continue; } + const s2n_hash_algorithm hash_alg = scheme->hash_alg; - /* Calculate the signature using EVP methods */ + /* Test that EVP can sign and verify */ s2n_stack_blob(evp_signature, OUTPUT_DATA_SIZE, OUTPUT_DATA_SIZE); EXPECT_OK(s2n_test_evp_sign(sig_alg, hash_alg, private_key, &evp_signature)); + EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &evp_signature)); + + /* Test verifying known value */ + S2N_BLOB_FROM_HEX(known_value, valid_signatures[hash_alg]); + EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &known_value)); + + /* Verify using legacy methods */ + if (s2n_test_legacy_signing_supported()) { + DEFER_CLEANUP(struct s2n_hash_state hash_state = { 0 }, s2n_hash_free); + EXPECT_SUCCESS(s2n_hash_new(&hash_state)); - /* Calculate the signature using ECDSA methods */ - s2n_stack_blob(ecdsa_signature, OUTPUT_DATA_SIZE, OUTPUT_DATA_SIZE); - EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); - EXPECT_SUCCESS(s2n_ecdsa_sign(private_key, sig_alg, &hash_state, &ecdsa_signature)); + s2n_stack_blob(ecdsa_signature, OUTPUT_DATA_SIZE, OUTPUT_DATA_SIZE); + EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); + EXPECT_SUCCESS(s2n_ecdsa_sign(private_key, sig_alg, &hash_state, &ecdsa_signature)); - /* Verify that the EVP methods can verify both signatures */ - EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &evp_signature, &ecdsa_signature)); + /* EVP verifies legacy signature */ + EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &ecdsa_signature)); - /* Verify that the ECDSA methods can verify the EVP signature */ - EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); - EXPECT_SUCCESS(s2n_ecdsa_verify(public_key, sig_alg, &hash_state, &evp_signature)); + /* legacy verifies EVP signature */ + EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); + EXPECT_SUCCESS(s2n_ecdsa_verify(public_key, sig_alg, &hash_state, &evp_signature)); + } } EXPECT_SUCCESS(s2n_cert_chain_and_key_free(ecdsa_cert_chain)); @@ -212,6 +336,51 @@ int main(int argc, char **argv) if (s2n_is_rsa_pss_signing_supported()) { s2n_signature_algorithm sig_alg = S2N_SIGNATURE_RSA_PSS_RSAE; + const char *valid_signatures[S2N_HASH_ALGS_COUNT] = { + [S2N_HASH_SHA256] = + "37 bd 68 e1 57 96 67 a4 c9 b8 cf 02 e6 f4 96 ab 65 9c bd " + "e2 33 2c f4 9b 66 03 f3 a6 2e ba 09 30 3b d0 d4 cf 5f 03 43 " + "50 56 55 b6 6c f2 f2 c4 e3 9a ea 9c 0d e4 8c 37 10 fe b9 1f " + "95 3f c8 fb 19 1d f4 bc 85 56 de 4e 1f 3f ff 21 f3 84 67 99 " + "4d 6f 21 74 6e f6 7d e0 40 3e e2 2a ad 76 c0 99 2d 22 35 2e " + "cc 18 c3 5a ef 39 8f 0e 86 17 55 5b fa a8 92 28 e4 16 28 0d " + "3b 84 2b 73 34 d3 97 b8 b0 ce 49 00 8d dd 36 4c 28 52 12 e4 " + "43 00 9b 42 f9 75 e3 79 65 ef 8b 42 d4 0d 22 78 58 76 b4 23 " + "4a c1 a2 8f 00 cb fd 82 71 71 f1 69 b2 1f c2 17 8b b0 06 06 " + "4b 19 a4 46 d5 54 88 6b 2d ce 69 79 cf 2f 81 59 ac d2 9a b7 " + "6b 7b 20 0e a2 9f 39 6d 8a dd 75 5a ef 5a 8f 2a 1c ac 0c 60 " + "d5 20 47 39 9d 79 83 cf 37 19 f5 56 62 02 09 ab 72 9c 0f 1e " + "ca 77 e6 c2 38 a4 b8 34 96 0f 2e bd e0 31 71 9d c5", + [S2N_HASH_SHA384] = + "a3 40 0b e9 8f 93 77 50 d5 6d f2 34 7d 92 cf e2 e8 a1 6c " + "36 4d a1 70 92 de 4f 3e 2a 6f 25 e6 ae 47 3c ec f1 d2 10 20 " + "d2 e5 78 43 40 75 b9 2c 7f 0c 2e 95 26 e0 9d b4 e5 c8 d4 d4 " + "c0 b2 a9 a0 4b 83 a5 45 b2 f3 62 aa bb 17 b5 b1 ac c1 19 db " + "22 a0 49 86 3c 77 ae 13 5f eb b9 f2 2f 4e 57 4e 0f 1d 2a d9 " + "d3 d0 39 ac 61 fe f4 b9 85 20 ed 4c ff 34 f1 67 cd 21 60 a1 " + "fc 9c c2 b0 ec d2 43 38 7b 06 aa d9 e3 81 a3 73 88 6e c0 72 " + "e3 a4 6e 41 79 c2 b0 54 5a 42 fb c7 00 1e 4c a8 3e a3 41 17 " + "a6 67 b3 e0 dd 2f f1 2d dc 42 46 c7 74 47 15 7a 9b ad b6 b0 " + "cf d6 1e b3 14 4a b6 2b ab ad 9e db 86 6c 6f 37 c7 62 59 52 " + "bc 4f 2f 30 a3 41 17 c6 85 64 db d7 06 31 4f dc 7f 33 3a 3a " + "3e 4e 23 37 89 53 8d f1 fe 46 d6 cc 80 f4 ed c8 87 24 60 a7 " + "a5 92 77 67 3c 0b f7 fa 56 e1 ad f7 c5 82 9f 83 25", + [S2N_HASH_SHA512] = + "95 63 f0 49 3e 93 f7 8c 76 f0 bf 0a 87 4d 2a 8b f7 45 b1 " + "c1 41 a4 d9 5f f1 43 cb 10 bc af 55 44 7d 61 78 75 f9 6a 98 " + "10 ef 3c ae f9 e0 f3 ce 5c 51 79 70 3e a9 cd 86 fc c8 a2 73 " + "21 60 f4 37 73 20 b7 a7 24 e3 ec 49 d9 e0 bd 20 7f d0 36 3c " + "dd 1f 36 a7 56 ee bf c9 c8 16 17 ef 07 48 ad b2 f1 dd 8d 65 " + "19 ec c4 b0 4d 94 80 9c 2e cc a6 a5 36 23 ed 1f 69 29 0e d9 " + "1b 72 ec 73 9d 5d 9b ec a5 c7 ec 24 86 ca 5f bc 70 92 b1 c3 " + "00 2d 15 4b 74 bb aa f9 c9 ca 60 77 2f 3a 59 b6 89 44 32 5c " + "8d bd 02 ed a1 b9 80 a7 17 bb b2 cc 89 a2 60 74 f0 20 d7 4d " + "a9 92 33 90 2c 7c ab ec f6 a3 38 22 32 e5 83 b6 09 14 b5 b4 " + "3b 23 25 92 33 16 5e 40 8b b2 97 89 e9 82 d6 10 0b 2c b7 f0 " + "81 81 c4 00 b3 38 84 bc 39 00 e2 6d 38 f0 e7 1b 66 ad 62 06 " + "1b 76 62 18 3c 2a d9 b6 a8 fd af b4 1f a4 92 e9 24", + }; + DEFER_CLEANUP(struct s2n_pkey public_key_parsed = { 0 }, s2n_pkey_free); EXPECT_OK(s2n_setup_public_key(&public_key_parsed, rsa_cert_chain)); @@ -220,26 +389,38 @@ int main(int argc, char **argv) EXPECT_PKEY_USES_EVP_SIGNING(private_key); EXPECT_PKEY_USES_EVP_SIGNING(public_key); - for (s2n_hash_algorithm hash_alg = 0; hash_alg < S2N_HASH_ALGS_COUNT; hash_alg++) { - if (!s2n_hash_alg_is_supported(sig_alg, hash_alg)) { + for (size_t i = 0; i < all_sig_schemes->count; i++) { + const struct s2n_signature_scheme *scheme = all_sig_schemes->signature_schemes[i]; + if (scheme->sig_alg != sig_alg) { continue; } + const s2n_hash_algorithm hash_alg = scheme->hash_alg; - /* Calculate the signature using EVP methods */ + /* Test that EVP can sign and verify */ s2n_stack_blob(evp_signature, OUTPUT_DATA_SIZE, OUTPUT_DATA_SIZE); EXPECT_OK(s2n_test_evp_sign(sig_alg, hash_alg, private_key, &evp_signature)); + EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &evp_signature)); + + /* Test verifying known value */ + S2N_BLOB_FROM_HEX(known_value, valid_signatures[hash_alg]); + EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &known_value)); - /* Calculate the signature using RSA-PSS methods */ - s2n_stack_blob(rsa_pss_signature, OUTPUT_DATA_SIZE, OUTPUT_DATA_SIZE); - EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); - EXPECT_SUCCESS(s2n_rsa_pss_sign(private_key, &hash_state, &rsa_pss_signature)); + /* Verify using legacy methods */ + if (s2n_test_legacy_signing_supported()) { + DEFER_CLEANUP(struct s2n_hash_state hash_state = { 0 }, s2n_hash_free); + EXPECT_SUCCESS(s2n_hash_new(&hash_state)); - /* Verify that the EVP methods can verify both signatures */ - EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &evp_signature, &rsa_pss_signature)); + s2n_stack_blob(rsa_pss_signature, OUTPUT_DATA_SIZE, OUTPUT_DATA_SIZE); + EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); + EXPECT_SUCCESS(s2n_rsa_pss_sign(private_key, &hash_state, &rsa_pss_signature)); - /* Verify that the RSA-PSS methods can verify the EVP signature */ - EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); - EXPECT_SUCCESS(s2n_rsa_pss_verify(public_key, &hash_state, &evp_signature)); + /* EVP verifies legacy signature */ + EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &rsa_pss_signature)); + + /* legacy verifies EVP signature */ + EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); + EXPECT_SUCCESS(s2n_rsa_pss_verify(public_key, &hash_state, &evp_signature)); + } } } @@ -247,6 +428,51 @@ int main(int argc, char **argv) if (s2n_is_rsa_pss_certs_supported()) { s2n_signature_algorithm sig_alg = S2N_SIGNATURE_RSA_PSS_PSS; + const char *valid_signatures[S2N_HASH_ALGS_COUNT] = { + [S2N_HASH_SHA256] = + "66 0c 25 38 fd a1 bc b8 ca 48 3f 3d b9 3f 55 49 f6 3b 8c " + "62 95 60 74 bd 5d 53 bb 57 64 3a 63 63 04 04 fb e4 cf 15 82 " + "11 13 7c e0 ab 66 b2 c9 44 67 db b3 f2 55 24 32 31 29 d9 f7 " + "d4 be 53 02 75 bd e1 3d 27 6d 45 74 65 9e 20 27 96 ba 09 32 " + "81 8c 0e bb 7f 4b 7f e4 0a 95 22 68 a8 48 8a 8d 32 13 2e c0 " + "12 74 88 0e 48 74 99 c4 7b 6a 0e 62 0a c6 cb 04 87 f2 9b dc 9e " + "d7 e5 28 34 9a 75 bc 55 fa c4 71 20 17 4d 11 31 00 f5 cd 5e " + "13 65 74 b5 e8 5a a2 16 d5 22 84 3c 3f f0 96 2c b4 32 bc 9a " + "9a d0 02 4e e1 ac f3 ad 6b 9e 4f 99 90 19 6a cf 46 9f 04 5c " + "ba 0c f4 4e 06 2d 67 29 f5 88 63 c9 2f 3a 69 4c 36 8e 2c 64 " + "1d e6 b4 97 cb fc e2 c7 ae 6e c7 57 74 c6 ad a8 79 15 2f 5a " + "9d 18 4a 64 e9 5c f2 dc 9c 4b 9f 07 70 9c be e9 7a 20 18 2c " + "4b ca ab 27 47 cb ec 1a b0 88 b7 ea a7 e6 85 68", + [S2N_HASH_SHA384] = + "28 4d cd 9f 75 79 a9 fe 08 77 df 73 98 8e 70 6b 73 6e db " + "d6 eb a0 0e a4 53 31 53 79 7b af 94 eb 1e 6e b8 66 76 b6 34 " + "f4 8c 78 f0 57 d4 3b 48 45 24 e7 55 52 16 89 f9 78 06 25 9c " + "98 0b b3 da 20 20 c8 e2 41 24 fd a2 7f ac 73 0b 04 90 c3 77 " + "65 37 3a b6 73 cd 9b 4b 14 2e f5 53 f9 c1 7d 5d fd 0c 9d 02 " + "96 7f bd 1d 7f fa eb 0c e3 0a 65 29 5c 96 09 2f 11 4c 1d 03 b2 " + "18 6e 7c b8 e3 0d 03 f8 df ad 65 08 83 57 bb 71 5b 2b 98 03 47 " + "fc d2 d7 db 4b e3 9b 2b b4 37 a6 db db 8b 8d 67 ca 1a fe bd " + "f1 d3 f9 53 8f 78 ba 4a e0 55 b4 c6 37 de e5 41 e4 e0 2f 28 " + "83 ce b6 8b 5b 68 9b a3 75 fd 5c 61 ab d3 3c a4 4e 69 89 4a " + "bd 74 84 78 6e 89 00 66 b8 2d 5b 98 ff ce 61 f2 59 80 56 34 " + "aa 66 1f 75 df 10 20 80 4a cb 1c 9b 41 d0 c2 9b a1 9b 68 f0 " + "7c 10 73 0f 81 e7 f6 6a 6e 27 70 5e ff a9 bd", + [S2N_HASH_SHA512] = + "5a 9b 32 6f aa 20 e2 a7 0b ec 7f 00 17 24 04 dc 7f 6f 17 02 " + "db 82 dc 18 7f d8 2c b7 a9 8e 05 ae 84 c6 4a 87 2b 8b 14 f4 " + "54 59 83 4c d4 80 64 5a 54 bb 23 c6 ad 8a f8 70 31 18 96 99 " + "5f a9 49 98 70 55 a6 18 9b 0a d8 03 9b 3e 68 19 72 34 41 c4 " + "bb 99 f1 a3 3d 9e 5d 7e 79 4b 74 a0 72 fc cb 83 5b 16 38 17 " + "e0 0e 57 55 4c d3 3a 9e de 8e d5 5f d5 be 5e 6d 85 91 fa fa " + "44 90 b7 d3 cb b1 65 12 98 e8 6f d2 f3 6c 80 ef 3e dc 2b 42 " + "71 a1 73 55 db 44 7d e5 2f 2b be a8 73 15 72 2b 72 df fb ed " + "c1 39 34 2f bb 9d c9 be 97 25 3c e0 ae e6 af 2c 06 d3 5e e5 " + "65 a9 1c 22 6b 5d fa bb c7 78 af 70 34 e0 f1 80 b8 f4 b1 17 " + "94 f3 ea b6 7c f4 be b8 ec 05 29 f1 d1 e4 f7 91 aa 47 2e f3 " + "b0 0b 61 78 77 37 5f 47 86 7b c7 c8 59 25 a6 e1 91 14 d0 31 " + "b9 cd 6a 52 85 7b 06 01 40 f1 d2 5a d0 6a 3d f7", + }; + struct s2n_cert_chain_and_key *rsa_pss_cert_chain = NULL; EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&rsa_pss_cert_chain, S2N_RSA_PSS_2048_SHA256_LEAF_CERT, S2N_RSA_PSS_2048_SHA256_LEAF_KEY)); @@ -258,26 +484,38 @@ int main(int argc, char **argv) EXPECT_PKEY_USES_EVP_SIGNING(private_key); EXPECT_PKEY_USES_EVP_SIGNING(public_key); - for (s2n_hash_algorithm hash_alg = 0; hash_alg < S2N_HASH_ALGS_COUNT; hash_alg++) { - if (!s2n_hash_alg_is_supported(sig_alg, hash_alg)) { + for (size_t i = 0; i < all_sig_schemes->count; i++) { + const struct s2n_signature_scheme *scheme = all_sig_schemes->signature_schemes[i]; + if (scheme->sig_alg != sig_alg) { continue; } + const s2n_hash_algorithm hash_alg = scheme->hash_alg; - /* Calculate the signature using EVP methods */ + /* Test that EVP can sign and verify */ s2n_stack_blob(evp_signature, OUTPUT_DATA_SIZE, OUTPUT_DATA_SIZE); EXPECT_OK(s2n_test_evp_sign(sig_alg, hash_alg, private_key, &evp_signature)); + EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &evp_signature)); + + /* Test verifying known value */ + S2N_BLOB_FROM_HEX(known_value, valid_signatures[hash_alg]); + EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &known_value)); - /* Calculate the signature using RSA-PSS methods */ - s2n_stack_blob(rsa_pss_signature, OUTPUT_DATA_SIZE, OUTPUT_DATA_SIZE); - EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); - EXPECT_SUCCESS(s2n_rsa_pss_sign(private_key, &hash_state, &rsa_pss_signature)); + /* Verify using legacy methods */ + if (s2n_test_legacy_signing_supported()) { + DEFER_CLEANUP(struct s2n_hash_state hash_state = { 0 }, s2n_hash_free); + EXPECT_SUCCESS(s2n_hash_new(&hash_state)); - /* Verify that the EVP methods can verify both signatures */ - EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &evp_signature, &rsa_pss_signature)); + s2n_stack_blob(rsa_pss_signature, OUTPUT_DATA_SIZE, OUTPUT_DATA_SIZE); + EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); + EXPECT_SUCCESS(s2n_rsa_pss_sign(private_key, &hash_state, &rsa_pss_signature)); - /* Verify that the RSA-PSS methods can verify the EVP signature */ - EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); - EXPECT_SUCCESS(s2n_rsa_pss_verify(public_key, &hash_state, &evp_signature)); + /* EVP verifies legacy signature */ + EXPECT_OK(s2n_test_evp_verify(sig_alg, hash_alg, public_key, &rsa_pss_signature)); + + /* legacy verifies EVP signature */ + EXPECT_OK(s2n_test_hash_init(&hash_state, hash_alg)); + EXPECT_SUCCESS(s2n_rsa_pss_verify(public_key, &hash_state, &evp_signature)); + } } EXPECT_SUCCESS(s2n_cert_chain_and_key_free(rsa_pss_cert_chain)); diff --git a/tests/unit/s2n_resume_test.c b/tests/unit/s2n_resume_test.c index 3ce0e528364..0e8ab11c982 100644 --- a/tests/unit/s2n_resume_test.c +++ b/tests/unit/s2n_resume_test.c @@ -1532,7 +1532,7 @@ int main(int argc, char **argv) /* Manually zero out key bytes */ struct s2n_ticket_key *key = NULL; - EXPECT_OK(s2n_set_get(config->ticket_keys, 0, (void **) &key)); + EXPECT_OK(s2n_array_get(config->ticket_keys, 0, (void **) &key)); EXPECT_NOT_NULL(key); POSIX_CHECKED_MEMSET((uint8_t *) key->aes_key, 0, S2N_AES256_KEY_LEN); diff --git a/tests/unit/s2n_security_policies_test.c b/tests/unit/s2n_security_policies_test.c index 7545162efc1..b23d9c37434 100644 --- a/tests/unit/s2n_security_policies_test.c +++ b/tests/unit/s2n_security_policies_test.c @@ -651,7 +651,7 @@ int main(int argc, char **argv) EXPECT_EQUAL(config->security_policy, &security_policy_test_all); EXPECT_EQUAL(config->security_policy->cipher_preferences, &cipher_preferences_test_all); EXPECT_EQUAL(config->security_policy->kem_preferences, &kem_preferences_all); - EXPECT_EQUAL(config->security_policy->signature_preferences, &s2n_signature_preferences_20201021); + EXPECT_EQUAL(config->security_policy->signature_preferences, &s2n_signature_preferences_all); EXPECT_EQUAL(config->security_policy->ecc_preferences, &s2n_ecc_preferences_test_all); EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, "test_all_tls12")); @@ -755,7 +755,7 @@ int main(int argc, char **argv) EXPECT_EQUAL(security_policy, &security_policy_test_all); EXPECT_EQUAL(security_policy->cipher_preferences, &cipher_preferences_test_all); EXPECT_EQUAL(security_policy->kem_preferences, &kem_preferences_all); - EXPECT_EQUAL(security_policy->signature_preferences, &s2n_signature_preferences_20201021); + EXPECT_EQUAL(security_policy->signature_preferences, &s2n_signature_preferences_all); EXPECT_EQUAL(security_policy->ecc_preferences, &s2n_ecc_preferences_test_all); EXPECT_SUCCESS(s2n_connection_set_cipher_preferences(conn, "test_all_tls12")); diff --git a/tests/unit/s2n_session_ticket_test.c b/tests/unit/s2n_session_ticket_test.c index f0d36a7db59..4638c50339c 100644 --- a/tests/unit/s2n_session_ticket_test.c +++ b/tests/unit/s2n_session_ticket_test.c @@ -577,9 +577,9 @@ int main(int argc, char **argv) EXPECT_TRUE(IS_ISSUING_NEW_SESSION_TICKET(server_conn)); /* Verify that the server has only the unexpired key */ - EXPECT_OK(s2n_set_get(server_config->ticket_keys, 0, (void **) &ticket_key)); + EXPECT_OK(s2n_array_get(server_config->ticket_keys, 0, (void **) &ticket_key)); EXPECT_BYTEARRAY_EQUAL(ticket_key->key_name, ticket_key_name2, s2n_array_len(ticket_key_name2)); - EXPECT_OK(s2n_set_len(server_config->ticket_keys, &ticket_keys_len)); + EXPECT_OK(s2n_array_num_elements(server_config->ticket_keys, &ticket_keys_len)); EXPECT_EQUAL(ticket_keys_len, 1); /* Verify that the client received NST */ @@ -765,22 +765,41 @@ int main(int argc, char **argv) EXPECT_SUCCESS(s2n_config_add_ticket_crypto_key(server_config, ticket_key_name3, s2n_array_len(ticket_key_name3), ticket_key3, s2n_array_len(ticket_key3), 0)); /* Try adding the expired keys */ - EXPECT_EQUAL(s2n_config_add_ticket_crypto_key(server_config, ticket_key_name2, s2n_array_len(ticket_key_name2), ticket_key2, s2n_array_len(ticket_key2), 0), -1); - EXPECT_EQUAL(s2n_config_add_ticket_crypto_key(server_config, ticket_key_name1, s2n_array_len(ticket_key_name1), ticket_key1, s2n_array_len(ticket_key1), 0), -1); + EXPECT_SUCCESS(s2n_config_add_ticket_crypto_key(server_config, ticket_key_name2, s2n_array_len(ticket_key_name2), ticket_key2, s2n_array_len(ticket_key2), 0)); + EXPECT_SUCCESS(s2n_config_add_ticket_crypto_key(server_config, ticket_key_name1, s2n_array_len(ticket_key_name1), ticket_key1, s2n_array_len(ticket_key1), 0)); - /* Verify that the config has only one unexpired key */ - EXPECT_OK(s2n_set_get(server_config->ticket_keys, 0, (void **) &ticket_key)); + /* Verify that the config has three unexpired keys */ + EXPECT_OK(s2n_array_get(server_config->ticket_keys, 0, (void **) &ticket_key)); + /* ticket_key3 should have "rotated" to the first index as other keys expired */ EXPECT_BYTEARRAY_EQUAL(ticket_key->key_name, ticket_key_name3, s2n_array_len(ticket_key_name3)); - EXPECT_OK(s2n_set_len(server_config->ticket_keys, &ticket_keys_len)); - EXPECT_EQUAL(ticket_keys_len, 1); - - /* Verify that the total number of key hashes is three */ - EXPECT_OK(s2n_set_len(server_config->ticket_key_hashes, &ticket_keys_len)); + EXPECT_OK(s2n_array_num_elements(server_config->ticket_keys, &ticket_keys_len)); EXPECT_EQUAL(ticket_keys_len, 3); EXPECT_SUCCESS(s2n_config_free(server_config)); }; + /* Attempting to add more than S2N_MAX_TICKET_KEYS causes failures. */ + { + DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(), s2n_config_ptr_free); + EXPECT_SUCCESS(s2n_config_set_session_tickets_onoff(config, 1)); + EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key)); + + uint8_t id = 0; + uint8_t ticket_key_buf[32] = { 0 }; + + for (uint8_t i = 0; i < S2N_MAX_TICKET_KEYS; i++) { + id = i; + ticket_key_buf[0] = i; + EXPECT_SUCCESS(s2n_config_add_ticket_crypto_key(config, + &id, sizeof(id), ticket_key_buf, s2n_array_len(ticket_key_buf), 0)); + } + + id = S2N_MAX_TICKET_KEYS; + ticket_key_buf[0] = S2N_MAX_TICKET_KEYS; + EXPECT_FAILURE(s2n_config_add_ticket_crypto_key(config, &id, sizeof(id), + ticket_key_buf, s2n_array_len(ticket_key_buf), 0)); + }; + /* Scenario 1: Client sends empty ST and server has multiple encrypt-decrypt keys to choose from for encrypting NST. */ { EXPECT_NOT_NULL(client_config = s2n_config_new()); @@ -1058,14 +1077,6 @@ int main(int argc, char **argv) EXPECT_BYTEARRAY_EQUAL(serialized_session_state + S2N_TICKET_KEY_NAME_LOCATION, ticket_key_name2, s2n_array_len(ticket_key_name2)); - /* Verify that the keys are stored from oldest to newest */ - EXPECT_OK(s2n_set_get(server_config->ticket_keys, 0, (void **) &ticket_key)); - EXPECT_BYTEARRAY_EQUAL(ticket_key->key_name, ticket_key_name2, s2n_array_len(ticket_key_name2)); - EXPECT_OK(s2n_set_get(server_config->ticket_keys, 1, (void **) &ticket_key)); - EXPECT_BYTEARRAY_EQUAL(ticket_key->key_name, ticket_key_name1, s2n_array_len(ticket_key_name1)); - EXPECT_OK(s2n_set_get(server_config->ticket_keys, 2, (void **) &ticket_key)); - EXPECT_BYTEARRAY_EQUAL(ticket_key->key_name, ticket_key_name3, s2n_array_len(ticket_key_name3)); - EXPECT_SUCCESS(s2n_shutdown_test_server_and_client(server_conn, client_conn)); EXPECT_SUCCESS(s2n_connection_free(server_conn)); diff --git a/tests/unit/s2n_signature_scheme_test.c b/tests/unit/s2n_signature_scheme_test.c index ec4c052baab..b081af28df4 100644 --- a/tests/unit/s2n_signature_scheme_test.c +++ b/tests/unit/s2n_signature_scheme_test.c @@ -13,7 +13,7 @@ * permissions and limitations under the License. */ -#include "tls/s2n_signature_scheme.c" +#include "tls/s2n_signature_scheme.h" #include "s2n_test.h" @@ -21,11 +21,16 @@ int main(int argc, char **argv) { BEGIN_TEST(); + const struct s2n_signature_preferences *all_prefs = &s2n_signature_preferences_all; + /* Test all signature schemes */ size_t policy_i = 0; while (security_policy_selection[policy_i].version != NULL) { const struct s2n_signature_preferences *sig_prefs = security_policy_selection[policy_i].security_policy->signature_preferences; + + bool s2n_rsa_pkcs1_md5_sha1_found = false; + for (size_t sig_i = 0; sig_i < sig_prefs->count; sig_i++) { const struct s2n_signature_scheme *const sig_scheme = sig_prefs->signature_schemes[sig_i]; @@ -50,7 +55,36 @@ int main(int argc, char **argv) sig_prefs->signature_schemes[dup_i]; EXPECT_NOT_EQUAL(sig_scheme->iana_value, potential_duplicate->iana_value); } + + if (sig_scheme == &s2n_rsa_pkcs1_md5_sha1) { + s2n_rsa_pkcs1_md5_sha1_found = true; + } + + /* s2n_null_sig_scheme is not a real signature scheme and is just a placeholder. + * It should not appear in any policy. + */ + EXPECT_NOT_EQUAL(sig_scheme, &s2n_null_sig_scheme); + + /* Must be included in s2n_signature_preferences_all */ + bool in_all = false; + for (size_t all_i = 0; all_i < all_prefs->count; all_i++) { + if (sig_scheme == all_prefs->signature_schemes[all_i]) { + in_all = true; + } + } + EXPECT_TRUE(in_all); } + + /* Only s2n_signature_preferences_all should include s2n_rsa_pkcs1_md5_sha1 + * + * s2n_rsa_pkcs1_md5_sha1 is the implicit default for pre-TLS1.2 when no signature + * schemes are provided. Any code that needs to handle "all signature schemes" + * also needs to handle s2n_rsa_pkcs1_md5_sha1. It is not explicitly included + * in any real signature preferences, but should still be tracked by + * s2n_signature_preferences_all. + */ + EXPECT_EQUAL(s2n_rsa_pkcs1_md5_sha1_found, sig_prefs == all_prefs); + policy_i++; } diff --git a/tests/unit/s2n_ssl_prf_test.c b/tests/unit/s2n_ssl_prf_test.c index 2eaa262b113..b3bd3be0e24 100644 --- a/tests/unit/s2n_ssl_prf_test.c +++ b/tests/unit/s2n_ssl_prf_test.c @@ -22,8 +22,9 @@ #include "testlib/s2n_testlib.h" #include "tls/s2n_prf.h" -/* +int s2n_prf_tls_master_secret(struct s2n_connection *conn, struct s2n_blob *premaster_secret); +/* * Grabbed from gnutls-cli --insecure -d 9 www.example.com --ciphers AES --macs SHA --protocols SSLv3 * * |<9>| INT: PREMASTER SECRET[48]: 03009e8e006a7f1451d32164088a8cba5077d1b819160662a97e90a765cec244b5f8f98fd50cfe8e4fba97994a7a4843 @@ -86,7 +87,7 @@ int main(int argc, char **argv) conn->actual_protocol_version = S2N_SSLv3; pms.data = conn->secrets.version.tls12.rsa_premaster_secret; pms.size = sizeof(conn->secrets.version.tls12.rsa_premaster_secret); - EXPECT_SUCCESS(s2n_tls_prf_master_secret(conn, &pms)); + EXPECT_SUCCESS(s2n_prf_tls_master_secret(conn, &pms)); /* Convert the master secret to hex */ for (int i = 0; i < 48; i++) { diff --git a/tests/unit/s2n_tls_hybrid_prf_test.c b/tests/unit/s2n_tls_hybrid_prf_test.c index 0e93b8a64fb..acdc4dfed65 100644 --- a/tests/unit/s2n_tls_hybrid_prf_test.c +++ b/tests/unit/s2n_tls_hybrid_prf_test.c @@ -106,7 +106,7 @@ int main(int argc, char **argv) EXPECT_MEMCPY_SUCCESS(conn->kex_params.client_key_exchange_message.data, client_key_exchange_message, client_key_exchange_message_length); - EXPECT_SUCCESS(s2n_hybrid_prf_master_secret(conn, &combined_pms)); + EXPECT_SUCCESS(s2n_prf_hybrid_master_secret(conn, &combined_pms)); EXPECT_BYTEARRAY_EQUAL(expected_master_secret, conn->secrets.version.tls12.master_secret, S2N_TLS_SECRET_LEN); EXPECT_SUCCESS(s2n_free(&conn->kex_params.client_key_exchange_message)); EXPECT_SUCCESS(s2n_connection_free(conn)); diff --git a/tests/unit/s2n_tls_prf_test.c b/tests/unit/s2n_tls_prf_test.c index faca2f3b543..d8ec146b047 100644 --- a/tests/unit/s2n_tls_prf_test.c +++ b/tests/unit/s2n_tls_prf_test.c @@ -26,6 +26,16 @@ #define TEST_BLOB_SIZE 64 +bool s2n_libcrypto_supports_tls_prf(); +int s2n_prf(struct s2n_connection *conn, struct s2n_blob *secret, + struct s2n_blob *label, struct s2n_blob *seed_a, + struct s2n_blob *seed_b, struct s2n_blob *seed_c, struct s2n_blob *out); +S2N_RESULT s2n_prf_get_digest_for_ems(struct s2n_connection *conn, + struct s2n_blob *message, s2n_hash_algorithm hash_alg, struct s2n_blob *output); +S2N_RESULT s2n_prf_tls_extended_master_secret(struct s2n_connection *conn, + struct s2n_blob *premaster_secret, struct s2n_blob *session_hash, struct s2n_blob *sha1_hash); +int s2n_prf_tls_master_secret(struct s2n_connection *conn, struct s2n_blob *premaster_secret); + /* * Grabbed from gnutls-cli --insecure -d 9 www.example.com --ciphers AES --macs SHA --protocols TLS1.0 * @@ -63,7 +73,7 @@ int main(int argc, char **argv) struct s2n_blob pms = { 0 }; EXPECT_SUCCESS(s2n_blob_init(&pms, conn->secrets.version.tls12.rsa_premaster_secret, sizeof(conn->secrets.version.tls12.rsa_premaster_secret))); - EXPECT_SUCCESS(s2n_tls_prf_master_secret(conn, &pms)); + EXPECT_SUCCESS(s2n_prf_tls_master_secret(conn, &pms)); EXPECT_EQUAL(memcmp(conn->secrets.version.tls12.master_secret, master_secret_in.data, master_secret_in.size), 0); EXPECT_SUCCESS(s2n_connection_free(conn)); @@ -100,7 +110,7 @@ int main(int argc, char **argv) *# session_hash) *# [0..47]; */ - EXPECT_OK(s2n_tls_prf_extended_master_secret(conn, &premaster_secret, &hash_digest, NULL)); + EXPECT_OK(s2n_prf_tls_extended_master_secret(conn, &premaster_secret, &hash_digest, NULL)); EXPECT_BYTEARRAY_EQUAL(extended_master_secret.data, conn->secrets.version.tls12.master_secret, S2N_TLS_SECRET_LEN); EXPECT_SUCCESS(s2n_connection_free(conn)); @@ -293,20 +303,20 @@ int main(int argc, char **argv) EXPECT_MEMCPY_SUCCESS(conn->secrets.version.tls12.rsa_premaster_secret, premaster_secret_in.data, premaster_secret_in.size); EXPECT_MEMCPY_SUCCESS(conn->handshake_params.client_random, client_random_in.data, client_random_in.size); EXPECT_MEMCPY_SUCCESS(conn->handshake_params.server_random, server_random_in.data, server_random_in.size); - EXPECT_SUCCESS(s2n_tls_prf_master_secret(conn, &pms)); + EXPECT_SUCCESS(s2n_prf_tls_master_secret(conn, &pms)); EXPECT_EQUAL(memcmp(conn->secrets.version.tls12.master_secret, master_secret_in.data, master_secret_in.size), 0); EXPECT_SUCCESS(s2n_connection_free_handshake(conn)); EXPECT_MEMCPY_SUCCESS(conn->secrets.version.tls12.rsa_premaster_secret, premaster_secret_in.data, premaster_secret_in.size); EXPECT_MEMCPY_SUCCESS(conn->handshake_params.client_random, client_random_in.data, client_random_in.size); EXPECT_MEMCPY_SUCCESS(conn->handshake_params.server_random, server_random_in.data, server_random_in.size); - EXPECT_FAILURE_WITH_ERRNO(s2n_tls_prf_master_secret(conn, &pms), S2N_ERR_NULL); + EXPECT_FAILURE_WITH_ERRNO(s2n_prf_tls_master_secret(conn, &pms), S2N_ERR_NULL); EXPECT_SUCCESS(s2n_connection_wipe(conn)); EXPECT_MEMCPY_SUCCESS(conn->secrets.version.tls12.rsa_premaster_secret, premaster_secret_in.data, premaster_secret_in.size); EXPECT_MEMCPY_SUCCESS(conn->handshake_params.client_random, client_random_in.data, client_random_in.size); EXPECT_MEMCPY_SUCCESS(conn->handshake_params.server_random, server_random_in.data, server_random_in.size); - EXPECT_SUCCESS(s2n_tls_prf_master_secret(conn, &pms)); + EXPECT_SUCCESS(s2n_prf_tls_master_secret(conn, &pms)); EXPECT_EQUAL(memcmp(conn->secrets.version.tls12.master_secret, master_secret_in.data, master_secret_in.size), 0); EXPECT_SUCCESS(s2n_connection_free(conn)); diff --git a/tls/s2n_client_cert.c b/tls/s2n_client_cert.c index 8d4af88645c..2dc540670bb 100644 --- a/tls/s2n_client_cert.c +++ b/tls/s2n_client_cert.c @@ -57,8 +57,12 @@ static S2N_RESULT s2n_client_cert_chain_store(struct s2n_connection *conn, RESULT_ENSURE_REF(conn); RESULT_ENSURE_REF(raw_cert_chain); - /* There shouldn't already be a client cert chain, but free just in case */ - RESULT_GUARD_POSIX(s2n_free(&conn->handshake_params.client_cert_chain)); + /* If a client cert chain has already been stored (e.g. on the re-entry case + * of an async callback), no need to store it again. + */ + if (conn->handshake_params.client_cert_chain.size > 0) { + return S2N_RESULT_OK; + } /* Earlier versions are a basic copy */ if (conn->actual_protocol_version < S2N_TLS13) { @@ -101,23 +105,26 @@ static S2N_RESULT s2n_client_cert_chain_store(struct s2n_connection *conn, int s2n_client_cert_recv(struct s2n_connection *conn) { + /* s2n_client_cert_recv() may be re-entered due to handling an async callback. + * We operate on a copy of `handshake.io` to ensure the stuffer is initilized properly on the re-entry case. + */ + struct s2n_stuffer in = conn->handshake.io; + if (conn->actual_protocol_version == S2N_TLS13) { uint8_t certificate_request_context_len = 0; - POSIX_GUARD(s2n_stuffer_read_uint8(&conn->handshake.io, &certificate_request_context_len)); + POSIX_GUARD(s2n_stuffer_read_uint8(&in, &certificate_request_context_len)); S2N_ERROR_IF(certificate_request_context_len != 0, S2N_ERR_BAD_MESSAGE); } - struct s2n_stuffer *in = &conn->handshake.io; - uint32_t cert_chain_size = 0; - POSIX_GUARD(s2n_stuffer_read_uint24(in, &cert_chain_size)); - POSIX_ENSURE(cert_chain_size <= s2n_stuffer_data_available(in), S2N_ERR_BAD_MESSAGE); + POSIX_GUARD(s2n_stuffer_read_uint24(&in, &cert_chain_size)); + POSIX_ENSURE(cert_chain_size <= s2n_stuffer_data_available(&in), S2N_ERR_BAD_MESSAGE); if (cert_chain_size == 0) { POSIX_GUARD(s2n_conn_set_handshake_no_client_cert(conn)); return S2N_SUCCESS; } - uint8_t *cert_chain_data = s2n_stuffer_raw_read(in, cert_chain_size); + uint8_t *cert_chain_data = s2n_stuffer_raw_read(&in, cert_chain_size); POSIX_ENSURE_REF(cert_chain_data); struct s2n_blob cert_chain = { 0 }; @@ -139,6 +146,9 @@ int s2n_client_cert_recv(struct s2n_connection *conn) POSIX_GUARD(s2n_pkey_check_key_exists(&public_key)); conn->handshake_params.client_public_key = public_key; + /* Update handshake.io to reflect the true stuffer state after all async callbacks are handled. */ + conn->handshake.io = in; + return S2N_SUCCESS; } diff --git a/tls/s2n_config.c b/tls/s2n_config.c index f0bbb623266..778dce920e8 100644 --- a/tls/s2n_config.c +++ b/tls/s2n_config.c @@ -308,28 +308,10 @@ struct s2n_config *s2n_config_new(void) return new_config; } -static int s2n_config_store_ticket_key_comparator(const void *a, const void *b) -{ - if (((const struct s2n_ticket_key *) a)->intro_timestamp >= ((const struct s2n_ticket_key *) b)->intro_timestamp) { - return S2N_GREATER_OR_EQUAL; - } else { - return S2N_LESS_THAN; - } -} - -static int s2n_verify_unique_ticket_key_comparator(const void *a, const void *b) -{ - return memcmp(a, b, SHA_DIGEST_LENGTH); -} - int s2n_config_init_session_ticket_keys(struct s2n_config *config) { if (config->ticket_keys == NULL) { - POSIX_ENSURE_REF(config->ticket_keys = s2n_set_new(sizeof(struct s2n_ticket_key), s2n_config_store_ticket_key_comparator)); - } - - if (config->ticket_key_hashes == NULL) { - POSIX_ENSURE_REF(config->ticket_key_hashes = s2n_set_new(SHA_DIGEST_LENGTH, s2n_verify_unique_ticket_key_comparator)); + POSIX_ENSURE_REF(config->ticket_keys = s2n_array_new_with_capacity(sizeof(struct s2n_ticket_key), S2N_MAX_TICKET_KEYS)); } return 0; @@ -338,11 +320,7 @@ int s2n_config_init_session_ticket_keys(struct s2n_config *config) int s2n_config_free_session_ticket_keys(struct s2n_config *config) { if (config->ticket_keys != NULL) { - POSIX_GUARD_RESULT(s2n_set_free_p(&config->ticket_keys)); - } - - if (config->ticket_key_hashes != NULL) { - POSIX_GUARD_RESULT(s2n_set_free_p(&config->ticket_key_hashes)); + POSIX_GUARD_RESULT(s2n_array_free_p(&config->ticket_keys)); } return 0; @@ -956,7 +934,7 @@ int s2n_config_add_ticket_crypto_key(struct s2n_config *config, POSIX_ENSURE(key_len != 0, S2N_ERR_INVALID_TICKET_KEY_LENGTH); uint32_t ticket_keys_len = 0; - POSIX_GUARD_RESULT(s2n_set_len(config->ticket_keys, &ticket_keys_len)); + POSIX_GUARD_RESULT(s2n_array_num_elements(config->ticket_keys, &ticket_keys_len)); POSIX_ENSURE(ticket_keys_len < S2N_MAX_TICKET_KEYS, S2N_ERR_TICKET_KEY_LIMIT); POSIX_ENSURE(name_len != 0, S2N_ERR_INVALID_TICKET_KEY_NAME_OR_NAME_LENGTH); @@ -967,9 +945,6 @@ int s2n_config_add_ticket_crypto_key(struct s2n_config *config, uint8_t name_data[S2N_TICKET_KEY_NAME_LEN] = { 0 }; POSIX_CHECKED_MEMCPY(name_data, name, name_len); - /* ensure the ticket name is not already present */ - POSIX_ENSURE(s2n_find_ticket_key(config, name_data) == NULL, S2N_ERR_INVALID_TICKET_KEY_NAME_OR_NAME_LENGTH); - uint8_t output_pad[S2N_AES256_KEY_LEN + S2N_TICKET_AAD_IMPLICIT_LEN] = { 0 }; struct s2n_blob out_key = { 0 }; POSIX_GUARD(s2n_blob_init(&out_key, output_pad, s2n_array_len(output_pad))); @@ -990,23 +965,6 @@ int s2n_config_add_ticket_crypto_key(struct s2n_config *config, POSIX_GUARD(s2n_hmac_new(&hmac)); POSIX_GUARD(s2n_hkdf(&hmac, S2N_HMAC_SHA256, &salt, &in_key, &info, &out_key)); - DEFER_CLEANUP(struct s2n_hash_state hash = { 0 }, s2n_hash_free); - uint8_t hash_output[SHA_DIGEST_LENGTH] = { 0 }; - - POSIX_GUARD(s2n_hash_new(&hash)); - POSIX_GUARD(s2n_hash_init(&hash, S2N_HASH_SHA1)); - POSIX_GUARD(s2n_hash_update(&hash, out_key.data, out_key.size)); - POSIX_GUARD(s2n_hash_digest(&hash, hash_output, SHA_DIGEST_LENGTH)); - - POSIX_GUARD_RESULT(s2n_set_len(config->ticket_keys, &ticket_keys_len)); - if (ticket_keys_len >= S2N_MAX_TICKET_KEY_HASHES) { - POSIX_GUARD_RESULT(s2n_set_free_p(&config->ticket_key_hashes)); - POSIX_ENSURE_REF(config->ticket_key_hashes = s2n_set_new(SHA_DIGEST_LENGTH, s2n_verify_unique_ticket_key_comparator)); - } - - /* Insert hash key into a sorted array at known index */ - POSIX_GUARD_RESULT(s2n_set_add(config->ticket_key_hashes, hash_output)); - POSIX_CHECKED_MEMCPY(session_ticket_key->key_name, name_data, s2n_array_len(name_data)); POSIX_CHECKED_MEMCPY(session_ticket_key->aes_key, out_key.data, S2N_AES256_KEY_LEN); out_key.data = output_pad + S2N_AES256_KEY_LEN; diff --git a/tls/s2n_config.h b/tls/s2n_config.h index 07d6166d762..1f780b041b6 100644 --- a/tls/s2n_config.h +++ b/tls/s2n_config.h @@ -31,8 +31,7 @@ #include "utils/s2n_blob.h" #include "utils/s2n_set.h" -#define S2N_MAX_TICKET_KEYS 48 -#define S2N_MAX_TICKET_KEY_HASHES 500 /* 10KB */ +#define S2N_MAX_TICKET_KEYS 48 /* * TLS1.3 does not allow alert messages to be fragmented, and some TLS @@ -133,8 +132,7 @@ struct s2n_config { uint64_t session_state_lifetime_in_nanos; - struct s2n_set *ticket_keys; - struct s2n_set *ticket_key_hashes; + struct s2n_array *ticket_keys; uint64_t encrypt_decrypt_key_lifetime_in_nanos; uint64_t decrypt_key_lifetime_in_nanos; diff --git a/tls/s2n_kex.c b/tls/s2n_kex.c index 7da8c7b2586..84e0fd61647 100644 --- a/tls/s2n_kex.c +++ b/tls/s2n_kex.c @@ -245,7 +245,7 @@ const struct s2n_kex s2n_hybrid_ecdhe_kem = { .server_key_send = &s2n_hybrid_server_key_send, .client_key_recv = &s2n_hybrid_client_key_recv, .client_key_send = &s2n_hybrid_client_key_send, - .prf = &s2n_hybrid_prf_master_secret, + .prf = &s2n_prf_hybrid_master_secret, }; /* TLS1.3 key exchange is implemented differently from previous versions and does diff --git a/tls/s2n_prf.c b/tls/s2n_prf.c index 763c8058f2e..73db2366723 100644 --- a/tls/s2n_prf.c +++ b/tls/s2n_prf.c @@ -34,6 +34,29 @@ #include "utils/s2n_mem.h" #include "utils/s2n_safety.h" +#if defined(OPENSSL_IS_AWSLC) + #define S2N_LIBCRYPTO_SUPPORTS_TLS_PRF 1 +#else + #define S2N_LIBCRYPTO_SUPPORTS_TLS_PRF 0 +#endif + +/* The s2n p_hash implementation is abstracted to allow for separate implementations, using + * either s2n's formally verified HMAC or OpenSSL's EVP HMAC, for use by the TLS PRF. */ +struct s2n_p_hash_hmac { + int (*alloc)(struct s2n_prf_working_space *ws); + int (*init)(struct s2n_prf_working_space *ws, s2n_hmac_algorithm alg, struct s2n_blob *secret); + int (*update)(struct s2n_prf_working_space *ws, const void *data, uint32_t size); + int (*final)(struct s2n_prf_working_space *ws, void *digest, uint32_t size); + int (*reset)(struct s2n_prf_working_space *ws); + int (*cleanup)(struct s2n_prf_working_space *ws); + int (*free)(struct s2n_prf_working_space *ws); +}; + +S2N_RESULT s2n_prf_get_digest_for_ems(struct s2n_connection *conn, struct s2n_blob *message, + s2n_hash_algorithm hash_alg, struct s2n_blob *output); +S2N_RESULT s2n_prf_tls_extended_master_secret(struct s2n_connection *conn, + struct s2n_blob *premaster_secret, struct s2n_blob *session_hash, struct s2n_blob *sha1_hash); + S2N_RESULT s2n_key_material_init(struct s2n_key_material *key_material, struct s2n_connection *conn) { RESULT_ENSURE_REF(key_material); @@ -114,7 +137,7 @@ S2N_RESULT s2n_key_material_init(struct s2n_key_material *key_material, struct s return S2N_RESULT_OK; } -static int s2n_sslv3_prf(struct s2n_connection *conn, struct s2n_blob *secret, struct s2n_blob *seed_a, +static int s2n_prf_sslv3(struct s2n_connection *conn, struct s2n_blob *secret, struct s2n_blob *seed_a, struct s2n_blob *seed_b, struct s2n_blob *seed_c, struct s2n_blob *out) { POSIX_ENSURE_REF(conn); @@ -509,7 +532,7 @@ bool s2n_libcrypto_supports_tls_prf() #endif } -S2N_RESULT s2n_custom_prf(struct s2n_connection *conn, struct s2n_blob *secret, struct s2n_blob *label, +S2N_RESULT s2n_prf_custom(struct s2n_connection *conn, struct s2n_blob *secret, struct s2n_blob *label, struct s2n_blob *seed_a, struct s2n_blob *seed_b, struct s2n_blob *seed_c, struct s2n_blob *out) { /* We zero the out blob because p_hash works by XOR'ing with the existing @@ -553,7 +576,7 @@ int CRYPTO_tls1_prf(const EVP_MD *digest, const uint8_t *seed1, size_t seed1_len, const uint8_t *seed2, size_t seed2_len); -S2N_RESULT s2n_libcrypto_prf(struct s2n_connection *conn, struct s2n_blob *secret, struct s2n_blob *label, +S2N_RESULT s2n_prf_libcrypto(struct s2n_connection *conn, struct s2n_blob *secret, struct s2n_blob *label, struct s2n_blob *seed_a, struct s2n_blob *seed_b, struct s2n_blob *seed_c, struct s2n_blob *out) { const EVP_MD *digest = NULL; @@ -601,7 +624,7 @@ S2N_RESULT s2n_libcrypto_prf(struct s2n_connection *conn, struct s2n_blob *secre return S2N_RESULT_OK; } #else -S2N_RESULT s2n_libcrypto_prf(struct s2n_connection *conn, struct s2n_blob *secret, struct s2n_blob *label, +S2N_RESULT s2n_prf_libcrypto(struct s2n_connection *conn, struct s2n_blob *secret, struct s2n_blob *label, struct s2n_blob *seed_a, struct s2n_blob *seed_b, struct s2n_blob *seed_c, struct s2n_blob *out) { RESULT_BAIL(S2N_ERR_UNIMPLEMENTED); @@ -624,7 +647,7 @@ int s2n_prf(struct s2n_connection *conn, struct s2n_blob *secret, struct s2n_blo POSIX_ENSURE(S2N_IMPLIES(seed_c != NULL, seed_b != NULL), S2N_ERR_PRF_INVALID_SEED); if (conn->actual_protocol_version == S2N_SSLv3) { - POSIX_GUARD(s2n_sslv3_prf(conn, secret, seed_a, seed_b, seed_c, out)); + POSIX_GUARD(s2n_prf_sslv3(conn, secret, seed_a, seed_b, seed_c, out)); return S2N_SUCCESS; } @@ -632,16 +655,16 @@ int s2n_prf(struct s2n_connection *conn, struct s2n_blob *secret, struct s2n_blo * FIPS-validated libcrypto implementation is used instead, if an implementation is provided. */ if (s2n_is_in_fips_mode() && s2n_libcrypto_supports_tls_prf()) { - POSIX_GUARD_RESULT(s2n_libcrypto_prf(conn, secret, label, seed_a, seed_b, seed_c, out)); + POSIX_GUARD_RESULT(s2n_prf_libcrypto(conn, secret, label, seed_a, seed_b, seed_c, out)); return S2N_SUCCESS; } - POSIX_GUARD_RESULT(s2n_custom_prf(conn, secret, label, seed_a, seed_b, seed_c, out)); + POSIX_GUARD_RESULT(s2n_prf_custom(conn, secret, label, seed_a, seed_b, seed_c, out)); return S2N_SUCCESS; } -int s2n_tls_prf_master_secret(struct s2n_connection *conn, struct s2n_blob *premaster_secret) +int s2n_prf_tls_master_secret(struct s2n_connection *conn, struct s2n_blob *premaster_secret) { POSIX_ENSURE_REF(conn); @@ -659,7 +682,7 @@ int s2n_tls_prf_master_secret(struct s2n_connection *conn, struct s2n_blob *prem return s2n_prf(conn, premaster_secret, &label, &client_random, &server_random, NULL, &master_secret); } -int s2n_hybrid_prf_master_secret(struct s2n_connection *conn, struct s2n_blob *premaster_secret) +int s2n_prf_hybrid_master_secret(struct s2n_connection *conn, struct s2n_blob *premaster_secret) { POSIX_ENSURE_REF(conn); @@ -685,7 +708,7 @@ int s2n_prf_calculate_master_secret(struct s2n_connection *conn, struct s2n_blob POSIX_ENSURE_EQ(s2n_conn_get_current_message_type(conn), CLIENT_KEY); if (!conn->ems_negotiated) { - POSIX_GUARD(s2n_tls_prf_master_secret(conn, premaster_secret)); + POSIX_GUARD(s2n_prf_tls_master_secret(conn, premaster_secret)); return S2N_SUCCESS; } @@ -708,13 +731,13 @@ int s2n_prf_calculate_master_secret(struct s2n_connection *conn, struct s2n_blob POSIX_GUARD(s2n_blob_init(&sha1_digest, sha1_data, sizeof(sha1_data))); POSIX_GUARD_RESULT(s2n_prf_get_digest_for_ems(conn, &client_key_blob, S2N_HASH_MD5, &digest)); POSIX_GUARD_RESULT(s2n_prf_get_digest_for_ems(conn, &client_key_blob, S2N_HASH_SHA1, &sha1_digest)); - POSIX_GUARD_RESULT(s2n_tls_prf_extended_master_secret(conn, premaster_secret, &digest, &sha1_digest)); + POSIX_GUARD_RESULT(s2n_prf_tls_extended_master_secret(conn, premaster_secret, &digest, &sha1_digest)); } else { s2n_hmac_algorithm prf_alg = conn->secure->cipher_suite->prf_alg; s2n_hash_algorithm hash_alg = 0; POSIX_GUARD(s2n_hmac_hash_alg(prf_alg, &hash_alg)); POSIX_GUARD_RESULT(s2n_prf_get_digest_for_ems(conn, &client_key_blob, hash_alg, &digest)); - POSIX_GUARD_RESULT(s2n_tls_prf_extended_master_secret(conn, premaster_secret, &digest, NULL)); + POSIX_GUARD_RESULT(s2n_prf_tls_extended_master_secret(conn, premaster_secret, &digest, NULL)); } return S2N_SUCCESS; } @@ -728,7 +751,7 @@ int s2n_prf_calculate_master_secret(struct s2n_connection *conn, struct s2n_blob *# session_hash) *# [0..47]; */ -S2N_RESULT s2n_tls_prf_extended_master_secret(struct s2n_connection *conn, struct s2n_blob *premaster_secret, struct s2n_blob *session_hash, struct s2n_blob *sha1_hash) +S2N_RESULT s2n_prf_tls_extended_master_secret(struct s2n_connection *conn, struct s2n_blob *premaster_secret, struct s2n_blob *session_hash, struct s2n_blob *sha1_hash) { RESULT_ENSURE_REF(conn); @@ -765,7 +788,7 @@ S2N_RESULT s2n_prf_get_digest_for_ems(struct s2n_connection *conn, struct s2n_bl return S2N_RESULT_OK; } -static int s2n_sslv3_finished(struct s2n_connection *conn, uint8_t prefix[4], struct s2n_hash_state *hash_workspace, uint8_t *out) +static int s2n_prf_sslv3_finished(struct s2n_connection *conn, uint8_t prefix[4], struct s2n_hash_state *hash_workspace, uint8_t *out) { POSIX_ENSURE_REF(conn); POSIX_ENSURE_REF(conn->handshake.hashes); @@ -808,24 +831,24 @@ static int s2n_sslv3_finished(struct s2n_connection *conn, uint8_t prefix[4], st return 0; } -static int s2n_sslv3_client_finished(struct s2n_connection *conn) +static int s2n_prf_sslv3_client_finished(struct s2n_connection *conn) { POSIX_ENSURE_REF(conn); POSIX_ENSURE_REF(conn->handshake.hashes); uint8_t prefix[4] = { 0x43, 0x4c, 0x4e, 0x54 }; - return s2n_sslv3_finished(conn, prefix, &conn->handshake.hashes->hash_workspace, conn->handshake.client_finished); + return s2n_prf_sslv3_finished(conn, prefix, &conn->handshake.hashes->hash_workspace, conn->handshake.client_finished); } -static int s2n_sslv3_server_finished(struct s2n_connection *conn) +static int s2n_prf_sslv3_server_finished(struct s2n_connection *conn) { POSIX_ENSURE_REF(conn); POSIX_ENSURE_REF(conn->handshake.hashes); uint8_t prefix[4] = { 0x53, 0x52, 0x56, 0x52 }; - return s2n_sslv3_finished(conn, prefix, &conn->handshake.hashes->hash_workspace, conn->handshake.server_finished); + return s2n_prf_sslv3_finished(conn, prefix, &conn->handshake.hashes->hash_workspace, conn->handshake.server_finished); } int s2n_prf_client_finished(struct s2n_connection *conn) @@ -842,7 +865,7 @@ int s2n_prf_client_finished(struct s2n_connection *conn) struct s2n_blob label = { 0 }; if (conn->actual_protocol_version == S2N_SSLv3) { - return s2n_sslv3_client_finished(conn); + return s2n_prf_sslv3_client_finished(conn); } client_finished.data = conn->handshake.client_finished; @@ -900,7 +923,7 @@ int s2n_prf_server_finished(struct s2n_connection *conn) struct s2n_blob label = { 0 }; if (conn->actual_protocol_version == S2N_SSLv3) { - return s2n_sslv3_server_finished(conn); + return s2n_prf_sslv3_server_finished(conn); } server_finished.data = conn->handshake.server_finished; diff --git a/tls/s2n_prf.h b/tls/s2n_prf.h index 4b3b99d3702..c5ade9ca838 100644 --- a/tls/s2n_prf.h +++ b/tls/s2n_prf.h @@ -17,19 +17,13 @@ #include -#include "crypto/s2n_hash.h" #include "crypto/s2n_hmac.h" +#include "tls/s2n_connection.h" #include "utils/s2n_blob.h" /* Enough to support TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, 2*SHA384_DIGEST_LEN + 2*AES256_KEY_SIZE */ #define S2N_MAX_KEY_BLOCK_LEN 160 -#if defined(OPENSSL_IS_AWSLC) - #define S2N_LIBCRYPTO_SUPPORTS_TLS_PRF 1 -#else - #define S2N_LIBCRYPTO_SUPPORTS_TLS_PRF 0 -#endif - union p_hash_state { struct s2n_hmac_state s2n_hmac; struct s2n_evp_hmac_state evp_hmac; @@ -41,18 +35,6 @@ struct s2n_prf_working_space { uint8_t digest1[S2N_MAX_DIGEST_LEN]; }; -/* The s2n p_hash implementation is abstracted to allow for separate implementations, using - * either s2n's formally verified HMAC or OpenSSL's EVP HMAC, for use by the TLS PRF. */ -struct s2n_p_hash_hmac { - int (*alloc)(struct s2n_prf_working_space *ws); - int (*init)(struct s2n_prf_working_space *ws, s2n_hmac_algorithm alg, struct s2n_blob *secret); - int (*update)(struct s2n_prf_working_space *ws, const void *data, uint32_t size); - int (*final)(struct s2n_prf_working_space *ws, void *digest, uint32_t size); - int (*reset)(struct s2n_prf_working_space *ws); - int (*cleanup)(struct s2n_prf_working_space *ws); - int (*free)(struct s2n_prf_working_space *ws); -}; - /* TLS key expansion results in an array of contiguous data which is then * interpreted as the MAC, KEY and IV for the client and server. * @@ -75,27 +57,13 @@ struct s2n_key_material { S2N_RESULT s2n_key_material_init(struct s2n_key_material *key_material, struct s2n_connection *conn); -#include "tls/s2n_connection.h" - S2N_RESULT s2n_prf_new(struct s2n_connection *conn); S2N_RESULT s2n_prf_wipe(struct s2n_connection *conn); S2N_RESULT s2n_prf_free(struct s2n_connection *conn); -int s2n_prf(struct s2n_connection *conn, struct s2n_blob *secret, struct s2n_blob *label, struct s2n_blob *seed_a, - struct s2n_blob *seed_b, struct s2n_blob *seed_c, struct s2n_blob *out); int s2n_prf_calculate_master_secret(struct s2n_connection *conn, struct s2n_blob *premaster_secret); -int s2n_tls_prf_master_secret(struct s2n_connection *conn, struct s2n_blob *premaster_secret); -int s2n_hybrid_prf_master_secret(struct s2n_connection *conn, struct s2n_blob *premaster_secret); -S2N_RESULT s2n_tls_prf_extended_master_secret(struct s2n_connection *conn, struct s2n_blob *premaster_secret, struct s2n_blob *session_hash, struct s2n_blob *sha1_hash); -S2N_RESULT s2n_prf_get_digest_for_ems(struct s2n_connection *conn, struct s2n_blob *message, s2n_hash_algorithm hash_alg, struct s2n_blob *output); +int s2n_prf_hybrid_master_secret(struct s2n_connection *conn, struct s2n_blob *premaster_secret); S2N_RESULT s2n_prf_generate_key_material(struct s2n_connection *conn, struct s2n_key_material *key_material); int s2n_prf_key_expansion(struct s2n_connection *conn); int s2n_prf_server_finished(struct s2n_connection *conn); int s2n_prf_client_finished(struct s2n_connection *conn); - -bool s2n_libcrypto_supports_tls_prf(); - -S2N_RESULT s2n_custom_prf(struct s2n_connection *conn, struct s2n_blob *secret, struct s2n_blob *label, - struct s2n_blob *seed_a, struct s2n_blob *seed_b, struct s2n_blob *seed_c, struct s2n_blob *out); -S2N_RESULT s2n_libcrypto_prf(struct s2n_connection *conn, struct s2n_blob *secret, struct s2n_blob *label, - struct s2n_blob *seed_a, struct s2n_blob *seed_b, struct s2n_blob *seed_c, struct s2n_blob *out); diff --git a/tls/s2n_resume.c b/tls/s2n_resume.c index a8b368e9d0b..8966e31d2ec 100644 --- a/tls/s2n_resume.c +++ b/tls/s2n_resume.c @@ -637,11 +637,11 @@ S2N_RESULT s2n_config_is_encrypt_key_available(struct s2n_config *config) RESULT_ENSURE_REF(config->ticket_keys); uint32_t ticket_keys_len = 0; - RESULT_GUARD(s2n_set_len(config->ticket_keys, &ticket_keys_len)); + RESULT_GUARD(s2n_array_num_elements(config->ticket_keys, &ticket_keys_len)); for (uint32_t i = ticket_keys_len; i > 0; i--) { uint32_t idx = i - 1; - RESULT_GUARD(s2n_set_get(config->ticket_keys, idx, (void **) &ticket_key)); + RESULT_GUARD(s2n_array_get(config->ticket_keys, idx, (void **) &ticket_key)); uint64_t key_intro_time = ticket_key->intro_timestamp; if (key_intro_time <= now @@ -668,7 +668,7 @@ int s2n_compute_weight_of_encrypt_decrypt_keys(struct s2n_config *config, /* Compute weight of encrypt-decrypt keys */ for (int i = 0; i < num_encrypt_decrypt_keys; i++) { - POSIX_GUARD_RESULT(s2n_set_get(config->ticket_keys, encrypt_decrypt_keys_index[i], (void **) &ticket_key)); + POSIX_GUARD_RESULT(s2n_array_get(config->ticket_keys, encrypt_decrypt_keys_index[i], (void **) &ticket_key)); uint64_t key_intro_time = ticket_key->intro_timestamp; uint64_t key_encryption_peak_time = key_intro_time + (config->encrypt_decrypt_key_lifetime_in_nanos / 2); @@ -720,11 +720,11 @@ struct s2n_ticket_key *s2n_get_ticket_encrypt_decrypt_key(struct s2n_config *con PTR_ENSURE_REF(config->ticket_keys); uint32_t ticket_keys_len = 0; - PTR_GUARD_RESULT(s2n_set_len(config->ticket_keys, &ticket_keys_len)); + PTR_GUARD_RESULT(s2n_array_num_elements(config->ticket_keys, &ticket_keys_len)); for (uint32_t i = ticket_keys_len; i > 0; i--) { uint32_t idx = i - 1; - PTR_GUARD_RESULT(s2n_set_get(config->ticket_keys, idx, (void **) &ticket_key)); + PTR_GUARD_RESULT(s2n_array_get(config->ticket_keys, idx, (void **) &ticket_key)); uint64_t key_intro_time = ticket_key->intro_timestamp; /* A key can be used at its intro time (<=) and it can be used up to (<) @@ -742,14 +742,14 @@ struct s2n_ticket_key *s2n_get_ticket_encrypt_decrypt_key(struct s2n_config *con } if (num_encrypt_decrypt_keys == 1) { - PTR_GUARD_RESULT(s2n_set_get(config->ticket_keys, encrypt_decrypt_keys_index[0], (void **) &ticket_key)); + PTR_GUARD_RESULT(s2n_array_get(config->ticket_keys, encrypt_decrypt_keys_index[0], (void **) &ticket_key)); return ticket_key; } int8_t idx = 0; PTR_GUARD_POSIX(idx = s2n_compute_weight_of_encrypt_decrypt_keys(config, encrypt_decrypt_keys_index, num_encrypt_decrypt_keys, now)); - PTR_GUARD_RESULT(s2n_set_get(config->ticket_keys, idx, (void **) &ticket_key)); + PTR_GUARD_RESULT(s2n_array_get(config->ticket_keys, idx, (void **) &ticket_key)); return ticket_key; } @@ -764,10 +764,10 @@ struct s2n_ticket_key *s2n_find_ticket_key(struct s2n_config *config, const uint PTR_ENSURE_REF(config->ticket_keys); uint32_t ticket_keys_len = 0; - PTR_GUARD_RESULT(s2n_set_len(config->ticket_keys, &ticket_keys_len)); + PTR_GUARD_RESULT(s2n_array_num_elements(config->ticket_keys, &ticket_keys_len)); for (uint32_t i = 0; i < ticket_keys_len; i++) { - PTR_GUARD_RESULT(s2n_set_get(config->ticket_keys, i, (void **) &ticket_key)); + PTR_GUARD_RESULT(s2n_array_get(config->ticket_keys, i, (void **) &ticket_key)); if (s2n_constant_time_equals(ticket_key->key_name, name, S2N_TICKET_KEY_NAME_LEN)) { /* Check to see if the key has expired */ @@ -1013,10 +1013,9 @@ int s2n_config_wipe_expired_ticket_crypto_keys(struct s2n_config *config, int8_t POSIX_ENSURE_REF(config->ticket_keys); uint32_t ticket_keys_len = 0; - POSIX_GUARD_RESULT(s2n_set_len(config->ticket_keys, &ticket_keys_len)); - + POSIX_GUARD_RESULT(s2n_array_num_elements(config->ticket_keys, &ticket_keys_len)); for (uint32_t i = 0; i < ticket_keys_len; i++) { - POSIX_GUARD_RESULT(s2n_set_get(config->ticket_keys, i, (void **) &ticket_key)); + POSIX_GUARD_RESULT(s2n_array_get(config->ticket_keys, i, (void **) &ticket_key)); if (now >= ticket_key->intro_timestamp + config->encrypt_decrypt_key_lifetime_in_nanos + config->decrypt_key_lifetime_in_nanos) { @@ -1027,7 +1026,7 @@ int s2n_config_wipe_expired_ticket_crypto_keys(struct s2n_config *config, int8_t end: for (int j = 0; j < num_of_expired_keys; j++) { - POSIX_GUARD_RESULT(s2n_set_remove(config->ticket_keys, expired_keys_index[j] - j)); + POSIX_GUARD_RESULT(s2n_array_remove(config->ticket_keys, expired_keys_index[j] - j)); } return 0; @@ -1035,8 +1034,20 @@ int s2n_config_wipe_expired_ticket_crypto_keys(struct s2n_config *config, int8_t int s2n_config_store_ticket_key(struct s2n_config *config, struct s2n_ticket_key *key) { - /* Keys are stored from oldest to newest */ - POSIX_GUARD_RESULT(s2n_set_add(config->ticket_keys, key)); + uint32_t ticket_keys_len = 0; + POSIX_GUARD_RESULT(s2n_array_num_elements(config->ticket_keys, &ticket_keys_len)); + + /* The ticket key name and secret must both be unique. */ + for (uint32_t i = 0; i < ticket_keys_len; i++) { + struct s2n_ticket_key *other_key = NULL; + POSIX_GUARD_RESULT(s2n_array_get(config->ticket_keys, i, (void **) &other_key)); + POSIX_ENSURE(!s2n_constant_time_equals(key->key_name, other_key->key_name, s2n_array_len(key->key_name)), + S2N_ERR_INVALID_TICKET_KEY_NAME_OR_NAME_LENGTH); + POSIX_ENSURE(!s2n_constant_time_equals(key->aes_key, other_key->aes_key, s2n_array_len(key->aes_key)), + S2N_ERR_TICKET_KEY_NOT_UNIQUE); + } + + POSIX_GUARD_RESULT(s2n_array_insert_and_copy(config->ticket_keys, ticket_keys_len, key)); return S2N_SUCCESS; } diff --git a/tls/s2n_security_policies.c b/tls/s2n_security_policies.c index 28e6f297475..d88dce76aa2 100644 --- a/tls/s2n_security_policies.c +++ b/tls/s2n_security_policies.c @@ -1154,7 +1154,7 @@ const struct s2n_security_policy security_policy_test_all = { .minimum_protocol_version = S2N_SSLv3, .cipher_preferences = &cipher_preferences_test_all, .kem_preferences = &kem_preferences_all, - .signature_preferences = &s2n_signature_preferences_20201021, + .signature_preferences = &s2n_signature_preferences_all, .ecc_preferences = &s2n_ecc_preferences_test_all, }; diff --git a/tls/s2n_server_cert.c b/tls/s2n_server_cert.c index 5c1882ceb29..7ba1564eb03 100644 --- a/tls/s2n_server_cert.c +++ b/tls/s2n_server_cert.c @@ -22,16 +22,21 @@ int s2n_server_cert_recv(struct s2n_connection *conn) { + /* s2n_server_cert_recv() may be re-entered due to handling an async callback. + * We operate on a copy of `handshake.io` to ensure the stuffer is initilized properly on the re-entry case. + */ + struct s2n_stuffer in = conn->handshake.io; + if (conn->actual_protocol_version == S2N_TLS13) { uint8_t certificate_request_context_len = 0; - POSIX_GUARD(s2n_stuffer_read_uint8(&conn->handshake.io, &certificate_request_context_len)); + POSIX_GUARD(s2n_stuffer_read_uint8(&in, &certificate_request_context_len)); S2N_ERROR_IF(certificate_request_context_len != 0, S2N_ERR_BAD_MESSAGE); } uint32_t size_of_all_certificates = 0; - POSIX_GUARD(s2n_stuffer_read_uint24(&conn->handshake.io, &size_of_all_certificates)); + POSIX_GUARD(s2n_stuffer_read_uint24(&in, &size_of_all_certificates)); - S2N_ERROR_IF(size_of_all_certificates > s2n_stuffer_data_available(&conn->handshake.io) || size_of_all_certificates < 3, + S2N_ERROR_IF(size_of_all_certificates > s2n_stuffer_data_available(&in) || size_of_all_certificates < 3, S2N_ERR_BAD_MESSAGE); s2n_cert_public_key public_key; @@ -40,7 +45,7 @@ int s2n_server_cert_recv(struct s2n_connection *conn) s2n_pkey_type actual_cert_pkey_type; struct s2n_blob cert_chain = { 0 }; cert_chain.size = size_of_all_certificates; - cert_chain.data = s2n_stuffer_raw_read(&conn->handshake.io, size_of_all_certificates); + cert_chain.data = s2n_stuffer_raw_read(&in, size_of_all_certificates); POSIX_ENSURE_REF(cert_chain.data); POSIX_GUARD_RESULT(s2n_x509_validator_validate_cert_chain(&conn->x509_validator, conn, cert_chain.data, @@ -50,6 +55,9 @@ int s2n_server_cert_recv(struct s2n_connection *conn) POSIX_GUARD_RESULT(s2n_pkey_setup_for_type(&public_key, actual_cert_pkey_type)); conn->handshake_params.server_public_key = public_key; + /* Update handshake.io to reflect the true stuffer state after all async callbacks are handled. */ + conn->handshake.io = in; + return 0; } diff --git a/tls/s2n_signature_scheme.c b/tls/s2n_signature_scheme.c index 93ca30fb602..797dcff8624 100644 --- a/tls/s2n_signature_scheme.c +++ b/tls/s2n_signature_scheme.c @@ -204,6 +204,34 @@ const struct s2n_signature_scheme s2n_rsa_pss_pss_sha512 = { .minimum_protocol_version = S2N_TLS12, }; +/* ALL signature schemes, including the legacy default s2n_rsa_pkcs1_md5_sha1 scheme. + * New signature schemes must be added to this list. + */ +const struct s2n_signature_scheme* const s2n_sig_scheme_pref_list_all[] = { + &s2n_rsa_pkcs1_md5_sha1, + &s2n_rsa_pkcs1_sha1, + &s2n_rsa_pkcs1_sha224, + &s2n_rsa_pkcs1_sha256, + &s2n_rsa_pkcs1_sha384, + &s2n_rsa_pkcs1_sha512, + &s2n_ecdsa_sha1, + &s2n_ecdsa_sha224, + &s2n_ecdsa_sha256, + &s2n_ecdsa_sha384, + &s2n_ecdsa_sha512, + &s2n_rsa_pss_rsae_sha256, + &s2n_rsa_pss_rsae_sha384, + &s2n_rsa_pss_rsae_sha512, + &s2n_rsa_pss_pss_sha256, + &s2n_rsa_pss_pss_sha384, + &s2n_rsa_pss_pss_sha512, +}; + +const struct s2n_signature_preferences s2n_signature_preferences_all = { + .count = s2n_array_len(s2n_sig_scheme_pref_list_all), + .signature_schemes = s2n_sig_scheme_pref_list_all, +}; + /* Chosen based on AWS server recommendations as of 05/24. * * The recommendations do not include PKCS1, but we must include it anyway for diff --git a/tls/s2n_signature_scheme.h b/tls/s2n_signature_scheme.h index 1bac827c72e..b475a740d45 100644 --- a/tls/s2n_signature_scheme.h +++ b/tls/s2n_signature_scheme.h @@ -82,5 +82,6 @@ extern const struct s2n_signature_preferences s2n_certificate_signature_preferen extern const struct s2n_signature_preferences s2n_signature_preferences_default_fips; extern const struct s2n_signature_preferences s2n_signature_preferences_null; extern const struct s2n_signature_preferences s2n_signature_preferences_test_all_fips; +extern const struct s2n_signature_preferences s2n_signature_preferences_all; extern const struct s2n_signature_preferences s2n_certificate_signature_preferences_20201110; diff --git a/tls/s2n_x509_validator.c b/tls/s2n_x509_validator.c index ad9ac61a3fd..efddae222c5 100644 --- a/tls/s2n_x509_validator.c +++ b/tls/s2n_x509_validator.c @@ -149,6 +149,8 @@ int s2n_x509_validator_init_no_x509_validation(struct s2n_x509_validator *valida validator->state = INIT; validator->cert_chain_from_wire = sk_X509_new_null(); validator->crl_lookup_list = NULL; + validator->cert_validation_info = (struct s2n_cert_validation_info){ 0 }; + validator->cert_validation_cb_invoked = false; return 0; } @@ -168,6 +170,8 @@ int s2n_x509_validator_init(struct s2n_x509_validator *validator, struct s2n_x50 validator->cert_chain_from_wire = sk_X509_new_null(); validator->state = INIT; validator->crl_lookup_list = NULL; + validator->cert_validation_info = (struct s2n_cert_validation_info){ 0 }; + validator->cert_validation_cb_invoked = false; return 0; } @@ -750,8 +754,8 @@ static S2N_RESULT s2n_x509_validator_parse_leaf_certificate_extensions(struct s2 return S2N_RESULT_OK; } -S2N_RESULT s2n_x509_validator_validate_cert_chain(struct s2n_x509_validator *validator, struct s2n_connection *conn, - uint8_t *cert_chain_in, uint32_t cert_chain_len, s2n_pkey_type *pkey_type, struct s2n_pkey *public_key_out) +S2N_RESULT s2n_x509_validator_validate_cert_chain_pre_cb(struct s2n_x509_validator *validator, struct s2n_connection *conn, + uint8_t *cert_chain_in, uint32_t cert_chain_len) { RESULT_ENSURE_REF(conn); RESULT_ENSURE_REF(conn->config); @@ -788,12 +792,37 @@ S2N_RESULT s2n_x509_validator_validate_cert_chain(struct s2n_x509_validator *val RESULT_GUARD_POSIX(s2n_extension_list_process(S2N_EXTENSION_LIST_CERTIFICATE, conn, &first_certificate_extensions)); } - if (conn->config->cert_validation_cb) { - struct s2n_cert_validation_info info = { 0 }; - RESULT_ENSURE(conn->config->cert_validation_cb(conn, &info, conn->config->cert_validation_ctx) >= S2N_SUCCESS, - S2N_ERR_CANCELLED); - RESULT_ENSURE(info.finished, S2N_ERR_INVALID_STATE); - RESULT_ENSURE(info.accepted, S2N_ERR_CERT_REJECTED); + return S2N_RESULT_OK; +} + +static S2N_RESULT s2n_x509_validator_handle_cert_validation_callback_result(struct s2n_x509_validator *validator) +{ + RESULT_ENSURE_REF(validator); + + if (!validator->cert_validation_info.finished) { + RESULT_BAIL(S2N_ERR_ASYNC_BLOCKED); + } + + RESULT_ENSURE(validator->cert_validation_info.accepted, S2N_ERR_CERT_REJECTED); + return S2N_RESULT_OK; +} + +S2N_RESULT s2n_x509_validator_validate_cert_chain(struct s2n_x509_validator *validator, struct s2n_connection *conn, + uint8_t *cert_chain_in, uint32_t cert_chain_len, s2n_pkey_type *pkey_type, struct s2n_pkey *public_key_out) +{ + RESULT_ENSURE_REF(validator); + + if (validator->cert_validation_cb_invoked) { + RESULT_GUARD(s2n_x509_validator_handle_cert_validation_callback_result(validator)); + } else { + RESULT_GUARD(s2n_x509_validator_validate_cert_chain_pre_cb(validator, conn, cert_chain_in, cert_chain_len)); + + if (conn->config->cert_validation_cb) { + RESULT_ENSURE(conn->config->cert_validation_cb(conn, &(validator->cert_validation_info), conn->config->cert_validation_ctx) == S2N_SUCCESS, + S2N_ERR_CANCELLED); + validator->cert_validation_cb_invoked = true; + RESULT_GUARD(s2n_x509_validator_handle_cert_validation_callback_result(validator)); + } } /* retrieve information from leaf cert */ diff --git a/tls/s2n_x509_validator.h b/tls/s2n_x509_validator.h index 7706fc0a031..3bdfaf9a959 100644 --- a/tls/s2n_x509_validator.h +++ b/tls/s2n_x509_validator.h @@ -52,6 +52,11 @@ struct s2n_x509_trust_store { unsigned loaded_system_certs : 1; }; +struct s2n_cert_validation_info { + unsigned finished : 1; + unsigned accepted : 1; +}; + /** * You should have one instance of this per connection. */ @@ -64,11 +69,8 @@ struct s2n_x509_validator { STACK_OF(X509) *cert_chain_from_wire; int state; struct s2n_array *crl_lookup_list; -}; - -struct s2n_cert_validation_info { - unsigned finished : 1; - unsigned accepted : 1; + struct s2n_cert_validation_info cert_validation_info; + bool cert_validation_cb_invoked; }; /** Some libcrypto implementations do not support OCSP validation. Returns 1 if supported, 0 otherwise. */