diff --git a/README.md b/README.md index 9469de22..a5c42aa7 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,8 @@ Alternatively, look at the [Cloudflare Go](https://github.com/cloudflare/go/tree [RFC-9496]: https://doi.org/10.17487/RFC9496 [RFC-9497]: https://doi.org/10.17487/RFC9497 [FIPS 202]: https://doi.org/10.6028/NIST.FIPS.202 +[FIPS 204]: https://doi.org/10.6028/NIST.FIPS.204 +[FIPS 205]: https://doi.org/10.6028/NIST.FIPS.205 [FIPS 186-5]: https://doi.org/10.6028/NIST.FIPS.186-5 [BLS12-381]: https://electriccoin.co/blog/new-snark-curve/ [ia.cr/2015/267]: https://ia.cr/2015/267 @@ -91,7 +93,8 @@ Alternatively, look at the [Cloudflare Go](https://github.com/cloudflare/go/tree |:---:| - [Dilithium](./sign/dilithium): modes 2, 3, 5 ([Dilithium](https://pq-crystals.org/dilithium/)). - - [ML-DSA](./sign/mldsa): modes 44, 65, 87 ([FIPS 204](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.204.pdf)). + - [ML-DSA](./sign/mldsa): modes 44, 65, 87 ([FIPS 204]). + - [SLH-DSA](./sign/slhdsa): twelve parameter sets, pure and pre-hash signing ([FIPS 205]). ### Zero-knowledge Proofs diff --git a/internal/test/test.go b/internal/test/test.go index 9ba73dd7..2017e61e 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -1,6 +1,8 @@ package test import ( + "bytes" + "encoding" "errors" "fmt" "strings" @@ -58,3 +60,26 @@ func CheckPanic(f func()) error { f() return hasPanicked } + +func CheckMarshal( + t *testing.T, + x, y interface { + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler + }, +) { + t.Helper() + + want, err := x.MarshalBinary() + CheckNoErr(t, err, fmt.Sprintf("cannot marshal %T = %v", x, x)) + + err = y.UnmarshalBinary(want) + CheckNoErr(t, err, fmt.Sprintf("cannot unmarshal %T from %x", y, want)) + + got, err := y.MarshalBinary() + CheckNoErr(t, err, fmt.Sprintf("cannot marshal %T = %v", y, y)) + + if !bytes.Equal(got, want) { + ReportError(t, got, want, x, y) + } +} diff --git a/sign/schemes/schemes.go b/sign/schemes/schemes.go index d01d8ca2..cb240e25 100644 --- a/sign/schemes/schemes.go +++ b/sign/schemes/schemes.go @@ -6,23 +6,26 @@ // Ed448 // Ed25519-Dilithium2 // Ed448-Dilithium3 +// Dilithium +// ML-DSA +// SLH-DSA package schemes import ( "strings" "github.com/cloudflare/circl/sign" + dilithium2 "github.com/cloudflare/circl/sign/dilithium/mode2" + dilithium3 "github.com/cloudflare/circl/sign/dilithium/mode3" + dilithium5 "github.com/cloudflare/circl/sign/dilithium/mode5" "github.com/cloudflare/circl/sign/ed25519" "github.com/cloudflare/circl/sign/ed448" "github.com/cloudflare/circl/sign/eddilithium2" "github.com/cloudflare/circl/sign/eddilithium3" - - dilithium2 "github.com/cloudflare/circl/sign/dilithium/mode2" - dilithium3 "github.com/cloudflare/circl/sign/dilithium/mode3" - dilithium5 "github.com/cloudflare/circl/sign/dilithium/mode5" "github.com/cloudflare/circl/sign/mldsa/mldsa44" "github.com/cloudflare/circl/sign/mldsa/mldsa65" "github.com/cloudflare/circl/sign/mldsa/mldsa87" + "github.com/cloudflare/circl/sign/slhdsa" ) var allSchemes = [...]sign.Scheme{ @@ -36,6 +39,18 @@ var allSchemes = [...]sign.Scheme{ mldsa44.Scheme(), mldsa65.Scheme(), mldsa87.Scheme(), + slhdsa.SHA2_128s.Scheme(), + slhdsa.SHAKE_128s.Scheme(), + slhdsa.SHA2_128f.Scheme(), + slhdsa.SHAKE_128f.Scheme(), + slhdsa.SHA2_192s.Scheme(), + slhdsa.SHAKE_192s.Scheme(), + slhdsa.SHA2_192f.Scheme(), + slhdsa.SHAKE_192f.Scheme(), + slhdsa.SHA2_256s.Scheme(), + slhdsa.SHAKE_256s.Scheme(), + slhdsa.SHA2_256f.Scheme(), + slhdsa.SHAKE_256f.Scheme(), } var allSchemeNames map[string]sign.Scheme diff --git a/sign/schemes/schemes_test.go b/sign/schemes/schemes_test.go index 3242c796..9c9cc35e 100644 --- a/sign/schemes/schemes_test.go +++ b/sign/schemes/schemes_test.go @@ -122,6 +122,18 @@ func Example() { // ML-DSA-44 // ML-DSA-65 // ML-DSA-87 + // SLH-DSA-SHA2-128s + // SLH-DSA-SHAKE-128s + // SLH-DSA-SHA2-128f + // SLH-DSA-SHAKE-128f + // SLH-DSA-SHA2-192s + // SLH-DSA-SHAKE-192s + // SLH-DSA-SHA2-192f + // SLH-DSA-SHAKE-192f + // SLH-DSA-SHA2-256s + // SLH-DSA-SHAKE-256s + // SLH-DSA-SHA2-256f + // SLH-DSA-SHAKE-256f } func BenchmarkGenerateKeyPair(b *testing.B) { diff --git a/sign/slhdsa/acvp_test.go b/sign/slhdsa/acvp_test.go new file mode 100644 index 00000000..899f2ace --- /dev/null +++ b/sign/slhdsa/acvp_test.go @@ -0,0 +1,381 @@ +package slhdsa + +import ( + "archive/zip" + "bytes" + "crypto" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "strings" + "testing" + + "github.com/cloudflare/circl/internal/test" + "github.com/cloudflare/circl/xof" +) + +type acvpKeyGenPrompt struct { + TestGroups []struct { + TestType string `json:"testType"` + ParameterSet string `json:"parameterSet"` + Tests []keyGenInput `json:"tests"` + TgID int `json:"tgId"` + } `json:"testGroups"` +} + +type keyGenInput struct { + SkSeed Hex `json:"skSeed"` + SkPrf Hex `json:"skPrf"` + PkSeed Hex `json:"pkSeed"` + TcID int `json:"tcId"` +} + +type acvpKeyGenResult struct { + TestGroups []struct { + Tests []struct { + Sk Hex `json:"sk"` + Pk Hex `json:"pk"` + TcID int `json:"tcId"` + } `json:"tests"` + TgID int `json:"tgId"` + } `json:"testGroups"` +} + +type acvpSigGenPrompt struct { + TestGroups []struct { + sigGenParams + TestType string `json:"testType"` + Tests []sigGenInput `json:"tests"` + TgID int `json:"tgId"` + } `json:"testGroups"` +} + +type sigParams struct { + ParameterSet string `json:"parameterSet"` + SigInterface string `json:"signatureInterface"` + PreHash string `json:"preHash"` +} + +type sigGenParams struct { + sigParams + IsDeterministic bool `json:"deterministic"` +} + +type sigGenInput struct { + HashAlg string `json:"hashAlg,omitempty"` + Sk Hex `json:"sk"` + Msg Hex `json:"message"` + Ctx Hex `json:"context,omitempty"` + AddRand Hex `json:"additionalRandomness,omitempty"` + TcID int `json:"tcId"` +} + +type acvpSigGenResult struct { + TestGroups []struct { + Tests []struct { + Signature Hex `json:"signature"` + TcID int `json:"tcId"` + } `json:"tests"` + TgID int `json:"tgId"` + } `json:"testGroups"` +} + +type acvpVerifyInput struct { + TestGroups []struct { + sigParams + TestType string `json:"testType"` + Tests []verifyInput `json:"tests"` + TgID int `json:"tgId"` + } `json:"testGroups"` +} + +type verifyInput struct { + HashAlg string `json:"hashAlg,omitempty"` + Pk Hex `json:"pk"` + Msg Hex `json:"message"` + Sig Hex `json:"signature"` + Ctx Hex `json:"context,omitempty"` + TcID int `json:"tcId"` +} + +type acvpVerifyResult struct { + TestGroups []struct { + Tests []struct { + TcID int `json:"tcId"` + TestPassed bool `json:"testPassed"` + } `json:"tests"` + TgID int `json:"tgId"` + } `json:"testGroups"` +} + +func TestACVP(t *testing.T) { + t.Run("Keygen", testKeygen) + t.Run("Sign", testSign) + t.Run("Verify", testVerify) +} + +func testKeygen(t *testing.T) { + // https://github.com/usnistgov/ACVP-Server/tree/v1.1.0.38/gen-val/json-files/SLH-DSA-keyGen-FIPS205 + inputs := new(acvpKeyGenPrompt) + readVector(t, "testdata/keyGen_prompt.json.zip", inputs) + outputs := new(acvpKeyGenResult) + readVector(t, "testdata/keyGen_results.json.zip", outputs) + + for gi, group := range inputs.TestGroups { + t.Run(fmt.Sprintf("TgID_%v", group.TgID), func(t *testing.T) { + if strings.HasSuffix(group.ParameterSet, "s") { + SkipLongTest(t) + } + + for ti := range group.Tests { + test.CheckOk( + group.Tests[ti].TcID == outputs.TestGroups[gi].Tests[ti].TcID, + "mismatch of TcID", t, + ) + + t.Run(fmt.Sprintf("TcID_%v", group.Tests[ti].TcID), + func(t *testing.T) { + acvpKeygen(t, group.ParameterSet, &group.Tests[ti], + outputs.TestGroups[gi].Tests[ti].Sk, + outputs.TestGroups[gi].Tests[ti].Pk, + ) + }) + } + }) + } +} + +func testSign(t *testing.T) { + // https://github.com/usnistgov/ACVP-Server/tree/v1.1.0.38/gen-val/json-files/SLH-DSA-sigGen-FIPS205 + inputs := new(acvpSigGenPrompt) + readVector(t, "testdata/sigGen_prompt.json.zip", inputs) + outputs := new(acvpSigGenResult) + readVector(t, "testdata/sigGen_results.json.zip", outputs) + + for gi, group := range inputs.TestGroups { + test.CheckOk(group.TgID == outputs.TestGroups[gi].TgID, "mismatch of TgID", t) + + t.Run(fmt.Sprintf("TgID_%v", group.TgID), func(t *testing.T) { + if strings.HasSuffix(group.ParameterSet, "s") { + SkipLongTest(t) + } + + for ti := range group.Tests { + test.CheckOk( + group.Tests[ti].TcID == outputs.TestGroups[gi].Tests[ti].TcID, + "mismatch of TcID", t, + ) + + t.Run(fmt.Sprintf("TcID_%v", group.Tests[ti].TcID), + func(t *testing.T) { + acvpSign(t, &group.sigGenParams, &group.Tests[ti], + outputs.TestGroups[gi].Tests[ti].Signature) + }) + } + }) + } +} + +func testVerify(t *testing.T) { + // https://github.com/usnistgov/ACVP-Server/tree/v1.1.0.38/gen-val/json-files/SLH-DSA-sigVer-FIPS205 + inputs := new(acvpVerifyInput) + readVector(t, "testdata/verify_prompt.json.zip", inputs) + outputs := new(acvpVerifyResult) + readVector(t, "testdata/verify_results.json.zip", outputs) + + for gi, group := range inputs.TestGroups { + test.CheckOk(group.TgID == outputs.TestGroups[gi].TgID, "mismatch of TgID", t) + + t.Run(fmt.Sprintf("TgID_%v", group.TgID), func(t *testing.T) { + if strings.HasSuffix(group.ParameterSet, "s") { + SkipLongTest(t) + } + + for ti := range group.Tests { + test.CheckOk( + group.Tests[ti].TcID == outputs.TestGroups[gi].Tests[ti].TcID, + "mismatch of TcID", t, + ) + + t.Run(fmt.Sprintf("TcID_%v", group.Tests[ti].TcID), + func(t *testing.T) { + acvpVerify(t, &group.sigParams, &group.Tests[ti], + outputs.TestGroups[gi].Tests[ti].TestPassed, + ) + }) + } + }) + } +} + +func acvpKeygen( + t *testing.T, paramSet string, in *keyGenInput, wantSk, wantPk []byte, +) { + id, err := IDByName(paramSet) + test.CheckNoErr(t, err, "invalid ParameterSet") + + var buffer bytes.Buffer + _, _ = buffer.Write(in.SkSeed) + _, _ = buffer.Write(in.SkPrf) + _, _ = buffer.Write(in.PkSeed) + pk, sk, err := GenerateKey(&buffer, id) + test.CheckNoErr(t, err, "GenerateKey failed") + + skGot, err := sk.MarshalBinary() + test.CheckNoErr(t, err, "PrivateKey.MarshalBinary failed") + + if !bytes.Equal(skGot, wantSk) { + test.ReportError(t, skGot, wantSk) + } + + pkGot, err := pk.MarshalBinary() + test.CheckNoErr(t, err, "PublicKey.MarshalBinary failed") + + if !bytes.Equal(pkGot, wantPk) { + test.ReportError(t, pkGot, wantPk) + } +} + +func acvpSign(t *testing.T, p *sigGenParams, in *sigGenInput, wantSig []byte) { + id, err := IDByName(p.ParameterSet) + test.CheckNoErr(t, err, "invalid ParameterSet") + + sk := PrivateKey{ID: id} + err = sk.UnmarshalBinary(in.Sk) + test.CheckNoErr(t, err, "PrivateKey.UnmarshalBinary failed") + + var gotSig []byte + if p.SigInterface == "internal" { + SkipLongTest(t) + + addRand := sk.publicKey.seed + if !p.IsDeterministic { + addRand = in.AddRand + } + + gotSig, err = slhSignInternal(&sk, in.Msg, addRand) + test.CheckNoErr(t, err, "slhSignInternal failed") + + if !bytes.Equal(gotSig, wantSig) { + more := " ... (more bytes differ)" + got := hex.EncodeToString(gotSig[:10]) + more + want := hex.EncodeToString(wantSig[:10]) + more + test.ReportError(t, got, want) + } + + valid := slhVerifyInternal(&sk.publicKey, in.Msg, gotSig) + test.CheckOk(valid, "slhVerifyInternal failed", t) + } else if p.SigInterface == "external" { + var msg *Message + if p.PreHash == "pure" { + msg = NewMessage(in.Msg) + } else if p.PreHash == "preHash" { + ph := getPreHash(t, in.HashAlg) + _, err = ph.Write(in.Msg) + test.CheckNoErr(t, err, "PreHash Write failed") + + msg, err = ph.BuildMessage() + test.CheckNoErr(t, err, "PreHash GetMessage failed") + } + + if p.IsDeterministic { + gotSig, err = SignDeterministic(&sk, msg, in.Ctx) + test.CheckNoErr(t, err, "SignDeterministic failed") + } else { + gotSig, err = SignRandomized(&sk, bytes.NewReader(in.AddRand), msg, in.Ctx) + test.CheckNoErr(t, err, "SignRandomized failed") + } + + if !bytes.Equal(gotSig, wantSig) { + more := " ... (more bytes differ)" + got := hex.EncodeToString(gotSig[:10]) + more + want := hex.EncodeToString(wantSig[:10]) + more + test.ReportError(t, got, want) + } + + pk := sk.PublicKey() + valid := Verify(&pk, msg, gotSig, in.Ctx) + test.CheckOk(valid, "Verify failed", t) + } +} + +func acvpVerify(t *testing.T, p *sigParams, in *verifyInput, want bool) { + id, err := IDByName(p.ParameterSet) + test.CheckNoErr(t, err, "invalid ParameterSet") + + pk := PublicKey{ID: id} + err = pk.UnmarshalBinary(in.Pk) + test.CheckNoErr(t, err, "PublicKey.UnmarshalBinary failed") + + var got bool + if p.SigInterface == "internal" { + SkipLongTest(t) + got = slhVerifyInternal(&pk, in.Msg, in.Sig) + } else if p.SigInterface == "external" { + var msg *Message + if p.PreHash == "pure" { + msg = NewMessage(in.Msg) + } else if p.PreHash == "preHash" { + ph := getPreHash(t, in.HashAlg) + _, err = ph.Write(in.Msg) + test.CheckNoErr(t, err, "PreHash Write failed") + + msg, err = ph.BuildMessage() + test.CheckNoErr(t, err, "PreHash GetMessage failed") + } + + got = Verify(&pk, msg, in.Sig, in.Ctx) + } + + if got != want { + test.ReportError(t, got, want) + } +} + +type Hex []byte + +func (b *Hex) UnmarshalJSON(data []byte) (err error) { + var s string + err = json.Unmarshal(data, &s) + if err == nil { + *b, err = hex.DecodeString(s) + } + return +} + +func readVector(t *testing.T, fileName string, vector interface{}) { + zipFile, err := zip.OpenReader(fileName) + test.CheckNoErr(t, err, "error opening file") + defer zipFile.Close() + + jsonFile, err := zipFile.File[0].Open() + test.CheckNoErr(t, err, "error opening unzipping file") + defer jsonFile.Close() + + input, err := io.ReadAll(jsonFile) + test.CheckNoErr(t, err, "error reading bytes") + + err = json.Unmarshal(input, &vector) + test.CheckNoErr(t, err, "error unmarshalling JSON file") +} + +func getPreHash(t *testing.T, s string) *PreHash { + m := make(map[string]*PreHash) + m["SHA2-224"], _ = NewPreHashWithHash(crypto.SHA224) + m["SHA2-256"], _ = NewPreHashWithHash(crypto.SHA256) + m["SHA2-384"], _ = NewPreHashWithHash(crypto.SHA384) + m["SHA2-512"], _ = NewPreHashWithHash(crypto.SHA512) + m["SHA2-512/224"], _ = NewPreHashWithHash(crypto.SHA512_224) + m["SHA2-512/256"], _ = NewPreHashWithHash(crypto.SHA512_256) + m["SHA3-224"], _ = NewPreHashWithHash(crypto.SHA3_224) + m["SHA3-256"], _ = NewPreHashWithHash(crypto.SHA3_256) + m["SHA3-384"], _ = NewPreHashWithHash(crypto.SHA3_384) + m["SHA3-512"], _ = NewPreHashWithHash(crypto.SHA3_512) + m["SHAKE-128"], _ = NewPreHashWithXof(xof.SHAKE128) + m["SHAKE-256"], _ = NewPreHashWithXof(xof.SHAKE256) + + ph, ok := m[s] + test.CheckOk(ok, fmt.Sprintf("preHash algorithm not supported %v", s), t) + return ph +} diff --git a/sign/slhdsa/address.go b/sign/slhdsa/address.go new file mode 100644 index 00000000..8a0b2878 --- /dev/null +++ b/sign/slhdsa/address.go @@ -0,0 +1,92 @@ +package slhdsa + +import "encoding/binary" + +// See FIPS 205 -- Section 4.2 +// Functions and Addressing + +type addrType = uint32 + +const ( + addressWotsHash = addrType(iota) + addressWotsPk + addressTree + addressForsTree + addressForsRoots + addressWotsPrf + addressForsPrf +) + +const ( + addressSizeCompressed = 22 + addressSizeNonCompressed = 32 +) + +type address struct { + b []byte + o int +} + +func (p *params) addressSize() uint32 { + if p.isSHA2 { + return addressSizeCompressed + } else { + return addressSizeNonCompressed + } +} + +func (p *params) addressOffset() int { + if p.isSHA2 { + return 0 + } else { + return 10 + } +} + +func (p *params) NewAddress() (a address) { + var m [addressSizeNonCompressed]byte + a.b = m[:p.addressSize()] + a.o = p.addressOffset() + return +} + +func (a *address) fromBytes(p *params, c *cursor) { + a.b = c.Next(p.addressSize()) + a.o = p.addressOffset() +} + +func (a *address) Set(x address) { copy(a.b, x.b); a.o = x.o } +func (a *address) Clear() { clearSlice(&a.b); a.o = 0 } +func (a *address) SetKeyPairAddress(i uint32) { binary.BigEndian.PutUint32(a.b[a.o+10:], i) } +func (a *address) SetChainAddress(i uint32) { binary.BigEndian.PutUint32(a.b[a.o+14:], i) } +func (a *address) SetTreeHeight(i uint32) { binary.BigEndian.PutUint32(a.b[a.o+14:], i) } +func (a *address) SetHashAddress(i uint32) { binary.BigEndian.PutUint32(a.b[a.o+18:], i) } +func (a *address) SetTreeIndex(i uint32) { binary.BigEndian.PutUint32(a.b[a.o+18:], i) } +func (a *address) GetKeyPairAddress() uint32 { return binary.BigEndian.Uint32(a.b[a.o+10:]) } +func (a *address) SetLayerAddress(l addrType) { + if a.o == 0 { + a.b[0] = byte(l & 0xFF) + } else { + binary.BigEndian.PutUint32(a.b[0:], l) + } +} + +func (a *address) SetTreeAddress(t [3]uint32) { + if a.o == 0 { + binary.BigEndian.PutUint32(a.b[1:], t[1]) + binary.BigEndian.PutUint32(a.b[5:], t[0]) + } else { + binary.BigEndian.PutUint32(a.b[4:], t[2]) + binary.BigEndian.PutUint32(a.b[8:], t[1]) + binary.BigEndian.PutUint32(a.b[12:], t[0]) + } +} + +func (a *address) SetTypeAndClear(t uint32) { + if a.o == 0 { + a.b[9] = byte(t) + } else { + binary.BigEndian.PutUint32(a.b[16:], t) + } + clear(a.b[a.o+10:]) +} diff --git a/sign/slhdsa/all_test.go b/sign/slhdsa/all_test.go new file mode 100644 index 00000000..5e44dd9a --- /dev/null +++ b/sign/slhdsa/all_test.go @@ -0,0 +1,54 @@ +package slhdsa + +import ( + "crypto/rand" + "flag" + "testing" +) + +// RunLongTest indicates whether long tests should be run. +var RunLongTest = flag.Bool("long", false, "runs longer tests and benchmark") + +func SkipLongTest(t testing.TB) { + t.Helper() + if !*RunLongTest { + t.Skip("Skipped one long test, add -long flag to run longer tests") + } +} + +func InnerTest(t *testing.T, sigIDs []ID) { + SkipLongTest(t) + for _, id := range sigIDs { + param := id.params() + t.Run(id.String(), func(t *testing.T) { + t.Run("Wots", func(t *testing.T) { testWotsPlus(t, param) }) + t.Run("Xmss", func(t *testing.T) { testXmss(t, param) }) + t.Run("Ht", func(tt *testing.T) { testHyperTree(t, param) }) + t.Run("Fors", func(tt *testing.T) { testFors(t, param) }) + t.Run("Int", func(tt *testing.T) { testInternal(t, param) }) + }) + } +} + +func BenchInner(b *testing.B, sigIDs []ID) { + SkipLongTest(b) + for _, id := range sigIDs { + param := id.params() + b.Run(param.name, func(b *testing.B) { + b.Run("Wots", func(b *testing.B) { benchmarkWotsPlus(b, param) }) + b.Run("Xmss", func(b *testing.B) { benchmarkXmss(b, param) }) + b.Run("Ht", func(b *testing.B) { benchmarkHyperTree(b, param) }) + b.Run("Fors", func(b *testing.B) { benchmarkFors(b, param) }) + b.Run("Int", func(b *testing.B) { benchmarkInternal(b, param) }) + }) + } +} + +func mustRead(t testing.TB, size uint32) (out []byte) { + out = make([]byte, size) + _, err := rand.Read(out) + if err != nil { + t.Fatalf("rand reader error: %v", err) + } + return +} diff --git a/sign/slhdsa/fors.go b/sign/slhdsa/fors.go new file mode 100644 index 00000000..5d599dc2 --- /dev/null +++ b/sign/slhdsa/fors.go @@ -0,0 +1,173 @@ +package slhdsa + +// See FIPS 205 -- Section 8 +// Forest of Random Subsets (FORS) is a few-time signature scheme that is +// used to sign the digests of the actual messages. + +type ( + forsPublicKey []byte // n bytes + forsPrivateKey []byte // n bytes + forsSignature []forsPair // k*forsPairSize() bytes + forsPair struct { + sk forsPrivateKey // forsSkSize() bytes + auth [][]byte // a*n bytes + } // forsSkSize() + a*n bytes +) + +func (p *params) forsMsgSize() uint32 { return (p.k*p.a + 7) / 8 } +func (p *params) forsPkSize() uint32 { return p.n } +func (p *params) forsSkSize() uint32 { return p.n } +func (p *params) forsSigSize() uint32 { return p.k * p.forsPairSize() } +func (p *params) forsPairSize() uint32 { return p.forsSkSize() + p.a*p.n } + +func (fs *forsSignature) fromBytes(p *params, c *cursor) { + *fs = make([]forsPair, p.k) + for i := range *fs { + (*fs)[i].fromBytes(p, c) + } +} + +func (fp *forsPair) fromBytes(p *params, c *cursor) { + fp.sk = c.Next(p.forsSkSize()) + fp.auth = make([][]byte, p.a) + for i := range fp.auth { + fp.auth[i] = c.Next(p.n) + } +} + +// See FIPS 205 -- Section 8.1 -- Algorithm 14. +func (s *statePriv) forsSkGen(addr address, idx uint32) forsPrivateKey { + s.PRF.address.Set(addr) + s.PRF.address.SetTypeAndClear(addressForsPrf) + s.PRF.address.SetKeyPairAddress(addr.GetKeyPairAddress()) + s.PRF.address.SetTreeIndex(idx) + + return s.PRF.Final() +} + +// See FIPS 205 -- Section 8.2 -- Algorithm 15 -- Iterative version. +// +// This is a stack-based implementation that computes the tree leaves +// in order (from the left to the right). +// Its recursive version can be found at fors_test.go file. +func (s *statePriv) forsNodeIter( + stack stackNode, root []byte, i, z uint32, addr address, +) { + if !(z <= s.a && i < s.k<<(s.a-z)) { + panic(ErrTree) + } + + s.F.address.Set(addr) + s.F.address.SetTreeHeight(0) + + s.H.address.Set(addr) + + twoZ := uint32(1) << z + iTwoZ := i << z + for k := range twoZ { + li := iTwoZ + k + lz := uint32(0) + + sk := s.forsSkGen(addr, li) + s.F.address.SetTreeIndex(li) + s.F.SetMessage(sk) + node := s.F.Final() + + for !stack.isEmpty() && stack.top().z == lz { + left := stack.pop() + li = (li - 1) >> 1 + lz = lz + 1 + + s.H.address.SetTreeHeight(lz) + s.H.address.SetTreeIndex(li) + s.H.SetMsgs(left.node, node) + node = s.H.Final() + } + + stack.push(item{node, lz}) + } + + copy(root, stack.pop().node) +} + +// See FIPS 205 -- Section 8.3 -- Algorithm 16. +func (s *statePriv) forsSign(sig forsSignature, digest []byte, addr address) { + stack := s.NewStack(s.a - 1) + defer stack.Clear() + + in, bits, total := 0, uint32(0), uint32(0) + maskA := (uint32(1) << s.a) - 1 + + for i := range s.k { + for bits < s.a { + total = (total << 8) + uint32(digest[in]) + in++ + bits += 8 + } + + bits -= s.a + indicesI := (total >> bits) & maskA + treeIdx := (i << s.a) + indicesI + forsSk := s.forsSkGen(addr, treeIdx) + copy(sig[i].sk, forsSk) + + for j := range s.a { + shift := (indicesI >> j) ^ 1 + s.forsNodeIter(stack, sig[i].auth[j], (i<<(s.a-j))+shift, j, addr) + } + } +} + +// See FIPS 205 -- Section 8.4 -- Algorithm 17. +func (s *state) forsPkFromSig( + sig forsSignature, digest []byte, addr address, +) (pk forsPublicKey) { + pk = make([]byte, s.forsPkSize()) + + s.F.address.Set(addr) + s.F.address.SetTreeHeight(0) + + s.H.address.Set(addr) + + s.T.address.Set(addr) + s.T.address.SetTypeAndClear(addressForsRoots) + s.T.address.SetKeyPairAddress(addr.GetKeyPairAddress()) + s.T.Reset() + + in, bits, total := 0, uint32(0), uint32(0) + maskA := (uint32(1) << s.a) - 1 + + for i := range s.k { + for bits < s.a { + total = (total << 8) + uint32(digest[in]) + in++ + bits += 8 + } + + bits -= s.a + indicesI := (total >> bits) & maskA + treeIdx := (i << s.a) + indicesI + s.F.address.SetTreeIndex(treeIdx) + s.F.SetMessage(sig[i].sk) + node := s.F.Final() + + for j := range s.a { + if (indicesI>>j)&0x1 == 0 { + treeIdx = treeIdx >> 1 + s.H.SetMsgs(node, sig[i].auth[j]) + } else { + treeIdx = (treeIdx - 1) >> 1 + s.H.SetMsgs(sig[i].auth[j], node) + } + + s.H.address.SetTreeHeight(j + 1) + s.H.address.SetTreeIndex(treeIdx) + node = s.H.Final() + } + + s.T.WriteMessage(node) + } + + copy(pk, s.T.Final()) + return pk +} diff --git a/sign/slhdsa/fors_test.go b/sign/slhdsa/fors_test.go new file mode 100644 index 00000000..8e324ff2 --- /dev/null +++ b/sign/slhdsa/fors_test.go @@ -0,0 +1,116 @@ +package slhdsa + +import ( + "bytes" + "testing" + + "github.com/cloudflare/circl/internal/test" +) + +// See FIPS 205 -- Section 8.2 -- Algorithm 15 -- Recursive version. +func (s *statePriv) forsNodeRec(i, z uint32, addr address) (node []byte) { + if !(z <= s.a && i < s.k<<(s.a-z)) { + panic(ErrTree) + } + + node = make([]byte, s.n) + if z == 0 { + sk := s.forsSkGen(addr, i) + addr.SetTreeHeight(0) + addr.SetTreeIndex(i) + + s.F.address.Set(addr) + s.F.SetMessage(sk) + copy(node, s.F.Final()) + } else { + lnode := s.forsNodeRec(2*i, z-1, addr) + rnode := s.forsNodeRec(2*i+1, z-1, addr) + + s.H.address.Set(addr) + s.H.address.SetTreeHeight(z) + s.H.address.SetTreeIndex(i) + s.H.SetMsgs(lnode, rnode) + copy(node, s.H.Final()) + } + + return +} + +func testFors(t *testing.T, p *params) { + skSeed := mustRead(t, p.n) + pkSeed := mustRead(t, p.n) + msg := mustRead(t, p.forsMsgSize()) + + state := p.NewStatePriv(skSeed, pkSeed) + + idxTree := [3]uint32{0, 0, 0} + idxLeaf := uint32(0) + + addr := p.NewAddress() + addr.SetLayerAddress(p.d - 1) + + pkRoot := make([]byte, p.n) + state.xmssNodeIter(p.NewStack(p.hPrime), pkRoot, idxLeaf, p.hPrime, addr) + + n0 := state.forsNodeRec(idxLeaf, p.a, addr) + + n1 := make([]byte, p.n) + state.forsNodeIter(p.NewStack(p.a), n1, idxLeaf, p.a, addr) + + if !bytes.Equal(n0, n1) { + test.ReportError(t, n0, n1) + } + + var sig forsSignature + curSig := cursor(make([]byte, p.forsSigSize())) + sig.fromBytes(p, &curSig) + state.forsSign(sig, msg, addr) + pkFors := state.forsPkFromSig(sig, msg, addr) + + var htSig hyperTreeSignature + curHtSig := cursor(make([]byte, p.hyperTreeSigSize())) + htSig.fromBytes(p, &curHtSig) + state.htSign(htSig, pkFors, idxTree, idxLeaf) + + valid := state.htVerify(pkFors, pkRoot, idxTree, idxLeaf, htSig) + test.CheckOk(valid, "hypertree signature verification failed", t) +} + +func benchmarkFors(b *testing.B, p *params) { + skSeed := mustRead(b, p.n) + pkSeed := mustRead(b, p.n) + msg := mustRead(b, p.forsMsgSize()) + + state := p.NewStatePriv(skSeed, pkSeed) + + addr := p.NewAddress() + addr.SetLayerAddress(p.d - 1) + + var sig forsSignature + curSig := cursor(make([]byte, p.forsSigSize())) + sig.fromBytes(p, &curSig) + state.forsSign(sig, msg, addr) + + b.Run("NodeRec", func(b *testing.B) { + for range b.N { + _ = state.forsNodeRec(0, p.a, addr) + } + }) + b.Run("NodeIter", func(b *testing.B) { + node := make([]byte, p.n) + forsStack := p.NewStack(p.a) + for range b.N { + state.forsNodeIter(forsStack, node, 0, p.a, addr) + } + }) + b.Run("Sign", func(b *testing.B) { + for range b.N { + state.forsSign(sig, msg, addr) + } + }) + b.Run("PkFromSig", func(b *testing.B) { + for range b.N { + _ = state.forsPkFromSig(sig, msg, addr) + } + }) +} diff --git a/sign/slhdsa/hypertree.go b/sign/slhdsa/hypertree.go new file mode 100644 index 00000000..2345e9b3 --- /dev/null +++ b/sign/slhdsa/hypertree.go @@ -0,0 +1,68 @@ +package slhdsa + +import "bytes" + +// See FIPS 205 -- Section 7 +// SLH-DSA uses a hypertree to sign the FORS keys. + +type hyperTreeSignature []xmssSignature // d*xmssSigSize() bytes + +func (p *params) hyperTreeSigSize() uint32 { return p.d * p.xmssSigSize() } + +func (hts *hyperTreeSignature) fromBytes(p *params, c *cursor) { + *hts = make([]xmssSignature, p.d) + for i := range *hts { + (*hts)[i].fromBytes(p, c) + } +} + +func nextIndex(idxTree *[3]uint32, n uint32) (idxLeaf uint32) { + idxLeaf = idxTree[0] & ((1 << n) - 1) + idxTree[0] = (idxTree[0] >> n) | (idxTree[1] << (32 - n)) + idxTree[1] = (idxTree[1] >> n) | (idxTree[2] << (32 - n)) + idxTree[2] = (idxTree[2] >> n) + + return +} + +// See FIPS 205 -- Section 7.1 -- Algorithm 12. +func (s *statePriv) htSign( + sig hyperTreeSignature, msg []byte, idxTree [3]uint32, idxLeaf uint32, +) { + addr := s.NewAddress() + addr.SetTreeAddress(idxTree) + stack := s.NewStack(s.hPrime - 1) + defer stack.Clear() + + s.xmssSign(stack, sig[0], msg, idxLeaf, addr) + + root := make([]byte, s.xmssPkSize()) + copy(root, msg) + for j := uint32(1); j < s.d; j++ { + s.xmssPkFromSig(root, root, sig[j-1], idxLeaf, addr) + idxLeaf = nextIndex(&idxTree, s.hPrime) + addr.SetLayerAddress(j) + addr.SetTreeAddress(idxTree) + s.xmssSign(stack, sig[j], root, idxLeaf, addr) + } +} + +// See FIPS 205 -- Section 7.2 -- Algorithm 13. +func (s *state) htVerify( + msg, root []byte, idxTree [3]uint32, idxLeaf uint32, sig hyperTreeSignature, +) bool { + addr := s.NewAddress() + addr.SetTreeAddress(idxTree) + + node := make([]byte, s.xmssPkSize()) + s.xmssPkFromSig(node, msg, sig[0], idxLeaf, addr) + + for j := uint32(1); j < s.d; j++ { + idxLeaf = nextIndex(&idxTree, s.hPrime) + addr.SetLayerAddress(j) + addr.SetTreeAddress(idxTree) + s.xmssPkFromSig(node, node, sig[j], idxLeaf, addr) + } + + return bytes.Equal(node, root) +} diff --git a/sign/slhdsa/hypertree_test.go b/sign/slhdsa/hypertree_test.go new file mode 100644 index 00000000..02f93fff --- /dev/null +++ b/sign/slhdsa/hypertree_test.go @@ -0,0 +1,60 @@ +package slhdsa + +import ( + "testing" + + "github.com/cloudflare/circl/internal/test" +) + +func testHyperTree(t *testing.T, p *params) { + skSeed := mustRead(t, p.n) + pkSeed := mustRead(t, p.n) + msg := mustRead(t, p.n) + + state := p.NewStatePriv(skSeed, pkSeed) + + idxTree := [3]uint32{0, 0, 0} + idxLeaf := uint32(0) + + addr := p.NewAddress() + addr.SetLayerAddress(p.d - 1) + stack := p.NewStack(p.hPrime) + pkRoot := make([]byte, p.n) + state.xmssNodeIter(stack, pkRoot, idxLeaf, p.hPrime, addr) + + var sig hyperTreeSignature + curSig := cursor(make([]byte, p.hyperTreeSigSize())) + sig.fromBytes(p, &curSig) + state.htSign(sig, msg, idxTree, idxLeaf) + + valid := state.htVerify(msg, pkRoot, idxTree, idxLeaf, sig) + test.CheckOk(valid, "hypertree signature verification failed", t) +} + +func benchmarkHyperTree(b *testing.B, p *params) { + skSeed := mustRead(b, p.n) + pkSeed := mustRead(b, p.n) + pkRoot := mustRead(b, p.n) + msg := mustRead(b, p.n) + + state := p.NewStatePriv(skSeed, pkSeed) + + idxTree := [3]uint32{0, 0, 0} + idxLeaf := uint32(0) + + var sig hyperTreeSignature + curSig := cursor(make([]byte, p.hyperTreeSigSize())) + sig.fromBytes(p, &curSig) + state.htSign(sig, msg, idxTree, idxLeaf) + + b.Run("Sign", func(b *testing.B) { + for range b.N { + state.htSign(sig, msg, idxTree, idxLeaf) + } + }) + b.Run("Verify", func(b *testing.B) { + for range b.N { + _ = state.htVerify(msg, pkRoot, idxTree, idxLeaf, sig) + } + }) +} diff --git a/sign/slhdsa/internal.go b/sign/slhdsa/internal.go new file mode 100644 index 00000000..1972b9bf --- /dev/null +++ b/sign/slhdsa/internal.go @@ -0,0 +1,135 @@ +package slhdsa + +import "encoding/binary" + +// See FIPS 205 -- Section 9 +// SLH-DSA Internal Functions + +// See FIPS 205 -- Section 9.1 -- Algorithm 18. +func slhKeyGenInternal( + p *params, skSeed, skPrf, pkSeed []byte, +) (pub PublicKey, priv PrivateKey) { + s := p.NewStatePriv(skSeed, pkSeed) + defer s.Clear() + + stack := p.NewStack(p.hPrime) + defer stack.Clear() + + addr := p.NewAddress() + addr.SetLayerAddress(p.d - 1) + pkRoot := make([]byte, p.n) + s.xmssNodeIter(stack, pkRoot, 0, p.hPrime, addr) + + pub.ID = p.ID + pub.seed = pkSeed + pub.root = pkRoot + + priv.ID = p.ID + priv.prfKey = skPrf + priv.seed = skSeed + priv.publicKey = pub + + return +} + +func (p *params) parseMsg( + digest []byte, +) (md []byte, idxTree [3]uint32, idxLeaf uint32) { + l1 := (p.k*p.a + 7) / 8 + l2 := (p.h - p.h/p.d + 7) / 8 + l3 := (p.h + 8*p.d - 1) / (8 * p.d) + + c := cursor(digest) + md = c.Next(l1) + s2 := c.Next(l2) + s3 := c.Next(l3) + + var b2 [12]byte + copy(b2[12-len(s2):], s2) + n2 := p.h - p.h/p.d + idxTree[0] = binary.BigEndian.Uint32(b2[8:]) & ((1 << n2) - 1) + n2 -= 32 + idxTree[1] = binary.BigEndian.Uint32(b2[4:]) & ((1 << n2) - 1) + idxTree[2] = binary.BigEndian.Uint32(b2[0:]) + + var b3 [4]byte + copy(b3[4-len(s3):], s3) + mask32 := (uint32(1) << (p.h / p.d)) - 1 + idxLeaf = mask32 & binary.BigEndian.Uint32(b3[0:]) + + return +} + +// See FIPS 205 -- Section 9.2 -- Algorithm 19. +func slhSignInternal(sk *PrivateKey, message, addRand []byte) ([]byte, error) { + p := sk.ID.params() + sigBytes := make([]byte, p.SignatureSize()) + + var sig signature + curSig := cursor(sigBytes) + if !sig.fromBytes(p, &curSig) { + return nil, ErrSigParse + } + + p.PRFMsg(sig.rnd, sk.prfKey, addRand, message) + digest := make([]byte, p.m) + p.HashMsg(digest, sig.rnd, message, &sk.publicKey) + + md, idxTree, idxLeaf := p.parseMsg(digest) + addr := p.NewAddress() + addr.SetTreeAddress(idxTree) + addr.SetTypeAndClear(addressForsTree) + addr.SetKeyPairAddress(idxLeaf) + + s := p.NewStatePriv(sk.seed, sk.publicKey.seed) + defer s.Clear() + + s.forsSign(sig.forsSig, md, addr) + pkFors := s.forsPkFromSig(sig.forsSig, md, addr) + s.htSign(sig.htSig, pkFors, idxTree, idxLeaf) + + return sigBytes, nil +} + +// See FIPS 205 -- Section 9.3 -- Algorithm 20. +func slhVerifyInternal(pub *PublicKey, message, sigBytes []byte) bool { + p := pub.ID.params() + var sig signature + curSig := cursor(sigBytes) + if len(sigBytes) != p.SignatureSize() || !sig.fromBytes(p, &curSig) { + return false + } + + digest := make([]byte, p.m) + p.HashMsg(digest, sig.rnd, message, pub) + + md, idxTree, idxLeaf := p.parseMsg(digest) + addr := p.NewAddress() + addr.SetTreeAddress(idxTree) + addr.SetTypeAndClear(addressForsTree) + addr.SetKeyPairAddress(idxLeaf) + + s := p.NewStatePub(pub.seed) + defer s.Clear() + + pkFors := s.forsPkFromSig(sig.forsSig, md, addr) + return s.htVerify(pkFors, pub.root, idxTree, idxLeaf, sig.htSig) +} + +// signature represents a SLH-DSA signature. +type signature struct { + rnd []byte // n bytes + forsSig forsSignature // forsSigSize() bytes + htSig hyperTreeSignature // hyperTreeSigSize() bytes +} + +func (p *params) SignatureSize() int { + return int(p.n + p.forsSigSize() + p.hyperTreeSigSize()) +} + +func (s *signature) fromBytes(p *params, c *cursor) bool { + s.rnd = c.Next(p.n) + s.forsSig.fromBytes(p, c) + s.htSig.fromBytes(p, c) + return len(*c) == 0 +} diff --git a/sign/slhdsa/internal_test.go b/sign/slhdsa/internal_test.go new file mode 100644 index 00000000..1b666f42 --- /dev/null +++ b/sign/slhdsa/internal_test.go @@ -0,0 +1,50 @@ +package slhdsa + +import ( + "testing" + + "github.com/cloudflare/circl/internal/test" +) + +func testInternal(t *testing.T, p *params) { + skSeed := mustRead(t, p.n) + skPrf := mustRead(t, p.n) + pkSeed := mustRead(t, p.n) + msg := mustRead(t, p.m) + addRand := mustRead(t, p.n) + + pk, sk := slhKeyGenInternal(p, skSeed, skPrf, pkSeed) + sig, err := slhSignInternal(&sk, msg, addRand) + test.CheckNoErr(t, err, "slhSignInternal failed") + + valid := slhVerifyInternal(&pk, msg, sig) + test.CheckOk(valid, "slhVerifyInternal failed", t) +} + +func benchmarkInternal(b *testing.B, p *params) { + skSeed := mustRead(b, p.n) + skPrf := mustRead(b, p.n) + pkSeed := mustRead(b, p.n) + msg := mustRead(b, p.m) + addRand := mustRead(b, p.n) + + pk, sk := slhKeyGenInternal(p, skSeed, skPrf, pkSeed) + sig, err := slhSignInternal(&sk, msg, addRand) + test.CheckNoErr(b, err, "slhSignInternal failed") + + b.Run("Keygen", func(b *testing.B) { + for range b.N { + _, _ = slhKeyGenInternal(p, skSeed, skPrf, pkSeed) + } + }) + b.Run("Sign", func(b *testing.B) { + for range b.N { + _, _ = slhSignInternal(&sk, msg, addRand) + } + }) + b.Run("Verify", func(b *testing.B) { + for range b.N { + _ = slhVerifyInternal(&pk, msg, sig) + } + }) +} diff --git a/sign/slhdsa/keys.go b/sign/slhdsa/keys.go new file mode 100644 index 00000000..53e9ced8 --- /dev/null +++ b/sign/slhdsa/keys.go @@ -0,0 +1,137 @@ +package slhdsa + +import ( + "bytes" + "crypto" + "crypto/subtle" + + "github.com/cloudflare/circl/internal/conv" + "golang.org/x/crypto/cryptobyte" +) + +// [PrivateKey] stores a private key of the SLH-DSA scheme. +// It implements the [crypto.Signer] and [crypto.PrivateKey] interfaces. +// For serialization, it also implements [cryptobyte.MarshalingValue], +// [encoding.BinaryMarshaler], and [encoding.BinaryUnmarshaler]. +type PrivateKey struct { + seed, prfKey []byte + publicKey PublicKey + ID +} + +func (p *params) PrivateKeySize() int { return int(2*p.n) + p.PublicKeySize() } + +// Marshal serializes the key using a [cryptobyte.Builder]. +func (k PrivateKey) Marshal(b *cryptobyte.Builder) error { + b.AddBytes(k.seed) + b.AddBytes(k.prfKey) + b.AddValue(k.publicKey) + return nil +} + +// Unmarshal recovers a [PrivateKey] from a [cryptobyte.String]. +// Caller must specify the private key's [ID] in advance. +// Example: +// +// key := PrivateKey{ID: SHA2Small192} +// key.Unmarshal(str) // returns true +func (k *PrivateKey) Unmarshal(s *cryptobyte.String) bool { + params := k.ID.params() + b := make([]byte, params.PrivateKeySize()) + if !s.CopyBytes(b) { + return false + } + + c := cursor(b) + return k.fromBytes(params, &c) +} + +func (k *PrivateKey) fromBytes(p *params, c *cursor) bool { + k.ID = p.ID + k.seed = c.Next(p.n) + k.prfKey = c.Next(p.n) + return k.publicKey.fromBytes(p, c) && k.publicKey.ID == k.ID +} + +// UnmarshalBinary recovers a [PrivateKey] from a slice of bytes. +// Caller must specify the private key's [ID] in advance. +// Example: +// +// key := PrivateKey{ID: SHA2Small192} +// key.UnmarshalBinary(bytes) // returns nil +func (k *PrivateKey) UnmarshalBinary(b []byte) error { return conv.UnmarshalBinary(k, b) } +func (k PrivateKey) MarshalBinary() ([]byte, error) { return conv.MarshalBinary(k) } +func (k PrivateKey) Public() crypto.PublicKey { return k.PublicKey() } +func (k PrivateKey) PublicKey() (pub PublicKey) { + params := k.ID.params() + c := cursor(make([]byte, params.PublicKeySize())) + pub.fromBytes(params, &c) + copy(pub.seed, k.publicKey.seed) + copy(pub.root, k.publicKey.root) + return +} + +func (k PrivateKey) Equal(x crypto.PrivateKey) bool { + other, ok := x.(PrivateKey) + return ok && k.ID == other.ID && + subtle.ConstantTimeCompare(k.seed, other.seed) == 1 && + subtle.ConstantTimeCompare(k.prfKey, other.prfKey) == 1 && + k.publicKey.Equal(other.publicKey) +} + +// [PublicKey] stores a public key of the SLH-DSA scheme. +// It implements the [crypto.PublicKey] interface. +// For serialization, it also implements [cryptobyte.MarshalingValue], +// [encoding.BinaryMarshaler], and [encoding.BinaryUnmarshaler]. +type PublicKey struct { + seed, root []byte + ID +} + +func (p *params) PublicKeySize() int { return int(2 * p.n) } + +// Marshal serializes the key using a [cryptobyte.Builder]. +func (k PublicKey) Marshal(b *cryptobyte.Builder) error { + b.AddBytes(k.seed) + b.AddBytes(k.root) + return nil +} + +// Unmarshal recovers a [PublicKey] from a [cryptobyte.String]. +// Caller must specify the public key's [ID] in advance. +// Example: +// +// key := PublicKey{ID: SHA2Small192} +// key.Unmarshal(str) // returns true +func (k *PublicKey) Unmarshal(s *cryptobyte.String) bool { + params := k.ID.params() + b := make([]byte, params.PublicKeySize()) + if !s.CopyBytes(b) { + return false + } + + c := cursor(b) + return k.fromBytes(params, &c) +} + +func (k *PublicKey) fromBytes(p *params, c *cursor) bool { + k.ID = p.ID + k.seed = c.Next(p.n) + k.root = c.Next(p.n) + return len(*c) == 0 +} + +// UnmarshalBinary recovers a [PublicKey] from a slice of bytes. +// Caller must specify the public key's [ID] in advance. +// Example: +// +// key := PublicKey{ID: SHA2Small192} +// key.UnmarshalBinary(bytes) // returns nil +func (k *PublicKey) UnmarshalBinary(b []byte) error { return conv.UnmarshalBinary(k, b) } +func (k PublicKey) MarshalBinary() ([]byte, error) { return conv.MarshalBinary(k) } +func (k PublicKey) Equal(x crypto.PublicKey) bool { + other, ok := x.(PublicKey) + return ok && k.ID == other.ID && + bytes.Equal(k.seed, other.seed) && + bytes.Equal(k.root, other.root) +} diff --git a/sign/slhdsa/message.go b/sign/slhdsa/message.go new file mode 100644 index 00000000..4c819e76 --- /dev/null +++ b/sign/slhdsa/message.go @@ -0,0 +1,115 @@ +package slhdsa + +import ( + "crypto" + "hash" + "io" + + "github.com/cloudflare/circl/xof" + _ "golang.org/x/crypto/sha3" +) + +// [PreHash] is a helper for hashing a message before signing. +// It implements the [io.Writer] interface, so the message can be provided +// in chunks before calling the [SignDeterministic], [SignRandomized], or +// [Verify] functions. +// Pre-hash must not be used for generating pure signatures. +type PreHash struct { + writer interface { + io.Writer + Reset() + } + size int + oid byte +} + +// [NewPreHashWithHash] is used to prehash messages using either the SHA2 or +// SHA3 hash functions. +// Returns [ErrPreHash] if the function is not supported. +func NewPreHashWithHash(h crypto.Hash) (*PreHash, error) { + hash2oid := [...]byte{ + crypto.SHA256: 1, + crypto.SHA384: 2, + crypto.SHA512: 3, + crypto.SHA224: 4, + crypto.SHA512_224: 5, + crypto.SHA512_256: 6, + crypto.SHA3_224: 7, + crypto.SHA3_256: 8, + crypto.SHA3_384: 9, + crypto.SHA3_512: 10, + } + + oid := hash2oid[h] + if oid == 0 { + return nil, ErrPreHash + } + + return &PreHash{h.New(), h.Size(), oid}, nil +} + +// [NewPreHashWithXof] is used to prehash messages using either SHAKE-128 +// or SHAKE-256. +// Returns [ErrPreHash] if the function is not supported. +func NewPreHashWithXof(x xof.ID) (*PreHash, error) { + switch x { + case xof.SHAKE128: + return &PreHash{x.New(), 32, 11}, nil + case xof.SHAKE256: + return &PreHash{x.New(), 64, 12}, nil + default: + return nil, ErrPreHash + } +} + +func (ph *PreHash) Reset() { ph.writer.Reset() } +func (ph *PreHash) Write(b []byte) (int, error) { return ph.writer.Write(b) } + +// BuildMessage returns a [Message] for signing, and resets the writer. +func (ph *PreHash) BuildMessage() (*Message, error) { + // Source https://csrc.nist.gov/Projects/computer-security-objects-register/algorithm-registration + const oidLen = 11 + oid := [oidLen]byte{ + 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, ph.oid, + } + + msg := make([]byte, oidLen+ph.size) + copy(msg, oid[:]) + switch f := ph.writer.(type) { + case hash.Hash: + msg = f.Sum(msg[:oidLen]) + case xof.XOF: + _, err := f.Read(msg[oidLen:]) + if err != nil { + return nil, err + } + default: + return nil, ErrPreHash + } + + ph.Reset() + return &Message{msg, 1}, nil +} + +// [Message] wraps a message for signing. +type Message struct { + msg []byte + isPreHash byte +} + +// For pure signatures, use [NewMessage] to pass the message to be signed. +// For pre-hashed signatures, use [PreHash] to hash the message first, and +// then use [PreHash.BuildMessage] to get a [Message] to be signed. +func NewMessage(msg []byte) *Message { return &Message{msg, 0} } + +func (m *Message) getMsgPrime(context []byte) ([]byte, error) { + // See FIPS 205 -- Section 10.2 -- Algorithm 23 and Algorithm 25. + const MaxContextSize = 255 + if len(context) > MaxContextSize { + return nil, ErrContext + } + + return append(append( + []byte{m.isPreHash, byte(len(context))}, context...), m.msg..., + ), nil +} diff --git a/sign/slhdsa/params.go b/sign/slhdsa/params.go new file mode 100644 index 00000000..38e4172a --- /dev/null +++ b/sign/slhdsa/params.go @@ -0,0 +1,182 @@ +package slhdsa + +import ( + "crypto" + "crypto/hmac" + "crypto/sha256" + "crypto/sha512" + "encoding/binary" + "hash" + "io" + "strings" + + "github.com/cloudflare/circl/internal/sha3" +) + +// [ID] identifies the supported parameter sets of SLH-DSA. +// Note that the zero value is not a valid identifier. +type ID byte + +//nolint:stylecheck +const ( + SHA2_128s ID = iota + 1 // SLH-DSA-SHA2-128s + SHAKE_128s // SLH-DSA-SHAKE-128s + SHA2_128f // SLH-DSA-SHA2-128f + SHAKE_128f // SLH-DSA-SHAKE-128f + SHA2_192s // SLH-DSA-SHA2-192s + SHAKE_192s // SLH-DSA-SHAKE-192s + SHA2_192f // SLH-DSA-SHA2-192f + SHAKE_192f // SLH-DSA-SHAKE-192f + SHA2_256s // SLH-DSA-SHA2-256s + SHAKE_256s // SLH-DSA-SHAKE-256s + SHA2_256f // SLH-DSA-SHA2-256f + SHAKE_256f // SLH-DSA-SHAKE-256f + _MaxParams +) + +// [IDByName] returns the [ID] that corresponds to the given name, +// or an error if no parameter set was found. +// See [ID] documentation for the specific names of each parameter set. +// Names are case insensitive. +// +// Example: +// +// IDByName("SLH-DSA-SHAKE-256s") // returns (SHAKESmall256, nil) +func IDByName(name string) (ID, error) { + v := strings.ToLower(name) + for i := range supportedParams { + if strings.ToLower(supportedParams[i].name) == v { + return supportedParams[i].ID, nil + } + } + + return ID(0), ErrParam +} + +// IsValid returns true if the parameter set is supported. +func (id ID) IsValid() bool { return 0 < id && id < _MaxParams } + +func (id ID) String() string { + if !id.IsValid() { + return ErrParam.Error() + } + return supportedParams[id-1].name +} + +func (id ID) params() *params { + if !id.IsValid() { + panic(ErrParam) + } + return &supportedParams[id-1] +} + +// params contains all the relevant constants of a parameter set. +type params struct { + name string // Name of the parameter set. + n uint32 // Length of WOTS+ messages. + hPrime uint32 // XMSS Merkle tree height. + h uint32 // Total height of a hypertree. + d uint32 // Hypertree has d layers of XMSS trees. + a uint32 // FORS signs a-bit messages. + k uint32 // FORS generates k private keys. + m uint32 // Used by HashMSG function. + isSHA2 bool // True, if the hash function is SHA2, otherwise is SHAKE. + ID // Identifier of the parameter set. +} + +// Stores all the supported (read-only) parameter sets. +var supportedParams = [_MaxParams - 1]params{ + {ID: SHA2_128s, n: 16, h: 63, d: 7, hPrime: 9, a: 12, k: 14, m: 30, isSHA2: true, name: "SLH-DSA-SHA2-128s"}, + {ID: SHAKE_128s, n: 16, h: 63, d: 7, hPrime: 9, a: 12, k: 14, m: 30, isSHA2: false, name: "SLH-DSA-SHAKE-128s"}, + {ID: SHA2_128f, n: 16, h: 66, d: 22, hPrime: 3, a: 6, k: 33, m: 34, isSHA2: true, name: "SLH-DSA-SHA2-128f"}, + {ID: SHAKE_128f, n: 16, h: 66, d: 22, hPrime: 3, a: 6, k: 33, m: 34, isSHA2: false, name: "SLH-DSA-SHAKE-128f"}, + {ID: SHA2_192s, n: 24, h: 63, d: 7, hPrime: 9, a: 14, k: 17, m: 39, isSHA2: true, name: "SLH-DSA-SHA2-192s"}, + {ID: SHAKE_192s, n: 24, h: 63, d: 7, hPrime: 9, a: 14, k: 17, m: 39, isSHA2: false, name: "SLH-DSA-SHAKE-192s"}, + {ID: SHA2_192f, n: 24, h: 66, d: 22, hPrime: 3, a: 8, k: 33, m: 42, isSHA2: true, name: "SLH-DSA-SHA2-192f"}, + {ID: SHAKE_192f, n: 24, h: 66, d: 22, hPrime: 3, a: 8, k: 33, m: 42, isSHA2: false, name: "SLH-DSA-SHAKE-192f"}, + {ID: SHA2_256s, n: 32, h: 64, d: 8, hPrime: 8, a: 14, k: 22, m: 47, isSHA2: true, name: "SLH-DSA-SHA2-256s"}, + {ID: SHAKE_256s, n: 32, h: 64, d: 8, hPrime: 8, a: 14, k: 22, m: 47, isSHA2: false, name: "SLH-DSA-SHAKE-256s"}, + {ID: SHA2_256f, n: 32, h: 68, d: 17, hPrime: 4, a: 9, k: 35, m: 49, isSHA2: true, name: "SLH-DSA-SHA2-256f"}, + {ID: SHAKE_256f, n: 32, h: 68, d: 17, hPrime: 4, a: 9, k: 35, m: 49, isSHA2: false, name: "SLH-DSA-SHAKE-256f"}, +} + +// See FIPS-205, Section 11.1 and Section 11.2. +func (p *params) PRFMsg(out, skPrf, optRand, msg []byte) { + if p.isSHA2 { + var h crypto.Hash + if p.n == 16 { + h = crypto.SHA256 + } else { + h = crypto.SHA512 + } + + mac := hmac.New(h.New, skPrf) + concat(mac, optRand, msg) + mac.Sum(out[:0]) + } else { + state := sha3.NewShake256() + concat(&state, skPrf, optRand, msg) + _, _ = state.Read(out) + } +} + +// See FIPS-205, Section 11.1 and Section 11.2. +func (p *params) HashMsg(out, r, msg []byte, pk *PublicKey) { + if p.isSHA2 { + var hLen uint32 + var state hash.Hash + if p.n == 16 { + hLen = sha256.Size + state = sha256.New() + } else { + hLen = sha512.Size + state = sha512.New() + } + + mgfSeed := make([]byte, 2*p.n+hLen+4) + c := cursor(mgfSeed) + copy(c.Next(p.n), r) + copy(c.Next(p.n), pk.seed) + sumInter := c.Next(hLen) + + concat(state, r, pk.seed, pk.root, msg) + state.Sum(sumInter[:0]) + p.mgf1(out, mgfSeed, p.m) + } else { + state := sha3.NewShake256() + concat(&state, r, pk.seed, pk.root, msg) + _, _ = state.Read(out) + } +} + +// MGF1 described in Appendix B.2.1 of RFC 8017. +func (p *params) mgf1(out, mgfSeed []byte, maskLen uint32) { + var hLen uint32 + var hashFn func(out, in []byte) + if p.n == 16 { + hLen = sha256.Size + hashFn = sha256sum + } else { + hLen = sha512.Size + hashFn = sha512sum + } + + offset := uint32(0) + end := (maskLen + hLen - 1) / hLen + counterBytes := mgfSeed[len(mgfSeed)-4:] + + for counter := range end { + binary.BigEndian.PutUint32(counterBytes, counter) + hashFn(out[offset:], mgfSeed) + offset += hLen + } +} + +func concat(w io.Writer, list ...[]byte) { + for _, li := range list { + _, err := w.Write(li) + if err != nil { + panic(ErrWriting) + } + } +} diff --git a/sign/slhdsa/scheme.go b/sign/slhdsa/scheme.go new file mode 100644 index 00000000..e6210f12 --- /dev/null +++ b/sign/slhdsa/scheme.go @@ -0,0 +1,114 @@ +package slhdsa + +import ( + "crypto/rand" + + "github.com/cloudflare/circl/internal/sha3" + "github.com/cloudflare/circl/sign" +) + +func (id ID) Scheme() sign.Scheme { return scheme{id.params()} } + +type scheme struct{ *params } + +func (s scheme) Name() string { return s.name } +func (s scheme) SeedSize() int { return s.PrivateKeySize() } +func (s scheme) SupportsContext() bool { return true } + +// GenerateKey is similar to [GenerateKey] function, except it always reads +// random bytes from [rand.Reader]. +func (s scheme) GenerateKey() (sign.PublicKey, sign.PrivateKey, error) { + return GenerateKey(rand.Reader, s.ID) +} + +// Sign returns a randomized pure signature of the message with the context +// given. +// If options is nil, an empty context is used. +// It returns an empty slice if the signature generation fails. +// +// Panics if the key is not a [PrivateKey] or when the [ID] mismatches. +func (s scheme) Sign( + priv sign.PrivateKey, message []byte, options *sign.SignatureOpts, +) []byte { + k, ok := priv.(PrivateKey) + if !ok || s.ID != k.ID { + panic(sign.ErrTypeMismatch) + } + + var context []byte + if options != nil { + context = []byte(options.Context) + } + + sig, err := SignRandomized(&k, rand.Reader, NewMessage(message), context) + if err != nil { + return nil + } + + return sig +} + +// Verify returns true if the signature of the message with the specified +// context is valid. +// If options is nil, an empty context is used. +// +// Panics if the key is not a [PublicKey] or when the [ID] mismatches. +func (s scheme) Verify( + pub sign.PublicKey, message, signature []byte, options *sign.SignatureOpts, +) bool { + k, ok := pub.(PublicKey) + if !ok || s.ID != k.ID { + panic(sign.ErrTypeMismatch) + } + + var context []byte + if options != nil { + context = []byte(options.Context) + } + + return Verify(&k, NewMessage(message), signature, context) +} + +// DeriveKey deterministically generates a pair of keys from a seed. +// +// Panics if seed is not of length [sign.Scheme.SeedSize]. +func (s scheme) DeriveKey(seed []byte) (sign.PublicKey, sign.PrivateKey) { + if len(seed) != s.SeedSize() { + panic(sign.ErrSeedSize) + } + + n := s.n + buf := make([]byte, 3*n) + if s.isSHA2 { + s.mgf1(buf, seed, 3*n) + } else { + sha3.ShakeSum256(buf, seed) + } + + c := cursor(buf) + skSeed := c.Next(n) + skPrf := c.Next(n) + pkSeed := c.Next(n) + + return slhKeyGenInternal(s.params, skSeed, skPrf, pkSeed) +} + +func (s scheme) UnmarshalBinaryPublicKey(b []byte) (sign.PublicKey, error) { + k := PublicKey{ID: s.ID} + err := k.UnmarshalBinary(b) + if err != nil { + return nil, err + } + + return k, nil +} + +func (s scheme) UnmarshalBinaryPrivateKey(b []byte) (sign.PrivateKey, error) { + k := PrivateKey{ID: s.ID} + err := k.UnmarshalBinary(b) + if err != nil { + return nil, err + } + + return k, nil +} diff --git a/sign/slhdsa/slhdsa.go b/sign/slhdsa/slhdsa.go new file mode 100644 index 00000000..ab9a7b75 --- /dev/null +++ b/sign/slhdsa/slhdsa.go @@ -0,0 +1,131 @@ +// Package slhdsa provides Stateless Hash-based Digital Signature Algorithm. +// +// This package is compliant with [FIPS 205] and the [ID] represents +// the following parameter sets: +// +// Category 1 +// - Based on SHA2: [SHA2_128s] and [SHA2_128f]. +// - Based on SHAKE: [SHAKE_128s] and [SHAKE_128f]. +// +// Category 3 +// - Based on SHA2: [SHA2_192s] and [SHA2_192f] +// - Based on SHAKE: [SHAKE_192s] and [SHAKE_192f] +// +// Category 5 +// - Based on SHA2: [SHA2_256s] and [SHA2_256f]. +// - Based on SHAKE: [SHAKE_256s] and [SHAKE_256f]. +// +// [FIPS 205]: https://doi.org/10.6028/NIST.FIPS.205 +package slhdsa + +import ( + "crypto" + "crypto/rand" + "errors" + "io" +) + +// [GenerateKey] returns a pair of keys using the parameter set specified. +// It returns an error if it fails reading from the random source. +func GenerateKey( + random io.Reader, id ID, +) (pub PublicKey, priv PrivateKey, err error) { + // See FIPS 205 -- Section 10.1 -- Algorithm 21. + params := id.params() + + var skSeed, skPrf, pkSeed []byte + skSeed, err = readRandom(random, params.n) + if err != nil { + return + } + + skPrf, err = readRandom(random, params.n) + if err != nil { + return + } + + pkSeed, err = readRandom(random, params.n) + if err != nil { + return + } + + pub, priv = slhKeyGenInternal(params, skSeed, skPrf, pkSeed) + + return +} + +// [SignDeterministic] returns the signature of the message with the +// specified context. +func SignDeterministic( + priv *PrivateKey, message *Message, context []byte, +) (signature []byte, err error) { + return priv.doSign(message, context, priv.publicKey.seed) +} + +// [SignRandomized] returns a random signature of the message with the +// specified context. +// It returns an error if it fails reading from the random source. +func SignRandomized( + priv *PrivateKey, random io.Reader, message *Message, context []byte, +) (signature []byte, err error) { + params := priv.ID.params() + addRand, err := readRandom(random, params.n) + if err != nil { + return nil, err + } + + return priv.doSign(message, context, addRand) +} + +// [PrivateKey.Sign] returns a randomized signature of the message with an +// empty context. +// Any parameter passed in [crypto.SignerOpts] is discarded. +// It returns an error if it fails reading from the random source. +func (k PrivateKey) Sign( + random io.Reader, message []byte, _ crypto.SignerOpts, +) (signature []byte, err error) { + return SignRandomized(&k, random, NewMessage(message), nil) +} + +func (k *PrivateKey) doSign( + message *Message, context, addRand []byte, +) ([]byte, error) { + // See FIPS 205 -- Section 10.2 -- Algorithm 22 and Algorithm 23. + msgPrime, err := message.getMsgPrime(context) + if err != nil { + return nil, err + } + + return slhSignInternal(k, msgPrime, addRand) +} + +// [Verify] returns true if the signature of the message with the specified +// context is valid. +func Verify(key *PublicKey, message *Message, signature, context []byte) bool { + // See FIPS 205 -- Section 10.3 -- Algorithm 24. + msgPrime, err := message.getMsgPrime(context) + if err != nil { + return false + } + + return slhVerifyInternal(key, msgPrime, signature) +} + +func readRandom(random io.Reader, size uint32) (out []byte, err error) { + out = make([]byte, size) + if random == nil { + random = rand.Reader + } + _, err = random.Read(out) + return +} + +var ( + ErrContext = errors.New("sign/slhdsa: context is larger than 255 bytes") + ErrMsgLen = errors.New("sign/slhdsa: invalid message length") + ErrParam = errors.New("sign/slhdsa: invalid SLH-DSA parameter") + ErrPreHash = errors.New("sign/slhdsa: invalid prehash function") + ErrSigParse = errors.New("sign/slhdsa: failed to decode the signature") + ErrTree = errors.New("sign/slhdsa: invalid tree height or tree index") + ErrWriting = errors.New("sign/slhdsa: failed to write to a hash function") +) diff --git a/sign/slhdsa/slhdsa_test.go b/sign/slhdsa/slhdsa_test.go new file mode 100644 index 00000000..6238915d --- /dev/null +++ b/sign/slhdsa/slhdsa_test.go @@ -0,0 +1,148 @@ +package slhdsa_test + +import ( + "crypto" + "crypto/rand" + "io" + "testing" + + "github.com/cloudflare/circl/internal/sha3" + "github.com/cloudflare/circl/internal/test" + "github.com/cloudflare/circl/sign/slhdsa" + "github.com/cloudflare/circl/xof" +) + +var fastSign = [...]slhdsa.ID{ + slhdsa.SHA2_128f, slhdsa.SHAKE_128f, + slhdsa.SHA2_192f, slhdsa.SHAKE_192f, + slhdsa.SHA2_256f, slhdsa.SHAKE_256f, +} + +var smallSign = [...]slhdsa.ID{ + slhdsa.SHA2_128s, slhdsa.SHAKE_128s, + slhdsa.SHA2_192s, slhdsa.SHAKE_192s, + slhdsa.SHA2_256s, slhdsa.SHAKE_256s, +} + +func TestInnerFast(t *testing.T) { slhdsa.InnerTest(t, fastSign[:]) } +func TestInnerSmall(t *testing.T) { slhdsa.InnerTest(t, smallSign[:]) } +func TestSlhdsaFast(t *testing.T) { testSlhdsa(t, fastSign[:]) } +func TestSlhdsaSmall(t *testing.T) { + slhdsa.SkipLongTest(t) + testSlhdsa(t, smallSign[:]) +} + +func testSlhdsa(t *testing.T, sigIDs []slhdsa.ID) { + for _, id := range sigIDs { + t.Run(id.String(), func(t *testing.T) { + t.Run("Keys", func(t *testing.T) { testKeys(t, id) }) + t.Run("Sign", func(t *testing.T) { testSign(t, id) }) + }) + } +} + +func testKeys(t *testing.T, id slhdsa.ID) { + reader := sha3.NewShake128() + + reader.Reset() + pub0, priv0, err := slhdsa.GenerateKey(&reader, id) + test.CheckNoErr(t, err, "GenerateKey failed") + + reader.Reset() + pub1, priv1, err := slhdsa.GenerateKey(&reader, id) + test.CheckNoErr(t, err, "GenerateKey failed") + + test.CheckOk(pub0.Equal(pub1), "public key not equal", t) + test.CheckOk(priv0.Equal(priv1), "private key not equal", t) + + test.CheckMarshal(t, &priv0, &priv1) + test.CheckMarshal(t, &pub0, &pub1) + + scheme := id.Scheme() + seed := make([]byte, scheme.SeedSize()) + pub2, priv2 := scheme.DeriveKey(seed) + pub3, priv3 := scheme.DeriveKey(seed) + + test.CheckOk(priv2.Equal(priv3), "private key not equal", t) + test.CheckOk(pub2.Equal(pub3), "public key not equal", t) +} + +func testSign(t *testing.T, id slhdsa.ID) { + pub, priv, err := slhdsa.GenerateKey(rand.Reader, id) + test.CheckNoErr(t, err, "GenerateKey failed") + + msg := []byte("Alice and Bob") + sig, err := priv.Sign(rand.Reader, msg, nil) + test.CheckNoErr(t, err, "Sign randomized failed") + + valid := slhdsa.Verify(&pub, slhdsa.NewMessage(msg), sig, nil) + test.CheckOk(valid, "Verify failed", t) +} + +func BenchmarkInnerFast(b *testing.B) { slhdsa.BenchInner(b, fastSign[:]) } +func BenchmarkInnerSmall(b *testing.B) { slhdsa.BenchInner(b, smallSign[:]) } +func BenchmarkSlhdsaFast(b *testing.B) { benchmarkSlhdsa(b, fastSign[:]) } +func BenchmarkSlhdsaSmall(b *testing.B) { + slhdsa.SkipLongTest(b) + benchmarkSlhdsa(b, smallSign[:]) +} + +func BenchmarkPreHash(b *testing.B) { + b.Run("WithHash", func(b *testing.B) { + ph, err := slhdsa.NewPreHashWithHash(crypto.SHA512) + test.CheckNoErr(b, err, "NewPreHashWithHash failed") + benchmarkPreHash(b, ph) + }) + b.Run("WithXof", func(b *testing.B) { + ph, err := slhdsa.NewPreHashWithXof(xof.SHAKE256) + test.CheckNoErr(b, err, "NewPreHashWithXof failed") + benchmarkPreHash(b, ph) + }) +} + +func benchmarkPreHash(b *testing.B, ph *slhdsa.PreHash) { + s := sha3.NewShake128() + for range b.N { + _, err := io.Copy(ph, io.LimitReader(&s, 1024)) + test.CheckNoErr(b, err, "io.Copy failed") + + _, err = ph.BuildMessage() + test.CheckNoErr(b, err, "BuildMessage failed") + } +} + +func benchmarkSlhdsa(b *testing.B, sigIDs []slhdsa.ID) { + msg := slhdsa.NewMessage([]byte("Alice and Bob")) + ctx := []byte("this is a context string") + + for _, id := range sigIDs { + pub, priv, err := slhdsa.GenerateKey(rand.Reader, id) + test.CheckNoErr(b, err, "GenerateKey failed") + + sig, err := slhdsa.SignDeterministic(&priv, msg, ctx) + test.CheckNoErr(b, err, "SignDeterministic failed") + + b.Run(id.String(), func(b *testing.B) { + b.Run("GenerateKey", func(b *testing.B) { + for range b.N { + _, _, _ = slhdsa.GenerateKey(rand.Reader, id) + } + }) + b.Run("SignRandomized", func(b *testing.B) { + for range b.N { + _, _ = slhdsa.SignRandomized(&priv, rand.Reader, msg, ctx) + } + }) + b.Run("SignDeterministic", func(b *testing.B) { + for range b.N { + _, _ = slhdsa.SignDeterministic(&priv, msg, ctx) + } + }) + b.Run("Verify", func(b *testing.B) { + for range b.N { + _ = slhdsa.Verify(&pub, msg, sig, ctx) + } + }) + }) + } +} diff --git a/sign/slhdsa/state.go b/sign/slhdsa/state.go new file mode 100644 index 00000000..9613e6fc --- /dev/null +++ b/sign/slhdsa/state.go @@ -0,0 +1,357 @@ +package slhdsa + +import ( + "crypto/sha256" + "crypto/sha512" + "hash" + "io" + + "github.com/cloudflare/circl/internal/sha3" +) + +// statePriv encapsulates common data for performing a private operation. +type statePriv struct { + state + PRF statePRF +} + +func (s *statePriv) Size(p *params) uint32 { + return s.state.Size(p) + s.PRF.Size(p) +} + +func (p *params) NewStatePriv(skSeed, pkSeed []byte) (s statePriv) { + c := cursor(make([]byte, s.Size(p))) + s.state.init(p, &c, pkSeed) + s.PRF.Init(p, &c, skSeed, pkSeed) + + return +} + +func (s *statePriv) Clear() { + s.PRF.Clear() + s.state.Clear() +} + +// state encapsulates common data for performing a public operation. +type state struct { + *params + + F stateF + H stateH + T stateT +} + +func (s *state) Size(p *params) uint32 { + return s.F.Size(p) + s.H.Size(p) + s.T.Size(p) +} + +func (p *params) NewStatePub(pkSeed []byte) (s state) { + c := cursor(make([]byte, s.Size(p))) + s.init(p, &c, pkSeed) + + return +} + +func (s *state) init(p *params, c *cursor, pkSeed []byte) { + s.params = p + s.F.Init(p, c, pkSeed) + s.H.Init(p, c, pkSeed) + s.T.Init(p, c, pkSeed) +} + +func (s *state) Clear() { + s.F.Clear() + s.H.Clear() + s.T.Clear() + s.params = nil +} + +func sha256sum(out, in []byte) { s := sha256.Sum256(in); copy(out, s[:]) } +func sha512sum(out, in []byte) { s := sha512.Sum512(in); copy(out, s[:]) } + +type baseHasher struct { + hash func(out, in []byte) + input, output []byte + address +} + +func (b *baseHasher) Size(p *params) uint32 { + return p.n + p.addressSize() +} + +func (b *baseHasher) Clear() { + clearSlice(&b.input) + clearSlice(&b.output) + b.address.Clear() +} + +func (b *baseHasher) Final() []byte { + b.hash(b.output, b.input) + return b.output +} + +type statePRF struct{ baseHasher } + +func (s *statePRF) Init(p *params, cur *cursor, skSeed, pkSeed []byte) { + c := cursor(cur.Next(s.Size(p))) + s.output = c.Next(p.n) + s.input = c.Rest() + copy(c.Next(p.n), pkSeed) + _ = c.Next(s.padSize(p)) + s.address.fromBytes(p, &c) + copy(c.Next(p.n), skSeed) + + if p.isSHA2 { + s.hash = sha256sum + } else { + s.hash = sha3.ShakeSum256 + } +} + +func (s *statePRF) Size(p *params) uint32 { + return 2*p.n + s.padSize(p) + s.baseHasher.Size(p) +} + +func (s *statePRF) padSize(p *params) uint32 { + if p.isSHA2 { + return 64 - p.n + } else { + return 0 + } +} + +type stateF struct { + msg []byte + baseHasher +} + +func (s *stateF) Init(p *params, cur *cursor, pkSeed []byte) { + c := cursor(cur.Next(s.Size(p))) + s.output = c.Next(p.n) + s.input = c.Rest() + copy(c.Next(p.n), pkSeed) + _ = c.Next(s.padSize(p)) + s.address.fromBytes(p, &c) + s.msg = c.Next(p.n) + + if p.isSHA2 { + s.hash = sha256sum + } else { + s.hash = sha3.ShakeSum256 + } +} + +func (s *stateF) SetMessage(msg []byte) { copy(s.msg, msg) } + +func (s *stateF) Clear() { + s.baseHasher.Clear() + clearSlice(&s.msg) +} + +func (s *stateF) Size(p *params) uint32 { + return 2*p.n + s.padSize(p) + s.baseHasher.Size(p) +} + +func (s *stateF) padSize(p *params) uint32 { + if p.isSHA2 { + return 64 - p.n + } else { + return 0 + } +} + +type stateH struct { + msg0, msg1 []byte + baseHasher +} + +func (s *stateH) Init(p *params, cur *cursor, pkSeed []byte) { + c := cursor(cur.Next(s.Size(p))) + s.output = c.Next(p.n) + s.input = c.Rest() + copy(c.Next(p.n), pkSeed) + _ = c.Next(s.padSize(p)) + s.address.fromBytes(p, &c) + s.msg0 = c.Next(p.n) + s.msg1 = c.Next(p.n) + + if p.isSHA2 { + if p.n == 16 { + s.hash = sha256sum + } else { + s.hash = sha512sum + } + } else { + s.hash = sha3.ShakeSum256 + } +} + +func (s *stateH) SetMsgs(m0, m1 []byte) { + copy(s.msg0, m0) + copy(s.msg1, m1) +} + +func (s *stateH) Clear() { + s.baseHasher.Clear() + clearSlice(&s.msg0) + clearSlice(&s.msg1) +} + +func (s *stateH) Size(p *params) uint32 { + return 3*p.n + s.padSize(p) + s.baseHasher.Size(p) +} + +func (s *stateH) padSize(p *params) uint32 { + if p.isSHA2 { + if p.n == 16 { + return 64 - p.n + } else { + return 128 - p.n + } + } else { + return 0 + } +} + +type stateT struct { + hash interface { + io.Writer + Reset() + Final([]byte) + } + input, output []byte + address +} + +func (s *stateT) Init(p *params, cur *cursor, pkSeed []byte) { + c := cursor(cur.Next(s.Size(p))) + s.output = c.Next(s.outputSize(p))[:p.n] + s.input = c.Rest() + copy(c.Next(p.n), pkSeed) + _ = c.Next(s.padSize(p)) + s.address.fromBytes(p, &c) + + if p.isSHA2 { + if p.n == 16 { + s.hash = &sha2rw{sha256.New()} + } else { + s.hash = &sha2rw{sha512.New()} + } + } else { + s.hash = &sha3rw{sha3.NewShake256()} + } +} + +func (s *stateT) Clear() { + clearSlice(&s.input) + clearSlice(&s.output) + s.address.Clear() + s.hash.Reset() +} + +func (s *stateT) Reset() { + s.hash.Reset() + _, _ = s.hash.Write(s.input) +} + +func (s *stateT) WriteMessage(msg []byte) { _, _ = s.hash.Write(msg) } + +func (s *stateT) Final() []byte { + s.hash.Final(s.output) + return s.output +} + +func (s *stateT) Size(p *params) uint32 { + return s.outputSize(p) + s.padSize(p) + p.n + p.addressSize() +} + +func (s *stateT) outputSize(p *params) uint32 { + if p.isSHA2 { + if p.n == 16 { + return sha256.Size + } else { + return sha512.Size + } + } else { + return p.n + } +} + +func (s *stateT) padSize(p *params) uint32 { + if p.isSHA2 { + if p.n == 16 { + return 64 - p.n + } else { + return 128 - p.n + } + } else { + return 0 + } +} + +type sha2rw struct{ hash.Hash } + +func (s *sha2rw) Final(out []byte) { s.Sum(out[:0]) } +func (s *sha2rw) SumIdempotent(out []byte) { s.Sum(out[:0]) } + +type sha3rw struct{ sha3.State } + +func (s *sha3rw) Final(out []byte) { _, _ = s.Read(out) } +func (s *sha3rw) SumIdempotent(out []byte) { _, _ = s.Clone().Read(out) } + +type ( + item struct { + node []byte + z uint32 + } + stackNode []item +) + +func (p *params) NewStack(z uint32) stackNode { + s := make([]item, z) + c := cursor(make([]byte, z*p.n)) + for i := range s { + s[i].node = c.Next(p.n) + } + + return s[:0] +} + +func (s stackNode) isEmpty() bool { return len(s) == 0 } +func (s stackNode) top() item { return s[len(s)-1] } +func (s *stackNode) push(v item) { + next := len(*s) + *s = (*s)[:next+1] + (*s)[next].z = v.z + copy((*s)[next].node, v.node) +} + +func (s *stackNode) pop() (v item) { + last := len(*s) - 1 + if last >= 0 { + v = (*s)[last] + *s = (*s)[:last] + } + return +} + +func (s *stackNode) Clear() { + *s = (*s)[:cap(*s)] + for i := range *s { + clearSlice(&(*s)[i].node) + } + clear((*s)[:]) +} + +type cursor []byte + +func (c *cursor) Rest() []byte { return (*c)[:] } +func (c *cursor) Next(n uint32) (out []byte) { + if len(*c) >= int(n) { + out = (*c)[:n] + *c = (*c)[n:] + } + return +} + +func clearSlice(s *[]byte) { clear(*s); *s = nil } diff --git a/sign/slhdsa/testdata/keyGen_prompt.json.zip b/sign/slhdsa/testdata/keyGen_prompt.json.zip new file mode 100644 index 00000000..60090b00 Binary files /dev/null and b/sign/slhdsa/testdata/keyGen_prompt.json.zip differ diff --git a/sign/slhdsa/testdata/keyGen_results.json.zip b/sign/slhdsa/testdata/keyGen_results.json.zip new file mode 100644 index 00000000..396c15a3 Binary files /dev/null and b/sign/slhdsa/testdata/keyGen_results.json.zip differ diff --git a/sign/slhdsa/testdata/sigGen_prompt.json.zip b/sign/slhdsa/testdata/sigGen_prompt.json.zip new file mode 100644 index 00000000..24bf262b Binary files /dev/null and b/sign/slhdsa/testdata/sigGen_prompt.json.zip differ diff --git a/sign/slhdsa/testdata/sigGen_results.json.zip b/sign/slhdsa/testdata/sigGen_results.json.zip new file mode 100644 index 00000000..c76f9142 Binary files /dev/null and b/sign/slhdsa/testdata/sigGen_results.json.zip differ diff --git a/sign/slhdsa/testdata/verify_prompt.json.zip b/sign/slhdsa/testdata/verify_prompt.json.zip new file mode 100644 index 00000000..a4d33221 Binary files /dev/null and b/sign/slhdsa/testdata/verify_prompt.json.zip differ diff --git a/sign/slhdsa/testdata/verify_results.json.zip b/sign/slhdsa/testdata/verify_results.json.zip new file mode 100644 index 00000000..d7c2ea5b Binary files /dev/null and b/sign/slhdsa/testdata/verify_results.json.zip differ diff --git a/sign/slhdsa/wotsp.go b/sign/slhdsa/wotsp.go new file mode 100644 index 00000000..5438c5ca --- /dev/null +++ b/sign/slhdsa/wotsp.go @@ -0,0 +1,142 @@ +package slhdsa + +// See FIPS 205 -- Section 5 +// Winternitz One-Time Signature Plus Scheme + +const ( + wotsW uint32 = 16 // wotsW is w = 2^lg_w, where lg_w = 4. + wotsLen2 uint32 = 3 // wotsLen2 is len_2 fixed to 3. +) + +type ( + wotsPublicKey []byte // n bytes + wotsSignature []byte // wotsLen()*n bytes +) + +func (p *params) wotsSigSize() uint32 { return p.wotsLen() * p.n } +func (p *params) wotsLen() uint32 { return p.wotsLen1() + wotsLen2 } +func (p *params) wotsLen1() uint32 { return 2 * p.n } + +func (ws *wotsSignature) fromBytes(p *params, c *cursor) { + *ws = c.Next(p.wotsSigSize()) +} + +// See FIPS 205 -- Section 5 -- Algorithm 5. +func (s *state) chain( + x []byte, index, steps uint32, addr address, +) (out []byte) { + out = x + s.F.address.Set(addr) + for j := index; j < index+steps; j++ { + s.F.address.SetHashAddress(j) + s.F.SetMessage(out) + out = s.F.Final() + } + return +} + +// See FIPS 205 -- Section 5.1 -- Algorithm 6. +func (s *statePriv) wotsPkGen(addr address) wotsPublicKey { + s.PRF.address.Set(addr) + s.PRF.address.SetTypeAndClear(addressWotsPrf) + s.PRF.address.SetKeyPairAddress(addr.GetKeyPairAddress()) + + s.T.address.Set(addr) + s.T.address.SetTypeAndClear(addressWotsPk) + s.T.address.SetKeyPairAddress(addr.GetKeyPairAddress()) + + s.T.Reset() + wotsLen := s.wotsLen() + for i := range wotsLen { + s.PRF.address.SetChainAddress(i) + sk := s.PRF.Final() + + addr.SetChainAddress(i) + tmpi := s.chain(sk, 0, wotsW-1, addr) + + s.T.WriteMessage(tmpi) + } + + return s.T.Final() +} + +// See FIPS 205 -- Section 5.2 -- Algorithm 7. +func (s *statePriv) wotsSign(sig wotsSignature, msg []byte, addr address) { + if len(msg) != int(s.wotsLen1()/2) { + panic(ErrMsgLen) + } + + curSig := cursor(sig) + wotsLen1 := s.wotsLen1() + csum := wotsLen1 * (wotsW - 1) + + s.PRF.address.Set(addr) + s.PRF.address.SetTypeAndClear(addressWotsPrf) + s.PRF.address.SetKeyPairAddress(addr.GetKeyPairAddress()) + + // Signs every nibble of the message and computes the checksum. + for i := range wotsLen1 { + s.PRF.address.SetChainAddress(i) + sk := s.PRF.Final() + + addr.SetChainAddress(i) + msgi := uint32((msg[i/2] >> ((1 - (i & 1)) << 2)) & 0xF) + sigi := s.chain(sk, 0, msgi, addr) + copy(curSig.Next(s.n), sigi) + csum -= msgi + } + + // Lastly, every nibble of the checksum is also signed. + for i := range wotsLen2 { + s.PRF.address.SetChainAddress(wotsLen1 + i) + sk := s.PRF.Final() + + addr.SetChainAddress(wotsLen1 + i) + csumi := (csum >> (8 - 4*i)) & 0xF + sigi := s.chain(sk, 0, csumi, addr) + copy(curSig.Next(s.n), sigi) + } +} + +// See FIPS 205 -- Section 5.3 -- Algorithm 8. +func (s *state) wotsPkFromSig( + sig wotsSignature, msg []byte, addr address, +) wotsPublicKey { + if len(msg) != int(s.wotsLen1()/2) { + panic(ErrMsgLen) + } + + wotsLen1 := s.wotsLen1() + csum := wotsLen1 * (wotsW - 1) + + s.T.address.Set(addr) + s.T.address.SetTypeAndClear(addressWotsPk) + s.T.address.SetKeyPairAddress(addr.GetKeyPairAddress()) + + s.T.Reset() + curSig := cursor(sig) + + // Signs every nibble of the message, computes the checksum, and + // feeds each signature to the T function. + for i := range wotsLen1 { + addr.SetChainAddress(i) + msgi := uint32((msg[i/2] >> ((1 - (i & 1)) << 2)) & 0xF) + sigi := s.chain(curSig.Next(s.n), msgi, wotsW-1-msgi, addr) + + s.T.WriteMessage(sigi) + csum -= msgi + } + + // Every nibble of the checksum is also signed feeding the signature + // to the T function. + for i := range wotsLen2 { + addr.SetChainAddress(wotsLen1 + i) + csumi := (csum >> (8 - 4*i)) & 0xF + sigi := s.chain(curSig.Next(s.n), csumi, wotsW-1-csumi, addr) + + s.T.WriteMessage(sigi) + } + + // Generates the public key as the output of the T function. + return s.T.Final() +} diff --git a/sign/slhdsa/wotsp_test.go b/sign/slhdsa/wotsp_test.go new file mode 100644 index 00000000..9639a37d --- /dev/null +++ b/sign/slhdsa/wotsp_test.go @@ -0,0 +1,64 @@ +package slhdsa + +import ( + "bytes" + "testing" + + "github.com/cloudflare/circl/internal/test" +) + +func testWotsPlus(t *testing.T, p *params) { + skSeed := mustRead(t, p.n) + pkSeed := mustRead(t, p.n) + msg := mustRead(t, p.n) + + state := p.NewStatePriv(skSeed, pkSeed) + + addr := p.NewAddress() + addr.SetTypeAndClear(addressWotsHash) + + pk0 := state.wotsPkGen(addr) + + var sig wotsSignature + curSig := cursor(make([]byte, p.wotsSigSize())) + sig.fromBytes(p, &curSig) + state.wotsSign(sig, msg, addr) + + pk1 := state.wotsPkFromSig(sig, msg, addr) + + if !bytes.Equal(pk0, pk1) { + test.ReportError(t, pk0, pk1, skSeed, pkSeed, msg) + } +} + +func benchmarkWotsPlus(b *testing.B, p *params) { + skSeed := mustRead(b, p.n) + pkSeed := mustRead(b, p.n) + msg := mustRead(b, p.n) + + state := p.NewStatePriv(skSeed, pkSeed) + + addr := p.NewAddress() + addr.SetTypeAndClear(addressWotsHash) + + var sig wotsSignature + curSig := cursor(make([]byte, p.wotsSigSize())) + sig.fromBytes(p, &curSig) + state.wotsSign(sig, msg, addr) + + b.Run("PkGen", func(b *testing.B) { + for range b.N { + _ = state.wotsPkGen(addr) + } + }) + b.Run("Sign", func(b *testing.B) { + for range b.N { + state.wotsSign(sig, msg, addr) + } + }) + b.Run("PkFromSig", func(b *testing.B) { + for range b.N { + _ = state.wotsPkFromSig(sig, msg, addr) + } + }) +} diff --git a/sign/slhdsa/xmss.go b/sign/slhdsa/xmss.go new file mode 100644 index 00000000..ef896f4b --- /dev/null +++ b/sign/slhdsa/xmss.go @@ -0,0 +1,111 @@ +package slhdsa + +// See FIPS 205 -- Section 6 +// eXtended Merkle Signature Scheme (XMSS) extends the WOTS+ signature +// scheme into one that can sign multiple messages. + +type ( + xmssPublicKey []byte // n bytes + xmssSignature struct { + wotsSig wotsSignature // wotsSigSize() bytes + authPath []byte // hPrime*n bytes + } // wotsSigSize() + hPrime*n bytes +) + +func (p *params) xmssPkSize() uint32 { return p.n } +func (p *params) xmssAuthPathSize() uint32 { return p.hPrime * p.n } +func (p *params) xmssSigSize() uint32 { + return p.wotsSigSize() + p.xmssAuthPathSize() +} + +func (xs *xmssSignature) fromBytes(p *params, c *cursor) { + xs.wotsSig.fromBytes(p, c) + xs.authPath = c.Next(p.xmssAuthPathSize()) +} + +// See FIPS 205 -- Section 6.1 -- Algorithm 9 -- Iterative version. +// +// This is a stack-based implementation that computes the tree leaves +// in order (from the left to the right). +// Its recursive version can be found at xmss_test.go file. +func (s *statePriv) xmssNodeIter( + stack stackNode, root []byte, i, z uint32, addr address, +) { + if !(z <= s.hPrime && i < (1<<(s.hPrime-z))) { + panic(ErrTree) + } + + s.H.address.Set(addr) + s.H.address.SetTypeAndClear(addressTree) + + twoZ := uint32(1) << z + iTwoZ := i << z + for k := range twoZ { + li := iTwoZ + k + lz := uint32(0) + + addr.SetTypeAndClear(addressWotsHash) + addr.SetKeyPairAddress(li) + node := s.wotsPkGen(addr) + + for !stack.isEmpty() && stack.top().z == lz { + left := stack.pop() + li = (li - 1) >> 1 + lz = lz + 1 + + s.H.address.SetTreeHeight(lz) + s.H.address.SetTreeIndex(li) + s.H.SetMsgs(left.node, node) + node = s.H.Final() + } + + stack.push(item{node, lz}) + } + + copy(root, stack.pop().node) +} + +// See FIPS 205 -- Section 6.2 -- Algorithm 10. +func (s *statePriv) xmssSign( + stack stackNode, sig xmssSignature, msg []byte, idx uint32, addr address, +) { + authPath := cursor(sig.authPath) + for j := range s.hPrime { + k := (idx >> j) ^ 1 + s.xmssNodeIter(stack, authPath.Next(s.n), k, j, addr) + } + + addr.SetTypeAndClear(addressWotsHash) + addr.SetKeyPairAddress(idx) + s.wotsSign(sig.wotsSig, msg, addr) +} + +// See FIPS 205 -- Section 6.3 -- Algorithm 11. +func (s *state) xmssPkFromSig( + out xmssPublicKey, msg []byte, sig xmssSignature, idx uint32, addr address, +) { + addr.SetTypeAndClear(addressWotsHash) + addr.SetKeyPairAddress(idx) + pk := xmssPublicKey(s.wotsPkFromSig(sig.wotsSig, msg, addr)) + + treeIdx := idx + s.H.address.Set(addr) + s.H.address.SetTypeAndClear(addressTree) + + authPath := cursor(sig.authPath) + for k := range s.hPrime { + if (idx>>k)&0x1 == 0 { + treeIdx = treeIdx >> 1 + s.H.SetMsgs(pk, authPath.Next(s.n)) + } else { + treeIdx = (treeIdx - 1) >> 1 + s.H.SetMsgs(authPath.Next(s.n), pk) + } + + s.H.address.SetTreeHeight(k + 1) + s.H.address.SetTreeIndex(treeIdx) + pk = s.H.Final() + } + + copy(out, pk) +} diff --git a/sign/slhdsa/xmss_test.go b/sign/slhdsa/xmss_test.go new file mode 100644 index 00000000..7123ea22 --- /dev/null +++ b/sign/slhdsa/xmss_test.go @@ -0,0 +1,116 @@ +package slhdsa + +import ( + "bytes" + "fmt" + "testing" + + "github.com/cloudflare/circl/internal/test" +) + +// See FIPS 205 -- Section 6.1 -- Algorithm 9 -- Recursive version. +func (s *statePriv) xmssNodeRec(i, z uint32, addr address) (node []byte) { + if !(z <= s.hPrime && i < (1<<(s.hPrime-z))) { + panic(ErrTree) + } + + node = make([]byte, s.n) + if z == 0 { + addr.SetTypeAndClear(addressWotsHash) + addr.SetKeyPairAddress(i) + copy(node, s.wotsPkGen(addr)) + } else { + lnode := s.xmssNodeRec(2*i, z-1, addr) + rnode := s.xmssNodeRec(2*i+1, z-1, addr) + + s.H.address.Set(addr) + s.H.address.SetTypeAndClear(addressTree) + s.H.address.SetTreeHeight(z) + s.H.address.SetTreeIndex(i) + s.H.SetMsgs(lnode, rnode) + copy(node, s.H.Final()) + } + + return +} + +func testXmss(t *testing.T, p *params) { + skSeed := mustRead(t, p.n) + pkSeed := mustRead(t, p.n) + msg := mustRead(t, p.n) + + state := p.NewStatePriv(skSeed, pkSeed) + + addr := p.NewAddress() + addr.SetTypeAndClear(addressWotsHash) + idx := uint32(0) + + rootRec := state.xmssNodeRec(idx, p.hPrime, addr) + test.CheckOk( + len(rootRec) == int(p.n), + fmt.Sprintf("bad xmss rootRec length: %v", len(rootRec)), + t, + ) + + stack := p.NewStack(p.hPrime) + rootIter := make([]byte, p.n) + state.xmssNodeIter(stack, rootIter, idx, p.hPrime, addr) + + if !bytes.Equal(rootRec, rootIter) { + test.ReportError(t, rootRec, rootIter, skSeed, pkSeed, msg) + } + + var sig xmssSignature + curSig := cursor(make([]byte, p.xmssSigSize())) + sig.fromBytes(p, &curSig) + state.xmssSign(stack, sig, msg, idx, addr) + + node := make([]byte, p.xmssPkSize()) + state.xmssPkFromSig(node, msg, sig, idx, addr) + + if !bytes.Equal(rootRec, node) { + test.ReportError(t, rootRec, node, skSeed, pkSeed, msg) + } +} + +func benchmarkXmss(b *testing.B, p *params) { + skSeed := mustRead(b, p.n) + pkSeed := mustRead(b, p.n) + msg := mustRead(b, p.n) + + state := p.NewStatePriv(skSeed, pkSeed) + + addr := p.NewAddress() + addr.SetTypeAndClear(addressWotsHash) + idx := uint32(0) + + var sig xmssSignature + curSig := cursor(make([]byte, p.xmssSigSize())) + sig.fromBytes(p, &curSig) + state.xmssSign(state.NewStack(p.hPrime), sig, msg, idx, addr) + + b.Run("NodeRec", func(b *testing.B) { + for range b.N { + _ = state.xmssNodeRec(idx, p.hPrime, addr) + } + }) + b.Run("NodeIter", func(b *testing.B) { + node := make([]byte, p.n) + stack := state.NewStack(p.hPrime) + for range b.N { + state.xmssNodeIter(stack, node, idx, p.hPrime, addr) + } + }) + b.Run("Sign", func(b *testing.B) { + stack := state.NewStack(p.hPrime) + for range b.N { + state.xmssSign(stack, sig, msg, idx, addr) + } + }) + b.Run("PkFromSig", func(b *testing.B) { + node := make([]byte, p.xmssPkSize()) + for range b.N { + state.xmssPkFromSig(node, msg, sig, idx, addr) + } + }) +}