Skip to content

Commit

Permalink
ACVP ML-KEM testing
Browse files Browse the repository at this point in the history
  • Loading branch information
skmcgrail committed Sep 9, 2024
1 parent e4092fb commit 476c6c6
Show file tree
Hide file tree
Showing 6 changed files with 330 additions and 4 deletions.
194 changes: 194 additions & 0 deletions util/fipstools/acvp/acvptool/subprocess/ml_kem.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 OR ISC

package subprocess

import (
"encoding/json"
"fmt"
"strings"
)

type mlKem struct{}

func (*mlKem) Process(vectorSet []byte, m Transactable) (interface{}, error) {
var vs struct {
Mode string `json:"mode"`
TestGroups json.RawMessage `json:"testGroups"`
}

if err := json.Unmarshal(vectorSet, &vs); err != nil {
return nil, err
}

switch {
case strings.EqualFold(vs.Mode, "keyGen"):
return processMlKemKeyGen(vs.TestGroups, m)
case strings.EqualFold(vs.Mode, "encapDecap"):
return processMlKemEncapDecap(vs.TestGroups, m)
}

return nil, fmt.Errorf("unknown ML-KEM mode: %v", vs.Mode)
}

type mlKemKeyGenTestGroup struct {
ID uint64 `json:"tgId"`
Type string `json:"testType"`
ParameterSet string `json:"parameterSet"`
Tests []struct {
ID uint64 `json:"tcId"`
D hexEncodedByteString `json:"d"`
Z hexEncodedByteString `json:"z"`
}
}

type mlKemKeyGenTestGroupResponse struct {
ID uint64 `json:"tgId"`
Tests []mlKemKeyGenTestCaseResponse `json:"tests"`
}

type mlKemKeyGenTestCaseResponse struct {
ID uint64 `json:"tcId"`
EK hexEncodedByteString `json:"ek"`
DK hexEncodedByteString `json:"dk"`
}

func processMlKemKeyGen(vectors json.RawMessage, m Transactable) (interface{}, error) {
var groups []mlKemKeyGenTestGroup

if err := json.Unmarshal(vectors, &groups); err != nil {
return nil, err
}

var responses []mlKemKeyGenTestGroupResponse

for _, group := range groups {
if !strings.EqualFold(group.Type, "AFT") {
return nil, fmt.Errorf("unsupported keyGen test type: %v", group.Type)
}

response := mlKemKeyGenTestGroupResponse{
ID: group.ID,
}

for _, test := range group.Tests {
results, err := m.Transact("ML-KEM/"+group.ParameterSet+"/keyGen", 2, test.D, test.Z)
if err != nil {
return nil, err
}

ek := results[0]
dk := results[1]

response.Tests = append(response.Tests, mlKemKeyGenTestCaseResponse{
ID: test.ID,
EK: ek,
DK: dk,
})
}

responses = append(responses, response)
}

return responses, nil
}

type mlKemEncapDecapTestGroup struct {
ID uint64 `json:"tgId"`
Type string `json:"testType"`
ParameterSet string `json:"parameterSet"`
Function string `json:"function"`
DK hexEncodedByteString `json:"dk"`
Tests []struct {
ID uint64 `json:"tcId"`
EK hexEncodedByteString `json:"ek"`
M hexEncodedByteString `json:"m"`
C hexEncodedByteString `json:"c"`
}
}

type mlKemEncDecapTestGroupResponse struct {
ID uint64 `json:"tgId"`
Tests []mlKemEncDecapTestCaseResponse `json:"tests"`
}

type mlKemEncDecapTestCaseResponse struct {
ID uint64 `json:"tcId"`
C hexEncodedByteString `json:"c,omitempty"`
K hexEncodedByteString `json:"k,omitempty"`
}

func processMlKemEncapDecap(vectors json.RawMessage, m Transactable) (interface{}, error) {
var groups []mlKemEncapDecapTestGroup

if err := json.Unmarshal(vectors, &groups); err != nil {
return nil, err
}

var responses []mlKemEncDecapTestGroupResponse

for _, group := range groups {
if (strings.EqualFold(group.Function, "encapsulation") && !strings.EqualFold(group.Type, "AFT")) ||
(strings.EqualFold(group.Function, "decapsulation") && !strings.EqualFold(group.Type, "VAL")) {
return nil, fmt.Errorf("unsupported encapDecap function and test group type pair: (%v, %v)", group.Function, group.Type)
}

response := mlKemEncDecapTestGroupResponse{
ID: group.ID,
}

for _, test := range group.Tests {
var (
err error
testResponse mlKemEncDecapTestCaseResponse
)

switch {
case strings.EqualFold(group.Function, "encapsulation"):
testResponse, err = processMlKemEncapTestCase(test.ID, group.ParameterSet, test.EK, test.M, m)
case strings.EqualFold(group.Function, "decapsulation"):
testResponse, err = processMlKemDecapTestCase(test.ID, group.ParameterSet, group.DK, test.C, m)
default:
return nil, fmt.Errorf("unknown encDecap function: %v", group.Function)
}
if err != nil {
return nil, err
}

response.Tests = append(response.Tests, testResponse)
}

responses = append(responses, response)
}
return responses, nil
}

func processMlKemEncapTestCase(id uint64, algorithm string, ek []byte, m []byte, t Transactable) (mlKemEncDecapTestCaseResponse, error) {
results, err := t.Transact("ML-KEM/"+algorithm+"/encap", 2, ek, m)
if err != nil {
return mlKemEncDecapTestCaseResponse{}, err
}

c := results[0]
k := results[1]

return mlKemEncDecapTestCaseResponse{
ID: id,
C: c,
K: k,
}, nil
}

func processMlKemDecapTestCase(id uint64, algorithm string, dk []byte, c []byte, t Transactable) (mlKemEncDecapTestCaseResponse, error) {
results, err := t.Transact("ML-KEM/"+algorithm+"/decap", 1, dk, c)
if err != nil {
return mlKemEncDecapTestCaseResponse{}, err
}

k := results[0]

return mlKemEncDecapTestCaseResponse{
ID: id,
K: k,
}, nil
}
1 change: 1 addition & 0 deletions util/fipstools/acvp/acvptool/subprocess/subprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ func NewWithIO(cmd *exec.Cmd, in io.WriteCloser, out io.ReadCloser) *Subprocess
"KAS-ECC-SSC": &kas{},
"KAS-FFC-SSC": &kasDH{},
"PBKDF": &pbkdf{},
"ML-KEM": &mlKem{},
}
m.primitives["ECDSA"] = &ecdsa{"ECDSA", map[string]bool{"P-224": true, "P-256": true, "P-384": true, "P-521": true}, m.primitives}

Expand Down
Binary file not shown.
3 changes: 2 additions & 1 deletion util/fipstools/acvp/acvptool/test/tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@
{"Wrapper": "modulewrapper", "In": "vectors/TLS-1.2-KDF.bz2", "Out": "expected/TLS-1.2-KDF.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/PBKDF.bz2", "Out": "expected/PBKDF.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/KDA-HKDF.bz2", "Out": "expected/KDA-HKDF.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/KDA-OneStep.bz2", "Out": "expected/KDA-OneStep.bz2"}
{"Wrapper": "modulewrapper", "In": "vectors/KDA-OneStep.bz2", "Out": "expected/KDA-OneStep.bz2"},
{"Wrapper": "modulewrapper", "In": "vectors/ML-KEM.bz2", "Out": "expected/ML-KEM.bz2"}
]
Binary file not shown.
136 changes: 133 additions & 3 deletions util/fipstools/acvp/modulewrapper/modulewrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
* OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
* CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */

#include <signal.h>
#include <algorithm>
#include <map>
#include <string>
#include <vector>
#include <signal.h>
#include <cstring>

#include <sstream>

Expand All @@ -41,6 +42,7 @@
#include <openssl/ecdsa.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/experimental/kem_deterministic_api.h>
#include <openssl/hkdf.h>
#include <openssl/hmac.h>
#include <openssl/kdf.h>
Expand Down Expand Up @@ -732,7 +734,7 @@ static bool GetConfig(const Span<const uint8_t> args[],
}]
}]
},)"
R"({
R"({
"sigType": "pss",
"properties": [{
"modulo": 2048,
Expand Down Expand Up @@ -988,7 +990,7 @@ static bool GetConfig(const Span<const uint8_t> args[],
}]
}]
},)"
R"({
R"({
"sigType": "pss",
"properties": [{
"modulo": 2048,
Expand Down Expand Up @@ -1307,6 +1309,19 @@ static bool GetConfig(const Span<const uint8_t> args[],
"encoding": ["concatenation"],
"z": [{"min": 224, "max": 8192, "increment": 8}],
"l": 2048
},)"
R"({
"algorithm": "ML-KEM",
"mode": "keyGen",
"revision": "FIPS203",
"parameterSets": ["ML-KEM-512", "ML-KEM-768", "ML-KEM-1024"]
},
{
"algorithm": "ML-KEM",
"mode": "encapDecap",
"revision": "FIPS203",
"parameterSets": ["ML-KEM-512", "ML-KEM-768", "ML-KEM-1024"],
"functions": ["encapsulation", "decapsulation"]
}
])";
return write_reply({Span<const uint8_t>(
Expand Down Expand Up @@ -2830,6 +2845,112 @@ static bool KBKDF_CTR_HMAC(const Span<const uint8_t> args[],
return write_reply({Span<const uint8_t>(out)});
}

template <int nid>
static bool ML_KEM_KEYGEN(const Span<const uint8_t> args[],
ReplyCallback write_reply) {
const Span<const uint8_t> d = args[0];
const Span<const uint8_t> z = args[1];

std::vector<uint8_t> seed(d.size() + z.size());
std::memcpy(seed.data(), d.data(), d.size());
std::memcpy(seed.data() + d.size(), z.data(), z.size());

EVP_PKEY *raw = NULL;
size_t seed_len = 0;

bssl::UniquePtr<EVP_PKEY_CTX> ctx(EVP_PKEY_CTX_new_id(EVP_PKEY_KEM, nullptr));
if (!EVP_PKEY_CTX_kem_set_params(ctx.get(), nid) ||
!EVP_PKEY_keygen_init(ctx.get()) ||
!EVP_PKEY_keygen_deterministic(ctx.get(), &raw, NULL, &seed_len) ||
seed_len != seed.size() ||
!EVP_PKEY_keygen_deterministic(ctx.get(), &raw, seed.data(), &seed_len)) {
return false;
}
bssl::UniquePtr<EVP_PKEY> pkey(raw);

size_t decaps_key_size = 0;
size_t encaps_key_size = 0;

if (!EVP_PKEY_get_raw_private_key(pkey.get(), nullptr, &decaps_key_size) ||
!EVP_PKEY_get_raw_public_key(pkey.get(), nullptr, &encaps_key_size)) {
return false;
}

std::vector<uint8_t> decaps_key(decaps_key_size);
std::vector<uint8_t> encaps_key(encaps_key_size);

if (!EVP_PKEY_get_raw_private_key(pkey.get(), decaps_key.data(),
&decaps_key_size) ||
!EVP_PKEY_get_raw_public_key(pkey.get(), encaps_key.data(),
&encaps_key_size)) {
return false;
}

return write_reply({Span<const uint8_t>(encaps_key.data(), encaps_key_size),
Span<const uint8_t>(decaps_key.data(), decaps_key_size)});
}

template <int nid>
static bool ML_KEM_ENCAP(const Span<const uint8_t> args[],
ReplyCallback write_reply) {
const Span<const uint8_t> ek = args[0];
const Span<const uint8_t> m = args[1];

bssl::UniquePtr<EVP_PKEY> pkey(
EVP_PKEY_kem_new_raw_public_key(nid, ek.data(), ek.size()));
bssl::UniquePtr<EVP_PKEY_CTX> ctx(EVP_PKEY_CTX_new(pkey.get(), nullptr));

size_t ciphertext_len = 0;
size_t shared_secret_len = 0;
size_t seed_len = 0;
if (!EVP_PKEY_encapsulate_deterministic(ctx.get(), nullptr, &ciphertext_len,
nullptr, &shared_secret_len, nullptr,
&seed_len) ||
seed_len != m.size()) {
return false;
}

std::vector<uint8_t> ciphertext(ciphertext_len);
std::vector<uint8_t> shared_secret(shared_secret_len);

if (!EVP_PKEY_encapsulate_deterministic(
ctx.get(), ciphertext.data(), &ciphertext_len, shared_secret.data(),
&shared_secret_len, m.data(), &seed_len)) {
return false;
}

return write_reply(
{Span<const uint8_t>(ciphertext.data(), ciphertext_len),
Span<const uint8_t>(shared_secret.data(), shared_secret_len)});
}

template <int nid>
static bool ML_KEM_DECAP(const Span<const uint8_t> args[],
ReplyCallback write_reply) {
const Span<const uint8_t> dk = args[0];
const Span<const uint8_t> c = args[1];

bssl::UniquePtr<EVP_PKEY> pkey(
EVP_PKEY_kem_new_raw_secret_key(nid, dk.data(), dk.size()));
bssl::UniquePtr<EVP_PKEY_CTX> ctx(EVP_PKEY_CTX_new(pkey.get(), nullptr));

size_t shared_secret_len = 0;
if (!EVP_PKEY_decapsulate(ctx.get(), nullptr, &shared_secret_len, c.data(),
c.size())) {
return false;
}

std::vector<uint8_t> shared_secret(shared_secret_len);

if (!EVP_PKEY_decapsulate(ctx.get(), shared_secret.data(), &shared_secret_len,
c.data(), c.size())) {
return false;
}

return write_reply(
{Span<const uint8_t>(shared_secret.data(), shared_secret_len)});
}

static struct {
char name[kMaxNameLength + 1];
uint8_t num_expected_args;
Expand Down Expand Up @@ -3064,6 +3185,15 @@ static struct {
{"KDF/Counter/HMAC-SHA2-512", 3, KBKDF_CTR_HMAC<EVP_sha512>},
{"KDF/Counter/HMAC-SHA2-512/224", 3, KBKDF_CTR_HMAC<EVP_sha512_224>},
{"KDF/Counter/HMAC-SHA2-512/256", 3, KBKDF_CTR_HMAC<EVP_sha512_256>},
{"ML-KEM/ML-KEM-512/keyGen", 2, ML_KEM_KEYGEN<NID_MLKEM512>},
{"ML-KEM/ML-KEM-768/keyGen", 2, ML_KEM_KEYGEN<NID_MLKEM768>},
{"ML-KEM/ML-KEM-1024/keyGen", 2, ML_KEM_KEYGEN<NID_MLKEM1024>},
{"ML-KEM/ML-KEM-512/encap", 2, ML_KEM_ENCAP<NID_MLKEM512>},
{"ML-KEM/ML-KEM-768/encap", 2, ML_KEM_ENCAP<NID_MLKEM768>},
{"ML-KEM/ML-KEM-1024/encap", 2, ML_KEM_ENCAP<NID_MLKEM1024>},
{"ML-KEM/ML-KEM-512/decap", 2, ML_KEM_DECAP<NID_MLKEM512>},
{"ML-KEM/ML-KEM-768/decap", 2, ML_KEM_DECAP<NID_MLKEM768>},
{"ML-KEM/ML-KEM-1024/decap", 2, ML_KEM_DECAP<NID_MLKEM1024>},
};

Handler FindHandler(Span<const Span<const uint8_t>> args) {
Expand Down

0 comments on commit 476c6c6

Please sign in to comment.