diff --git a/README.md b/README.md index a4fd64f95..57406615e 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ Alternatively, look at the [Cloudflare Go](https://github.com/cloudflare/go/tree [RFC-9496]: https://doi.org/10.17487/RFC9496 [RFC-9497]: https://doi.org/10.17487/RFC9497 [FIPS 202]: https://doi.org/10.6028/NIST.FIPS.202 +[FIPS 205]: https://doi.org/10.6028/NIST.FIPS.205 [FIPS 186-5]: https://doi.org/10.6028/NIST.FIPS.186-5 [BLS12-381]: https://electriccoin.co/blog/new-snark-curve/ [ia.cr/2015/267]: https://ia.cr/2015/267 @@ -91,6 +92,7 @@ Alternatively, look at the [Cloudflare Go](https://github.com/cloudflare/go/tree |:---:| - [Dilithium](./sign/dilithium): modes 2, 3, 5 ([Dilithium](https://pq-crystals.org/dilithium/)). + - [SLH-DSA](./sign/slhdsa): twelve parameter sets, pure and pre-hash signing ([FIPS 205]). ### Zero-knowledge Proofs diff --git a/internal/conv/conv.go b/internal/conv/conv.go index 649a8e931..7c401b166 100644 --- a/internal/conv/conv.go +++ b/internal/conv/conv.go @@ -5,6 +5,8 @@ import ( "fmt" "math/big" "strings" + + "golang.org/x/crypto/cryptobyte" ) // BytesLe2Hex returns an hexadecimal string of a number stored in a @@ -138,3 +140,26 @@ func BigInt2Uint64Le(z []uint64, x *big.Int) { z[i] = 0 } } + +// MarshalBinary encodes a value into a byte array in a format readable by UnmarshalBinary. +func MarshalBinary(v cryptobyte.MarshalingValue) ([]byte, error) { + var b cryptobyte.Builder + b.AddValue(v) + return b.Bytes() +} + +// A UnmarshalingValue decodes itself from a cryptobyte.String and advances the pointer. +// It reports whether the read was successful. +type UnmarshalingValue interface { + Unmarshal(*cryptobyte.String) bool +} + +// UnmarshalBinary recovers a value from a byte array. +// It returns an error if the read was unsuccessful. +func UnmarshalBinary(v UnmarshalingValue, data []byte) (err error) { + s := cryptobyte.String(data) + if !v.Unmarshal(&s) || !s.Empty() { + err = fmt.Errorf("cannot read %T from input string", v) + } + return +} diff --git a/internal/test/test.go b/internal/test/test.go index 576211a9f..72d37fdf8 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -1,6 +1,8 @@ package test import ( + "bytes" + "encoding" "errors" "fmt" "strings" @@ -58,3 +60,26 @@ func CheckPanic(f func()) error { f() return hasPanicked } + +func CheckMarshal( + t *testing.T, + x, y interface { + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler + }, +) { + t.Helper() + + want, err := x.MarshalBinary() + CheckNoErr(t, err, fmt.Sprintf("cannot marshal %T = %v", x, x)) + + err = y.UnmarshalBinary(want) + CheckNoErr(t, err, fmt.Sprintf("cannot unmarshal %T from %x", y, want)) + + got, err := y.MarshalBinary() + CheckNoErr(t, err, fmt.Sprintf("cannot marshal %T = %v", y, y)) + + if !bytes.Equal(got, want) { + ReportError(t, got, want, x, y) + } +} diff --git a/sign/schemes/schemes.go b/sign/schemes/schemes.go index 66ee61253..7b38273a7 100644 --- a/sign/schemes/schemes.go +++ b/sign/schemes/schemes.go @@ -6,6 +6,7 @@ // Ed448 // Ed25519-Dilithium2 // Ed448-Dilithium3 +// SLH-DSA package schemes import ( @@ -16,6 +17,7 @@ import ( "github.com/cloudflare/circl/sign/ed448" "github.com/cloudflare/circl/sign/eddilithium2" "github.com/cloudflare/circl/sign/eddilithium3" + "github.com/cloudflare/circl/sign/slhdsa" ) var allSchemes = [...]sign.Scheme{ @@ -23,6 +25,18 @@ var allSchemes = [...]sign.Scheme{ ed448.Scheme(), eddilithium2.Scheme(), eddilithium3.Scheme(), + slhdsa.ParamIDSHA2Small128, + slhdsa.ParamIDSHAKESmall128, + slhdsa.ParamIDSHA2Fast128, + slhdsa.ParamIDSHAKEFast128, + slhdsa.ParamIDSHA2Small192, + slhdsa.ParamIDSHAKESmall192, + slhdsa.ParamIDSHA2Fast192, + slhdsa.ParamIDSHAKEFast192, + slhdsa.ParamIDSHA2Small256, + slhdsa.ParamIDSHAKESmall256, + slhdsa.ParamIDSHA2Fast256, + slhdsa.ParamIDSHAKEFast256, } var allSchemeNames map[string]sign.Scheme diff --git a/sign/schemes/schemes_test.go b/sign/schemes/schemes_test.go index 2d8e86512..417b074c0 100644 --- a/sign/schemes/schemes_test.go +++ b/sign/schemes/schemes_test.go @@ -117,6 +117,18 @@ func Example() { // Ed448 // Ed25519-Dilithium2 // Ed448-Dilithium3 + // SLH-DSA-SHA2-128s + // SLH-DSA-SHAKE-128s + // SLH-DSA-SHA2-128f + // SLH-DSA-SHAKE-128f + // SLH-DSA-SHA2-192s + // SLH-DSA-SHAKE-192s + // SLH-DSA-SHA2-192f + // SLH-DSA-SHAKE-192f + // SLH-DSA-SHA2-256s + // SLH-DSA-SHAKE-256s + // SLH-DSA-SHA2-256f + // SLH-DSA-SHAKE-256f } func BenchmarkGenerateKeyPair(b *testing.B) { diff --git a/sign/slhdsa/acvp_test.go b/sign/slhdsa/acvp_test.go new file mode 100644 index 000000000..9c19c834b --- /dev/null +++ b/sign/slhdsa/acvp_test.go @@ -0,0 +1,300 @@ +package slhdsa + +import ( + "archive/zip" + "bytes" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "testing" + + "github.com/cloudflare/circl/internal/test" +) + +type acvpHeader struct { + VsID int `json:"vsId"` + Algorithm string `json:"algorithm"` + Mode string `json:"mode"` + Revision string `json:"revision"` + IsSample bool `json:"isSample"` +} + +type acvpKeygenVector struct { + acvpHeader + TestGroups []struct { + TgID int `json:"tgId"` + TestType string `json:"testType"` + ParameterSet string `json:"parameterSet"` + Tests []keygenInput `json:"tests"` + } `json:"testGroups"` +} + +type keygenInput struct { + TcID int `json:"tcId"` + Deferred bool `json:"deferred"` + SkSeed hexBytes `json:"skSeed"` + SkPrf hexBytes `json:"skPrf"` + PkSeed hexBytes `json:"pkSeed"` + Sk hexBytes `json:"sk"` + Pk hexBytes `json:"pk"` +} + +type acvpSigGenPrompt struct { + acvpHeader + TestGroups []struct { + TgID int `json:"tgId"` + TestType string `json:"testType"` + ParameterSet string `json:"parameterSet"` + Deterministic bool `json:"deterministic"` + Tests []signInput `json:"tests"` + } `json:"testGroups"` +} + +type signInput struct { + TcID int `json:"tcId"` + Sk hexBytes `json:"sk"` + MsgLen int `json:"messageLength"` + Msg hexBytes `json:"message"` + AddRand hexBytes `json:"additionalRandomness,omitempty"` +} + +type acvpSigGenResult struct { + acvpHeader + TestGroups []struct { + TgID int `json:"tgId"` + Tests []struct { + TcID int `json:"tcId"` + Signature hexBytes `json:"signature"` + } `json:"tests"` + } `json:"testGroups"` +} + +type acvpVerifyInput struct { + acvpHeader + TestGroups []struct { + TgID int `json:"tgId"` + TestType string `json:"testType"` + ParameterSet string `json:"parameterSet"` + Tests []verifyInput `json:"tests"` + } `json:"testGroups"` +} + +type verifyInput struct { + TcID int `json:"tcId"` + Pk hexBytes `json:"pk"` + MessageLength int `json:"messageLength"` + Message hexBytes `json:"message"` + Signature hexBytes `json:"signature"` + Reason string `json:"reason"` +} + +type acvpVerifyResult struct { + acvpHeader + TestGroups []struct { + TgID int `json:"tgId"` + Tests []struct { + TcID int `json:"tcId"` + TestPassed bool `json:"testPassed"` + } `json:"tests"` + } `json:"testGroups"` +} + +func TestACVP(t *testing.T) { + t.Run("Keygen", testKeygen) + t.Run("Sign", testSign) + t.Run("Verify", testVerify) +} + +func testKeygen(t *testing.T) { + // https://github.com/usnistgov/ACVP-Server/tree/v1.1.0.35/gen-val/json-files/SLH-DSA-keyGen-FIPS205 + inputs := new(acvpKeygenVector) + readVector(t, "testdata/keygen.json.zip", inputs) + + for _, group := range inputs.TestGroups { + t.Run(fmt.Sprintf("TgID_%v", group.TgID), func(t *testing.T) { + for ti := range group.Tests { + t.Run(fmt.Sprintf("TcID_%v", group.Tests[ti].TcID), func(t *testing.T) { + acvpKeygen(t, group.ParameterSet, &group.Tests[ti]) + }) + } + }) + } +} + +func testSign(t *testing.T) { + // https://github.com/usnistgov/ACVP-Server/tree/v1.1.0.35/gen-val/json-files/SLH-DSA-sigGen-FIPS205 + inputs := new(acvpSigGenPrompt) + readVector(t, "testdata/sigGen_prompt.json.zip", inputs) + outputs := new(acvpSigGenResult) + readVector(t, "testdata/sigGen_results.json.zip", outputs) + + for gi, group := range inputs.TestGroups { + test.CheckOk(group.TgID == outputs.TestGroups[gi].TgID, "mismatch of TgID", t) + + t.Run(fmt.Sprintf("TgID_%v", group.TgID), func(t *testing.T) { + for ti := range group.Tests { + test.CheckOk( + group.Tests[ti].TcID == outputs.TestGroups[gi].Tests[ti].TcID, + "mismatch of TcID", t, + ) + + t.Run(fmt.Sprintf("TcID_%v", group.Tests[ti].TcID), func(t *testing.T) { + acvpSign( + t, group.ParameterSet, &group.Tests[ti], + outputs.TestGroups[gi].Tests[ti].Signature, + group.Deterministic, + ) + }) + } + }) + } +} + +func testVerify(t *testing.T) { + // https://github.com/usnistgov/ACVP-Server/tree/v1.1.0.35/gen-val/json-files/SLH-DSA-sigVer-FIPS205 + inputs := new(acvpVerifyInput) + readVector(t, "testdata/verify_prompt.json.zip", inputs) + outputs := new(acvpVerifyResult) + readVector(t, "testdata/verify_results.json.zip", outputs) + + for gi, group := range inputs.TestGroups { + test.CheckOk(group.TgID == outputs.TestGroups[gi].TgID, "mismatch of TgID", t) + + t.Run(fmt.Sprintf("TgID_%v", group.TgID), func(t *testing.T) { + for ti := range group.Tests { + test.CheckOk( + group.Tests[ti].TcID == outputs.TestGroups[gi].Tests[ti].TcID, + "mismatch of TcID", t, + ) + + t.Run(fmt.Sprintf("TcID_%v", group.Tests[ti].TcID), func(t *testing.T) { + acvpVerify( + t, group.ParameterSet, &group.Tests[ti], + outputs.TestGroups[gi].Tests[ti].TestPassed, + ) + }) + } + }) + } +} + +func acvpKeygen(t *testing.T, paramSet string, in *keygenInput) { + t.Parallel() + + id, err := ParamIDByName(paramSet) + test.CheckNoErr(t, err, "invalid param name") + + params := id.params() + pk, sk := slhKeyGenInternal(params, in.SkSeed, in.SkPrf, in.PkSeed) + + skGot, err := sk.MarshalBinary() + test.CheckNoErr(t, err, "PrivateKey.MarshalBinary failed") + + if !bytes.Equal(skGot, in.Sk) { + test.ReportError(t, skGot, in.Sk) + } + + skWant := &PrivateKey{ParamID: id} + err = skWant.UnmarshalBinary(in.Sk) + test.CheckNoErr(t, err, "PrivateKey.UnmarshalBinary failed") + + if !sk.Equal(skWant) { + test.ReportError(t, sk, skWant) + } + + pkGot, err := pk.MarshalBinary() + test.CheckNoErr(t, err, "PublicKey.MarshalBinary failed") + + if !bytes.Equal(pkGot, in.Pk) { + test.ReportError(t, pkGot, in.Pk) + } + + pkWant := &PublicKey{ParamID: id} + err = pkWant.UnmarshalBinary(in.Pk) + test.CheckNoErr(t, err, "PublicKey.UnmarshalBinary failed") + + if !pk.Equal(pkWant) { + test.ReportError(t, pk, pkWant) + } +} + +func acvpSign( + t *testing.T, + paramSet string, + in *signInput, + wantSignature []byte, + deterministic bool, +) { + t.Parallel() + + id, err := ParamIDByName(paramSet) + test.CheckNoErr(t, err, "invalid param name") + + sk := &PrivateKey{ParamID: id} + err = sk.UnmarshalBinary(in.Sk) + test.CheckNoErr(t, err, "PrivateKey.UnmarshalBinary failed") + + addRand := sk.publicKey.seed + if !deterministic { + addRand = in.AddRand + } + + params := id.params() + gotSignature, err := slhSignInternal(params, sk, in.Msg, addRand) + test.CheckNoErr(t, err, "slhSignInternal failed") + + if !bytes.Equal(gotSignature, wantSignature) { + more := " ... (more bytes differ)" + got := hex.EncodeToString(gotSignature[:10]) + more + want := hex.EncodeToString(wantSignature[:10]) + more + test.ReportError(t, got, want) + } + + valid := slhVerifyInternal(params, &sk.publicKey, in.Msg, gotSignature) + test.CheckOk(valid, "slhVerifyInternal failed", t) +} + +func acvpVerify(t *testing.T, paramSet string, in *verifyInput, want bool) { + id, err := ParamIDByName(paramSet) + test.CheckNoErr(t, err, "invalid param name") + + pk := &PublicKey{ParamID: id} + err = pk.UnmarshalBinary(in.Pk) + test.CheckNoErr(t, err, "PublicKey.UnmarshalBinary failed") + + params := id.params() + got := slhVerifyInternal(params, pk, in.Message, in.Signature) + + if got != want { + test.ReportError(t, got, want) + } +} + +type hexBytes []byte + +func (b *hexBytes) UnmarshalJSON(data []byte) (err error) { + var s string + err = json.Unmarshal(data, &s) + if err != nil { + return + } + *b, err = hex.DecodeString(s) + return +} + +func readVector(t *testing.T, fileName string, vector interface{}) { + zipFile, err := zip.OpenReader(fileName) + test.CheckNoErr(t, err, "error opening file") + defer zipFile.Close() + + jsonFile, err := zipFile.File[0].Open() + test.CheckNoErr(t, err, "error opening unzipping file") + defer jsonFile.Close() + + input, err := io.ReadAll(jsonFile) + test.CheckNoErr(t, err, "error reading bytes") + + err = json.Unmarshal(input, &vector) + test.CheckNoErr(t, err, "error unmarshalling JSON file") +} diff --git a/sign/slhdsa/address.go b/sign/slhdsa/address.go new file mode 100644 index 000000000..8a0b2878a --- /dev/null +++ b/sign/slhdsa/address.go @@ -0,0 +1,92 @@ +package slhdsa + +import "encoding/binary" + +// See FIPS 205 -- Section 4.2 +// Functions and Addressing + +type addrType = uint32 + +const ( + addressWotsHash = addrType(iota) + addressWotsPk + addressTree + addressForsTree + addressForsRoots + addressWotsPrf + addressForsPrf +) + +const ( + addressSizeCompressed = 22 + addressSizeNonCompressed = 32 +) + +type address struct { + b []byte + o int +} + +func (p *params) addressSize() uint32 { + if p.isSHA2 { + return addressSizeCompressed + } else { + return addressSizeNonCompressed + } +} + +func (p *params) addressOffset() int { + if p.isSHA2 { + return 0 + } else { + return 10 + } +} + +func (p *params) NewAddress() (a address) { + var m [addressSizeNonCompressed]byte + a.b = m[:p.addressSize()] + a.o = p.addressOffset() + return +} + +func (a *address) fromBytes(p *params, c *cursor) { + a.b = c.Next(p.addressSize()) + a.o = p.addressOffset() +} + +func (a *address) Set(x address) { copy(a.b, x.b); a.o = x.o } +func (a *address) Clear() { clearSlice(&a.b); a.o = 0 } +func (a *address) SetKeyPairAddress(i uint32) { binary.BigEndian.PutUint32(a.b[a.o+10:], i) } +func (a *address) SetChainAddress(i uint32) { binary.BigEndian.PutUint32(a.b[a.o+14:], i) } +func (a *address) SetTreeHeight(i uint32) { binary.BigEndian.PutUint32(a.b[a.o+14:], i) } +func (a *address) SetHashAddress(i uint32) { binary.BigEndian.PutUint32(a.b[a.o+18:], i) } +func (a *address) SetTreeIndex(i uint32) { binary.BigEndian.PutUint32(a.b[a.o+18:], i) } +func (a *address) GetKeyPairAddress() uint32 { return binary.BigEndian.Uint32(a.b[a.o+10:]) } +func (a *address) SetLayerAddress(l addrType) { + if a.o == 0 { + a.b[0] = byte(l & 0xFF) + } else { + binary.BigEndian.PutUint32(a.b[0:], l) + } +} + +func (a *address) SetTreeAddress(t [3]uint32) { + if a.o == 0 { + binary.BigEndian.PutUint32(a.b[1:], t[1]) + binary.BigEndian.PutUint32(a.b[5:], t[0]) + } else { + binary.BigEndian.PutUint32(a.b[4:], t[2]) + binary.BigEndian.PutUint32(a.b[8:], t[1]) + binary.BigEndian.PutUint32(a.b[12:], t[0]) + } +} + +func (a *address) SetTypeAndClear(t uint32) { + if a.o == 0 { + a.b[9] = byte(t) + } else { + binary.BigEndian.PutUint32(a.b[16:], t) + } + clear(a.b[a.o+10:]) +} diff --git a/sign/slhdsa/all_test.go b/sign/slhdsa/all_test.go new file mode 100644 index 000000000..2fcdd3f93 --- /dev/null +++ b/sign/slhdsa/all_test.go @@ -0,0 +1,46 @@ +package slhdsa + +import ( + "crypto/rand" + "io" + "testing" +) + +func TestInner(t *testing.T) { + for i := range supportedParams { + param := &supportedParams[i] + + t.Run(param.name, func(t *testing.T) { + t.Parallel() + + t.Run("Wots", func(t *testing.T) { testWotsPlus(t, param) }) + t.Run("Xmss", func(t *testing.T) { testXmss(t, param) }) + t.Run("Ht", func(tt *testing.T) { testHyperTree(tt, param) }) + t.Run("Fors", func(tt *testing.T) { testFors(tt, param) }) + t.Run("Int", func(tt *testing.T) { testInternal(tt, param) }) + }) + } +} + +func BenchmarkInner(b *testing.B) { + for i := range supportedParams { + param := &supportedParams[i] + + b.Run(param.name, func(b *testing.B) { + b.Run("Wots", func(b *testing.B) { benchmarkWotsPlus(b, param) }) + b.Run("Xmss", func(b *testing.B) { benchmarkXmss(b, param) }) + b.Run("Ht", func(b *testing.B) { benchmarkHyperTree(b, param) }) + b.Run("Fors", func(b *testing.B) { benchmarkFors(b, param) }) + b.Run("Int", func(b *testing.B) { benchmarkInternal(b, param) }) + }) + } +} + +func mustRead(t testing.TB, size uint32) (out []byte) { + out = make([]byte, size) + _, err := io.ReadFull(rand.Reader, out) + if err != nil { + t.Fatalf("rand reader error: %v", err) + } + return +} diff --git a/sign/slhdsa/fors.go b/sign/slhdsa/fors.go new file mode 100644 index 000000000..a003b77a2 --- /dev/null +++ b/sign/slhdsa/fors.go @@ -0,0 +1,168 @@ +package slhdsa + +// See FIPS 205 -- Section 8 +// Forest of Random Subsets (FORS) is a few-time signature scheme that is +// used to sign the digests of the actual messages. + +type ( + forsPublicKey []byte // n bytes + forsPrivateKey []byte // n bytes + forsSignature []forsPair // k*forsPairSize() bytes + forsPair struct { + sk forsPrivateKey // forsSkSize() bytes + auth [][]byte // a*n bytes + } // forsSkSize() + a*n bytes +) + +func (p *params) forsMsgSize() uint32 { return (p.k*p.a + 7) / 8 } +func (p *params) forsPkSize() uint32 { return p.n } +func (p *params) forsSkSize() uint32 { return p.n } +func (p *params) forsSigSize() uint32 { return p.k * p.forsPairSize() } +func (p *params) forsPairSize() uint32 { return p.forsSkSize() + p.a*p.n } + +func (fs *forsSignature) fromBytes(p *params, c *cursor) { + *fs = make([]forsPair, p.k) + for i := range *fs { + (*fs)[i].fromBytes(p, c) + } +} + +func (fp *forsPair) fromBytes(p *params, c *cursor) { + fp.sk = c.Next(p.forsSkSize()) + fp.auth = make([][]byte, p.a) + for i := range fp.auth { + fp.auth[i] = c.Next(p.n) + } +} + +// See FIPS 205 -- Section 8.1 -- Algorithm 14. +func (s *statePriv) forsSkGen(addr address, idx uint32) forsPrivateKey { + s.PRF.address.Set(addr) + s.PRF.address.SetTypeAndClear(addressForsPrf) + s.PRF.address.SetKeyPairAddress(addr.GetKeyPairAddress()) + s.PRF.address.SetTreeIndex(idx) + + return s.PRF.Final() +} + +// See FIPS 205 -- Section 8.2 -- Algorithm 15 -- Iterative version. +func (s *statePriv) forsNodeIter( + stack stackNode, root []byte, i, z uint32, addr address, +) { + if !(z <= s.a && i < s.k<<(s.a-z)) { + panic(ErrTree) + } + + s.F.address.Set(addr) + s.F.address.SetTreeHeight(0) + + s.H.address.Set(addr) + + twoZ := uint32(1) << z + iTwoZ := i << z + for k := uint32(0); k < twoZ; k++ { + li := iTwoZ + k + lz := uint32(0) + + sk := s.forsSkGen(addr, li) + s.F.address.SetTreeIndex(li) + s.F.SetMessage(sk) + node := s.F.Final() + + for !stack.isEmpty() && stack.top().z == lz { + left := stack.pop() + li = (li - 1) >> 1 + lz = lz + 1 + + s.H.address.SetTreeHeight(lz) + s.H.address.SetTreeIndex(li) + s.H.SetMsgs(left.node, node) + node = s.H.Final() + } + + stack.push(item{lz, node}) + } + + copy(root, stack.pop().node) +} + +// See FIPS 205 -- Section 8.3 -- Algorithm 16. +func (s *statePriv) forsSign(sig forsSignature, digest []byte, addr address) { + stack := s.NewStack(s.a - 1) + defer stack.Clear() + + in, bits, total := 0, uint32(0), uint32(0) + maskA := (uint32(1) << s.a) - 1 + + for i := uint32(0); i < s.k; i++ { + for bits < s.a { + total = (total << 8) + uint32(digest[in]) + in++ + bits += 8 + } + + bits -= s.a + indicesI := (total >> bits) & maskA + forsSk := s.forsSkGen(addr, (i<> j) ^ 1 + s.forsNodeIter(stack, sig[i].auth[j], (i<<(s.a-j))+shift, j, addr) + } + } +} + +// See FIPS 205 -- Section 8.4 -- Algorithm 17. +func (s *state) forsPkFromSig( + digest []byte, sig forsSignature, addr address, +) (pk forsPublicKey) { + pk = make([]byte, s.forsPkSize()) + + s.F.address.Set(addr) + s.F.address.SetTreeHeight(0) + + s.H.address.Set(addr) + + s.T.address.Set(addr) + s.T.address.SetTypeAndClear(addressForsRoots) + s.T.address.SetKeyPairAddress(addr.GetKeyPairAddress()) + s.T.Reset() + + in, bits, total := 0, uint32(0), uint32(0) + maskB := (uint32(1) << s.a) - 1 + + for i := uint32(0); i < s.k; i++ { + for bits < s.a { + total = (total << 8) + uint32(digest[in]) + in++ + bits += 8 + } + + bits -= s.a + indicesI := (total >> bits) & maskB + treeIdx := (i << s.a) + indicesI + s.F.address.SetTreeIndex(treeIdx) + s.F.SetMessage(sig[i].sk) + node := s.F.Final() + + for j := uint32(0); j < s.a; j++ { + if (indicesI>>j)&0x1 == 0 { + treeIdx = treeIdx >> 1 + s.H.SetMsgs(node, sig[i].auth[j]) + } else { + treeIdx = (treeIdx - 1) >> 1 + s.H.SetMsgs(sig[i].auth[j], node) + } + + s.H.address.SetTreeHeight(j + 1) + s.H.address.SetTreeIndex(treeIdx) + node = s.H.Final() + } + + s.T.WriteMessage(node) + } + + copy(pk, s.T.Final()) + return pk +} diff --git a/sign/slhdsa/fors_test.go b/sign/slhdsa/fors_test.go new file mode 100644 index 000000000..63460627e --- /dev/null +++ b/sign/slhdsa/fors_test.go @@ -0,0 +1,117 @@ +package slhdsa + +import ( + "bytes" + "testing" + + "github.com/cloudflare/circl/internal/test" +) + +// See FIPS 205 -- Section 8.2 -- Algorithm 15 -- Recursive version. +func (s *statePriv) forsNodeRec(i, z uint32, addr address) (node []byte) { + if !(z <= s.a && i < s.k<<(s.a-z)) { + panic(ErrTree) + } + + node = make([]byte, s.n) + if z == 0 { + sk := s.forsSkGen(addr, i) + addr.SetTreeHeight(0) + addr.SetTreeIndex(i) + + s.F.address.Set(addr) + s.F.SetMessage(sk) + copy(node, s.F.Final()) + } else { + lnode := s.forsNodeRec(2*i, z-1, addr) + rnode := s.forsNodeRec(2*i+1, z-1, addr) + + s.H.address.Set(addr) + s.H.address.SetTreeHeight(z) + s.H.address.SetTreeIndex(i) + s.H.SetMsgs(lnode, rnode) + copy(node, s.H.Final()) + } + + return +} + +func testFors(t *testing.T, p *params) { + skSeed := mustRead(t, p.n) + pkSeed := mustRead(t, p.n) + msg := mustRead(t, p.forsMsgSize()) + + state := p.NewStatePriv(skSeed, pkSeed) + + idxTree := [3]uint32{0, 0, 0} + idxLeaf := uint32(0) + + addr := p.NewAddress() + addr.SetLayerAddress(p.d - 1) + + pkRoot := make([]byte, p.n) + state.xmssNodeIter(p.NewStack(p.hPrime), pkRoot, idxLeaf, p.hPrime, addr) + + n0 := state.forsNodeRec(idxLeaf, p.a, addr) + + n1 := make([]byte, p.n) + state.forsNodeIter(p.NewStack(p.a), n1, idxLeaf, p.a, addr) + + if !bytes.Equal(n0, n1) { + test.ReportError(t, n0, n1) + } + + var sig forsSignature + curSig := cursor(make([]byte, p.forsSigSize())) + sig.fromBytes(p, &curSig) + state.forsSign(sig, msg, addr) + + pkFors := state.forsPkFromSig(msg, sig, addr) + + var htSig hyperTreeSignature + curHtSig := cursor(make([]byte, p.hyperTreeSigSize())) + htSig.fromBytes(p, &curHtSig) + state.htSign(htSig, pkFors, idxTree, idxLeaf) + + valid := state.htVerify(pkFors, pkRoot, idxTree, idxLeaf, htSig) + test.CheckOk(valid, "hypertree signature verification failed", t) +} + +func benchmarkFors(b *testing.B, p *params) { + skSeed := mustRead(b, p.n) + pkSeed := mustRead(b, p.n) + msg := mustRead(b, p.forsMsgSize()) + + state := p.NewStatePriv(skSeed, pkSeed) + + addr := p.NewAddress() + addr.SetLayerAddress(p.d - 1) + + var sig forsSignature + curSig := cursor(make([]byte, p.forsSigSize())) + sig.fromBytes(p, &curSig) + state.forsSign(sig, msg, addr) + + b.Run("NodeRec", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = state.forsNodeRec(0, p.a, addr) + } + }) + b.Run("NodeIter", func(b *testing.B) { + node := make([]byte, p.n) + forsStack := p.NewStack(p.a) + for i := 0; i < b.N; i++ { + state.forsNodeIter(forsStack, node, 0, p.a, addr) + } + }) + b.Run("Sign", func(b *testing.B) { + for i := 0; i < b.N; i++ { + state.forsSign(sig, msg, addr) + } + }) + b.Run("PkFromSig", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = state.forsPkFromSig(msg, sig, addr) + } + }) +} diff --git a/sign/slhdsa/hypertree.go b/sign/slhdsa/hypertree.go new file mode 100644 index 000000000..e178e1a90 --- /dev/null +++ b/sign/slhdsa/hypertree.go @@ -0,0 +1,68 @@ +package slhdsa + +import "crypto/subtle" + +// See FIPS 205 -- Section 7 +// SLH-DSA uses a hypertree to sign the FORS keys. + +type hyperTreeSignature []xmssSignature // d*xmssSigSize() bytes + +func (p *params) hyperTreeSigSize() uint32 { return p.d * p.xmssSigSize() } + +func (hts *hyperTreeSignature) fromBytes(p *params, c *cursor) { + *hts = make([]xmssSignature, p.d) + for i := range *hts { + (*hts)[i].fromBytes(p, c) + } +} + +func nextIndex(idxTree *[3]uint32, n uint32) (idxLeaf uint32) { + idxLeaf = idxTree[0] & ((1 << n) - 1) + idxTree[0] = (idxTree[0] >> n) | (idxTree[1] << (32 - n)) + idxTree[1] = (idxTree[1] >> n) | (idxTree[2] << (32 - n)) + idxTree[2] = (idxTree[2] >> n) + + return +} + +// See FIPS 205 -- Section 7.1 -- Algorithm 12. +func (s *statePriv) htSign( + sig hyperTreeSignature, msg []byte, idxTree [3]uint32, idxLeaf uint32, +) { + addr := s.NewAddress() + addr.SetTreeAddress(idxTree) + stack := s.NewStack(s.hPrime - 1) + defer stack.Clear() + + s.xmssSign(stack, sig[0], msg, idxLeaf, addr) + + root := make([]byte, s.xmssPkSize()) + copy(root, msg) + for j := uint32(1); j < s.d; j++ { + s.xmssPkFromSig(root, root, sig[j-1], idxLeaf, addr) + idxLeaf = nextIndex(&idxTree, s.hPrime) + addr.SetLayerAddress(j) + addr.SetTreeAddress(idxTree) + s.xmssSign(stack, sig[j], root, idxLeaf, addr) + } +} + +// See FIPS 205 -- Section 7.2 -- Algorithm 13. +func (s *state) htVerify( + msg, root []byte, idxTree [3]uint32, idxLeaf uint32, sig hyperTreeSignature, +) bool { + addr := s.NewAddress() + addr.SetTreeAddress(idxTree) + + node := make([]byte, s.xmssPkSize()) + s.xmssPkFromSig(node, msg, sig[0], idxLeaf, addr) + + for j := uint32(1); j < s.d; j++ { + idxLeaf = nextIndex(&idxTree, s.hPrime) + addr.SetLayerAddress(j) + addr.SetTreeAddress(idxTree) + s.xmssPkFromSig(node, node, sig[j], idxLeaf, addr) + } + + return subtle.ConstantTimeCompare(node, root) == 1 +} diff --git a/sign/slhdsa/hypertree_test.go b/sign/slhdsa/hypertree_test.go new file mode 100644 index 000000000..a6548b66a --- /dev/null +++ b/sign/slhdsa/hypertree_test.go @@ -0,0 +1,60 @@ +package slhdsa + +import ( + "testing" + + "github.com/cloudflare/circl/internal/test" +) + +func testHyperTree(t *testing.T, p *params) { + skSeed := mustRead(t, p.n) + pkSeed := mustRead(t, p.n) + msg := mustRead(t, p.n) + + state := p.NewStatePriv(skSeed, pkSeed) + + idxTree := [3]uint32{0, 0, 0} + idxLeaf := uint32(0) + + addr := p.NewAddress() + addr.SetLayerAddress(p.d - 1) + stack := p.NewStack(p.hPrime) + pkRoot := make([]byte, p.n) + state.xmssNodeIter(stack, pkRoot, idxLeaf, p.hPrime, addr) + + var sig hyperTreeSignature + curSig := cursor(make([]byte, p.hyperTreeSigSize())) + sig.fromBytes(p, &curSig) + state.htSign(sig, msg, idxTree, idxLeaf) + + valid := state.htVerify(msg, pkRoot, idxTree, idxLeaf, sig) + test.CheckOk(valid, "hypertree signature verification failed", t) +} + +func benchmarkHyperTree(b *testing.B, p *params) { + skSeed := mustRead(b, p.n) + pkSeed := mustRead(b, p.n) + pkRoot := mustRead(b, p.n) + msg := mustRead(b, p.n) + + state := p.NewStatePriv(skSeed, pkSeed) + + idxTree := [3]uint32{0, 0, 0} + idxLeaf := uint32(0) + + var sig hyperTreeSignature + curSig := cursor(make([]byte, p.hyperTreeSigSize())) + sig.fromBytes(p, &curSig) + state.htSign(sig, msg, idxTree, idxLeaf) + + b.Run("Sign", func(b *testing.B) { + for i := 0; i < b.N; i++ { + state.htSign(sig, msg, idxTree, idxLeaf) + } + }) + b.Run("Verify", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = state.htVerify(msg, pkRoot, idxTree, idxLeaf, sig) + } + }) +} diff --git a/sign/slhdsa/internal.go b/sign/slhdsa/internal.go new file mode 100644 index 000000000..dec0b6a24 --- /dev/null +++ b/sign/slhdsa/internal.go @@ -0,0 +1,137 @@ +package slhdsa + +import "encoding/binary" + +// See FIPS 205 -- Section 9 +// SLH-DSA Internal Functions + +// See FIPS 205 -- Section 9.1 -- Algorithm 18. +func slhKeyGenInternal( + p *params, skSeed, skPrf, pkSeed []byte, +) (pub PublicKey, priv PrivateKey) { + s := p.NewStatePriv(skSeed, pkSeed) + defer s.Clear() + + stack := p.NewStack(p.hPrime) + defer stack.Clear() + + addr := p.NewAddress() + addr.SetLayerAddress(p.d - 1) + pkRoot := make([]byte, p.n) + s.xmssNodeIter(stack, pkRoot, 0, p.hPrime, addr) + + pub.ParamID = p.id + pub.seed = pkSeed + pub.root = pkRoot + + priv.ParamID = p.id + priv.prfKey = skPrf + priv.seed = skSeed + priv.publicKey = pub + + return +} + +func (p *params) parseMsg( + digest []byte, +) (md []byte, idxTree [3]uint32, idxLeaf uint32) { + l1 := (p.k*p.a + 7) / 8 + l2 := (p.h - p.h/p.d + 7) / 8 + l3 := (p.h + 8*p.d - 1) / (8 * p.d) + + c := cursor(digest) + md = c.Next(l1) + s2 := c.Next(l2) + s3 := c.Next(l3) + + var b2 [12]byte + copy(b2[12-len(s2):], s2) + n2 := p.h - p.h/p.d + idxTree[0] = binary.BigEndian.Uint32(b2[8:]) & ((1 << n2) - 1) + n2 -= 32 + idxTree[1] = binary.BigEndian.Uint32(b2[4:]) & ((1 << n2) - 1) + idxTree[2] = binary.BigEndian.Uint32(b2[0:]) + + var b3 [4]byte + copy(b3[4-len(s3):], s3) + mask32 := (uint32(1) << (p.h / p.d)) - 1 + idxLeaf = mask32 & binary.BigEndian.Uint32(b3[0:]) + + return +} + +// See FIPS 205 -- Section 9.2 -- Algorithm 19. +func slhSignInternal( + p *params, sk *PrivateKey, message, addRand []byte, +) ([]byte, error) { + sigBytes := make([]byte, p.SignatureSize()) + + var sig signature + curSig := cursor(sigBytes) + if !sig.fromBytes(p, &curSig) { + return nil, ErrSigParse + } + + p.PRFMsg(sig.rnd, sk.prfKey, addRand, message) + digest := make([]byte, p.m) + p.HashMsg(digest, sig.rnd, message, &sk.publicKey) + + md, idxTree, idxLeaf := p.parseMsg(digest) + addr := p.NewAddress() + addr.SetTreeAddress(idxTree) + addr.SetTypeAndClear(addressForsTree) + addr.SetKeyPairAddress(idxLeaf) + + s := p.NewStatePriv(sk.seed, sk.publicKey.seed) + defer s.Clear() + + s.forsSign(sig.forsSig, md, addr) + pkFors := s.forsPkFromSig(md, sig.forsSig, addr) + s.htSign(sig.htSig, pkFors, idxTree, idxLeaf) + + return sigBytes, nil +} + +// See FIPS 205 -- Section 9.3 -- Algorithm 20. +func slhVerifyInternal( + p *params, pub *PublicKey, message, sigBytes []byte, +) bool { + var sig signature + curSig := cursor(sigBytes) + if len(sigBytes) != int(p.SignatureSize()) || !sig.fromBytes(p, &curSig) { + return false + } + + digest := make([]byte, p.m) + p.HashMsg(digest, sig.rnd, message, pub) + + md, idxTree, idxLeaf := p.parseMsg(digest) + addr := p.NewAddress() + addr.SetTreeAddress(idxTree) + addr.SetTypeAndClear(addressForsTree) + addr.SetKeyPairAddress(idxLeaf) + + s := p.NewStatePub(pub.seed) + defer s.Clear() + + pkFors := s.forsPkFromSig(md, sig.forsSig, addr) + return s.htVerify(pkFors, pub.root, idxTree, idxLeaf, sig.htSig) +} + +// signature represents a SLH-DSA signature. +type signature struct { + rnd []byte // n bytes + forsSig forsSignature // forsSigSize() bytes + htSig hyperTreeSignature // hyperTreeSigSize() bytes +} + +func (p *params) SignatureSize() uint32 { + return p.n + p.forsSigSize() + p.hyperTreeSigSize() +} + +func (s *signature) fromBytes(p *params, c *cursor) bool { + s.rnd = c.Next(p.n) + s.forsSig.fromBytes(p, c) + s.htSig.fromBytes(p, c) + return len(*c) == 0 +} diff --git a/sign/slhdsa/internal_test.go b/sign/slhdsa/internal_test.go new file mode 100644 index 000000000..d3278d394 --- /dev/null +++ b/sign/slhdsa/internal_test.go @@ -0,0 +1,50 @@ +package slhdsa + +import ( + "testing" + + "github.com/cloudflare/circl/internal/test" +) + +func testInternal(t *testing.T, p *params) { + skSeed := mustRead(t, p.n) + skPrf := mustRead(t, p.n) + pkSeed := mustRead(t, p.n) + msg := mustRead(t, p.m) + addRand := mustRead(t, p.n) + + pk, sk := slhKeyGenInternal(p, skSeed, skPrf, pkSeed) + sig, err := slhSignInternal(p, &sk, msg, addRand) + test.CheckNoErr(t, err, "slhSignInternal failed") + + valid := slhVerifyInternal(p, &pk, msg, sig) + test.CheckOk(valid, "slhVerifyInternal failed", t) +} + +func benchmarkInternal(b *testing.B, p *params) { + skSeed := mustRead(b, p.n) + skPrf := mustRead(b, p.n) + pkSeed := mustRead(b, p.n) + msg := mustRead(b, p.m) + addRand := mustRead(b, p.n) + + pk, sk := slhKeyGenInternal(p, skSeed, skPrf, pkSeed) + sig, err := slhSignInternal(p, &sk, msg, addRand) + test.CheckNoErr(b, err, "slhSignInternal failed") + + b.Run("Keygen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = slhKeyGenInternal(p, skSeed, skPrf, pkSeed) + } + }) + b.Run("Sign", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = slhSignInternal(p, &sk, msg, addRand) + } + }) + b.Run("Verify", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = slhVerifyInternal(p, &pk, msg, sig) + } + }) +} diff --git a/sign/slhdsa/keys.go b/sign/slhdsa/keys.go new file mode 100644 index 000000000..8ac977ba5 --- /dev/null +++ b/sign/slhdsa/keys.go @@ -0,0 +1,144 @@ +package slhdsa + +import ( + "bytes" + "crypto" + "crypto/subtle" + + "github.com/cloudflare/circl/internal/conv" + "github.com/cloudflare/circl/sign" + "golang.org/x/crypto/cryptobyte" +) + +// [PrivateKey] stores a private key of the SLH-DSA scheme. +// It implements the [crypto.Signer] and [crypto.PrivateKey] interfaces. +// For serialization, it also implements [cryptobyte.MarshalingValue], +// [encoding.BinaryMarshaler], and [encoding.BinaryUnmarshaler]. +type PrivateKey struct { + ParamID ParamID + seed, prfKey []byte + publicKey PublicKey +} + +func (p *params) PrivateKeySize() uint32 { return 2*p.n + p.PublicKeySize() } + +// Marshal serializes the key using a Builder. +func (k *PrivateKey) Marshal(b *cryptobyte.Builder) error { + b.AddBytes(k.seed) + b.AddBytes(k.prfKey) + b.AddValue(&k.publicKey) + return nil +} + +// Unmarshal recovers a PrivateKey from a [cryptobyte.String]. Caller must +// specify the ParamID of the key in advance. +// Example: +// +// var key PrivateKey +// key.ParamID = ParamIDSHA2Small192 +// key.Unmarshal(str) // returns true +func (k *PrivateKey) Unmarshal(s *cryptobyte.String) bool { + params := k.ParamID.params() + var b []byte + if !s.ReadBytes(&b, int(params.PrivateKeySize())) { + return false + } + + c := cursor(b) + return k.fromBytes(params, &c) +} + +func (k *PrivateKey) fromBytes(p *params, c *cursor) bool { + k.ParamID = p.id + k.seed = c.Next(p.n) + k.prfKey = c.Next(p.n) + return k.publicKey.fromBytes(p, c) +} + +// UnmarshalBinary recovers a PrivateKey from a slice of bytes. Caller must +// specify the ParamID of the key in advance. +// Example: +// +// var key PrivateKey +// key.ParamID = ParamIDSHA2Small192 +// key.UnmarshalBinary(bytes) // returns nil +func (k *PrivateKey) UnmarshalBinary(b []byte) error { return conv.UnmarshalBinary(k, b) } +func (k *PrivateKey) MarshalBinary() ([]byte, error) { return conv.MarshalBinary(k) } +func (k *PrivateKey) Scheme() sign.Scheme { return k.ParamID } +func (k *PrivateKey) Public() crypto.PublicKey { pk := k.PublicKey(); return &pk } +func (k *PrivateKey) PublicKey() (out PublicKey) { + params := k.ParamID.params() + c := cursor(make([]byte, params.PublicKeySize())) + out.fromBytes(params, &c) + copy(out.seed, k.publicKey.seed) + copy(out.root, k.publicKey.root) + return +} + +func (k *PrivateKey) Equal(x crypto.PrivateKey) bool { + other, ok := x.(*PrivateKey) + return ok && k.ParamID == other.ParamID && + subtle.ConstantTimeCompare(k.seed, other.seed) == 1 && + subtle.ConstantTimeCompare(k.prfKey, other.prfKey) == 1 && + k.publicKey.Equal(&other.publicKey) +} + +// [PublicKey] stores a public key of the SLH-DSA scheme. +// It implements the [crypto.PublicKey] interface. +// For serialization, it also implements [cryptobyte.MarshalingValue], +// [encoding.BinaryMarshaler], and [encoding.BinaryUnmarshaler]. +type PublicKey struct { + ParamID ParamID + seed, root []byte +} + +func (p *params) PublicKeySize() uint32 { return 2 * p.n } + +// Marshal serializes the key using a Builder. +func (k *PublicKey) Marshal(b *cryptobyte.Builder) error { + b.AddBytes(k.seed) + b.AddBytes(k.root) + return nil +} + +// Unmarshal recovers a PublicKey from a [cryptobyte.String]. Caller must +// specify the ParamID of the key in advance. +// Example: +// +// var key PublicKey +// key.ParamID = ParamIDSHA2Small192 +// key.Unmarshal(str) // returns true +func (k *PublicKey) Unmarshal(s *cryptobyte.String) bool { + params := k.ParamID.params() + var b []byte + if !s.ReadBytes(&b, int(params.PublicKeySize())) { + return false + } + + c := cursor(b) + return k.fromBytes(params, &c) +} + +func (k *PublicKey) fromBytes(p *params, c *cursor) bool { + k.ParamID = p.id + k.seed = c.Next(p.n) + k.root = c.Next(p.n) + return len(*c) == 0 +} + +// UnmarshalBinary recovers a PublicKey from a slice of bytes. Caller must +// specify the ParamID of the key in advance. +// Example: +// +// var key PublicKey +// key.ParamID = ParamIDSHA2Small192 +// key.UnmarshalBinary(bytes) // returns nil +func (k *PublicKey) UnmarshalBinary(b []byte) error { return conv.UnmarshalBinary(k, b) } +func (k *PublicKey) MarshalBinary() ([]byte, error) { return conv.MarshalBinary(k) } +func (k *PublicKey) Scheme() sign.Scheme { return k.ParamID } +func (k *PublicKey) Equal(x crypto.PublicKey) bool { + other, ok := x.(*PublicKey) + return ok && k.ParamID == other.ParamID && + bytes.Equal(k.seed, other.seed) && + bytes.Equal(k.root, other.root) +} diff --git a/sign/slhdsa/message.go b/sign/slhdsa/message.go new file mode 100644 index 000000000..59dfad8aa --- /dev/null +++ b/sign/slhdsa/message.go @@ -0,0 +1,161 @@ +package slhdsa + +import ( + "bytes" + "crypto" + "crypto/sha256" + "crypto/sha512" + "io" + + "github.com/cloudflare/circl/internal/sha3" + "github.com/cloudflare/circl/xof" +) + +// [Message] wraps the message to be signed. +// It implements the [io.Writer] interface, so the message can be provided +// in chunks before calling the [PrivateKey.SignRandomized], +// [PrivateKey.SignDeterministic], or [Verify] functions. +// +// There are two cases depending on whether the message must be pre-hashed: +// - Hash Signing: Use [NewMessageWithPreHash] when the message is meant +// to be hashed before signing. The calls to [Message.Write] are +// directly passed to the specified pre-hash function. +// - Pure Signing. Use [NewMessage] or just create a [Message] variable, +// if the message must not be pre-hashed. +// Calling [NewMessageWithPreHash] with [NoPreHash] is equivalent. +// The calls to [Message.Write] copy the message into a internal buffer. +// To avoid copies of the message, use [NewMessage] instead. +type Message struct { + buffer bytes.Buffer + hasher interface { + io.Writer + SumIdempotent([]byte) + } + isPreHash bool + oid10 byte + outLen int +} + +// [NewMessage] wraps a message for signing, also known as pure signing. +// Use this function or just create a [Message] variable, if the message +// must not be pre-hashed. +// Calling [NewMessageWithPreHash] with [NoPreHash] is equivalent. +// The calls to [Message.Write] copy the message into a internal buffer. +// To avoid copies of the message, use [NewMessage] instead. +func NewMessage(msg []byte) (m Message) { + _ = m.init(NoPreHash, msg) + return +} + +// [NewMessageWithPreHash] wraps a message to be hashed before signing. +// The calls to [Message.Write] are directly passed to the specified +// pre-hash function. +// It returns an error if the pre-hash function is not supported. +func NewMessageWithPreHash(id PreHashID) (m Message, err error) { + err = m.init(id, nil) + return +} + +// Write allows to provide the message to be signed in chunks. +// Depending on how the receiver was generated, Write will either copy the +// chunks into an internal buffer, or pass them to the pre-hash function. +func (m *Message) Write(p []byte) (n int, err error) { + if m.isPreHash { + return m.hasher.Write(p) + } else { + return m.buffer.Write(p) + } +} + +func (m *Message) init(ph PreHashID, msg []byte) (err error) { + switch ph { + case NoPreHash: + m.isPreHash = false + m.buffer = *bytes.NewBuffer(msg) + case PreHashSHA256: + m.isPreHash = true + m.oid10 = 0x01 + m.outLen = crypto.SHA256.Size() + m.hasher = &sha2rw{sha256.New()} + case PreHashSHA512: + m.isPreHash = true + m.oid10 = 0x03 + m.outLen = crypto.SHA512.Size() + m.hasher = &sha2rw{sha512.New()} + case PreHashSHAKE128: + m.isPreHash = true + m.oid10 = 0x0B + m.outLen = 256 / 8 + m.hasher = &sha3rw{sha3.NewShake128()} + case PreHashSHAKE256: + m.isPreHash = true + m.oid10 = 0x0C + m.outLen = 512 / 8 + m.hasher = &sha3rw{sha3.NewShake256()} + default: + return ErrPreHash + } + + if m.isPreHash && msg != nil { + _, err = m.hasher.Write(msg) + } + + return err +} + +func (m *Message) getMsgPrime(context []byte) (msgPrime []byte, err error) { + // See FIPS 205 -- Section 10.2 -- Algorithm 23 and Algorithm 25. + if len(context) > MaxContextSize { + return nil, ErrContext + } + + msgPrime = append([]byte{0, byte(len(context))}, context...) + + var phMsg []byte + if !m.isPreHash { + msgPrime[0] = 0x0 + phMsg = m.buffer.Bytes() + } else { + msgPrime[0] = 0x1 + + oid := [11]byte{ + 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, + } + oid[10] = m.oid10 + msgPrime = append(msgPrime, oid[:]...) + + phMsg = make([]byte, m.outLen) + m.hasher.SumIdempotent(phMsg) + } + + return append(msgPrime, phMsg...), nil +} + +// PreHashID specifies a function for hashing the message before signing. +// The zero value is [NoPreHash] and stands for pure signing. +type PreHashID byte + +const ( + NoPreHash PreHashID = PreHashID(0) + PreHashSHA256 PreHashID = PreHashID(crypto.SHA256) + PreHashSHA512 PreHashID = PreHashID(crypto.SHA512) + PreHashSHAKE128 PreHashID = PreHashID(xof.SHAKE128) + PreHashSHAKE256 PreHashID = PreHashID(xof.SHAKE256) +) + +func (id PreHashID) String() string { + switch id { + case NoPreHash: + return "NoPreHash" + case PreHashSHA256: + return "PreHashSHA256" + case PreHashSHA512: + return "PreHashSHA512" + case PreHashSHAKE128: + return "PreHashSHAKE128" + case PreHashSHAKE256: + return "PreHashSHAKE256" + default: + return ErrPreHash.Error() + } +} diff --git a/sign/slhdsa/message_test.go b/sign/slhdsa/message_test.go new file mode 100644 index 000000000..c308196d5 --- /dev/null +++ b/sign/slhdsa/message_test.go @@ -0,0 +1,84 @@ +package slhdsa + +import ( + "bytes" + "testing" + + "github.com/cloudflare/circl/internal/test" +) + +func TestMessagePreHash(t *testing.T) { + const N = 128 + context := []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF} + + for _, ph := range []PreHashID{ + NoPreHash, + PreHashSHA256, + PreHashSHA512, + PreHashSHAKE128, + PreHashSHAKE256, + } { + var msg []byte + m0, err := NewMessageWithPreHash(ph) + test.CheckNoErr(t, err, "NewMessageWithPreHash failed") + + for i := byte(0); i < N; i++ { + _, errWrite := m0.Write([]byte{i}) + test.CheckNoErr(t, errWrite, "Write failed") + msg = append(msg, i) + } + + got, err := m0.getMsgPrime(context) + test.CheckNoErr(t, err, "getMsgPrime failed") + + m1, err := NewMessageWithPreHash(ph) + test.CheckNoErr(t, err, "NewMessageWithPreHash failed") + + _, err = m1.Write(msg) + test.CheckNoErr(t, err, "Write failed") + + want, err := m1.getMsgPrime(context) + test.CheckNoErr(t, err, "getMsgPrime failed") + + if !bytes.Equal(got, want) { + test.ReportError(t, got, want, ph) + } + } +} + +func TestMessageNoPreHash(t *testing.T) { + const N = 128 + context := []byte("context string") + + var msg []byte + var m0 Message + for i := byte(0); i < N; i++ { + _, errWrite := m0.Write([]byte{i}) + test.CheckNoErr(t, errWrite, "Write failed") + msg = append(msg, i) + } + + got, err := m0.getMsgPrime(context) + test.CheckNoErr(t, err, "getMsgPrime failed") + + m1 := NewMessage(msg) + want, err := m1.getMsgPrime(context) + test.CheckNoErr(t, err, "getMsgPrime failed") + + if !bytes.Equal(got, want) { + test.ReportError(t, got, want) + } + + m2, err := NewMessageWithPreHash(NoPreHash) + test.CheckNoErr(t, err, "NewMessageWithPreHash failed") + + _, err = m2.Write(msg) + test.CheckNoErr(t, err, "Write failed") + + want, err = m2.getMsgPrime(context) + test.CheckNoErr(t, err, "getMsgPrime failed") + + if !bytes.Equal(got, want) { + test.ReportError(t, got, want) + } +} diff --git a/sign/slhdsa/params.go b/sign/slhdsa/params.go new file mode 100644 index 000000000..66da1e6a8 --- /dev/null +++ b/sign/slhdsa/params.go @@ -0,0 +1,207 @@ +package slhdsa + +import ( + "crypto" + "crypto/hmac" + "crypto/sha256" + "crypto/sha512" + "encoding/binary" + "hash" + "io" + "strings" + + "github.com/cloudflare/circl/internal/sha3" + "github.com/cloudflare/circl/sign" +) + +// [ParamID] identifies the supported parameter sets of SLH-DSA. +// Note that the zero value is an invalid identifier. +// [ParamID] with a valid identifier also implements the [sign.Scheme] +// interface, but invalid identifiers cause methods panic. +type ParamID byte + +const ( + ParamIDSHA2Small128 ParamID = iota + 1 // SLH-DSA-SHA2-128s + ParamIDSHAKESmall128 // SLH-DSA-SHAKE-128s + ParamIDSHA2Fast128 // SLH-DSA-SHA2-128f + ParamIDSHAKEFast128 // SLH-DSA-SHAKE-128f + ParamIDSHA2Small192 // SLH-DSA-SHA2-192s + ParamIDSHAKESmall192 // SLH-DSA-SHAKE-192s + ParamIDSHA2Fast192 // SLH-DSA-SHA2-192f + ParamIDSHAKEFast192 // SLH-DSA-SHAKE-192f + ParamIDSHA2Small256 // SLH-DSA-SHA2-256s + ParamIDSHAKESmall256 // SLH-DSA-SHAKE-256s + ParamIDSHA2Fast256 // SLH-DSA-SHA2-256f + ParamIDSHAKEFast256 // SLH-DSA-SHAKE-256f + _MaxParams +) + +// [ParamIDByName] returns the [ParamID] that corresponds to the given name, +// or an error if no parameter set was found. +// See [ParamID] documentation for the specific names of each parameter set. +// Names are case insensitive. +// +// Example: +// +// ParamIDByName("SLH-DSA-SHAKE-256s") // returns (ParamIDSHAKESmall256, nil) +func ParamIDByName(name string) (id ParamID, err error) { + v := strings.ToLower(name) + for i := range supportedParams { + if strings.ToLower(supportedParams[i].name) == v { + return supportedParams[i].id, nil + } + } + + return id, ErrParam +} + +// IsValid returns true if the parameter set is supported. +func (id ParamID) IsValid() bool { return 0 < id && id < _MaxParams } +func (id ParamID) Name() string { return id.String() } +func (id ParamID) PublicKeySize() int { return int(id.params().PublicKeySize()) } +func (id ParamID) PrivateKeySize() int { return int(id.params().PrivateKeySize()) } +func (id ParamID) SignatureSize() int { return int(id.params().SignatureSize()) } +func (id ParamID) SeedSize() int { return id.PrivateKeySize() } +func (id ParamID) SupportsContext() bool { return true } +func (id ParamID) String() string { + if !id.IsValid() { + return ErrParam.Error() + } + return supportedParams[id-1].name +} + +func (id ParamID) params() *params { + if !id.IsValid() { + panic(ErrParam) + } + return &supportedParams[id-1] +} + +func (id ParamID) UnmarshalBinaryPublicKey(b []byte) (sign.PublicKey, error) { + k := PublicKey{ParamID: id} + err := k.UnmarshalBinary(b) + if err != nil { + return nil, err + } + return &k, nil +} + +func (id ParamID) UnmarshalBinaryPrivateKey(b []byte) (sign.PrivateKey, error) { + k := PrivateKey{ParamID: id} + err := k.UnmarshalBinary(b) + if err != nil { + return nil, err + } + return &k, nil +} + +// params contains all the relevant constants of a parameter set. +type params struct { + n uint32 // Length of WOTS+ messages. + hPrime uint32 // XMSS Merkle tree height. + h uint32 // Total height of a hypertree. + d uint32 // Hypertree has d layers of XMSS trees. + a uint32 // FORS signs a-bit messages. + k uint32 // FORS generates k private keys. + m uint32 // Used by HashMSG function. + isSHA2 bool // True, if the hash function is SHA2, otherwise is SHAKE. + name string // Name of the parameter set. + id ParamID // Identifier of the parameter set. +} + +// Stores all the supported (read-only) parameter sets. +var supportedParams = [_MaxParams - 1]params{ + {id: ParamIDSHA2Small128, n: 16, h: 63, d: 7, hPrime: 9, a: 12, k: 14, m: 30, isSHA2: true, name: "SLH-DSA-SHA2-128s"}, + {id: ParamIDSHAKESmall128, n: 16, h: 63, d: 7, hPrime: 9, a: 12, k: 14, m: 30, isSHA2: false, name: "SLH-DSA-SHAKE-128s"}, + {id: ParamIDSHA2Fast128, n: 16, h: 66, d: 22, hPrime: 3, a: 6, k: 33, m: 34, isSHA2: true, name: "SLH-DSA-SHA2-128f"}, + {id: ParamIDSHAKEFast128, n: 16, h: 66, d: 22, hPrime: 3, a: 6, k: 33, m: 34, isSHA2: false, name: "SLH-DSA-SHAKE-128f"}, + {id: ParamIDSHA2Small192, n: 24, h: 63, d: 7, hPrime: 9, a: 14, k: 17, m: 39, isSHA2: true, name: "SLH-DSA-SHA2-192s"}, + {id: ParamIDSHAKESmall192, n: 24, h: 63, d: 7, hPrime: 9, a: 14, k: 17, m: 39, isSHA2: false, name: "SLH-DSA-SHAKE-192s"}, + {id: ParamIDSHA2Fast192, n: 24, h: 66, d: 22, hPrime: 3, a: 8, k: 33, m: 42, isSHA2: true, name: "SLH-DSA-SHA2-192f"}, + {id: ParamIDSHAKEFast192, n: 24, h: 66, d: 22, hPrime: 3, a: 8, k: 33, m: 42, isSHA2: false, name: "SLH-DSA-SHAKE-192f"}, + {id: ParamIDSHA2Small256, n: 32, h: 64, d: 8, hPrime: 8, a: 14, k: 22, m: 47, isSHA2: true, name: "SLH-DSA-SHA2-256s"}, + {id: ParamIDSHAKESmall256, n: 32, h: 64, d: 8, hPrime: 8, a: 14, k: 22, m: 47, isSHA2: false, name: "SLH-DSA-SHAKE-256s"}, + {id: ParamIDSHA2Fast256, n: 32, h: 68, d: 17, hPrime: 4, a: 9, k: 35, m: 49, isSHA2: true, name: "SLH-DSA-SHA2-256f"}, + {id: ParamIDSHAKEFast256, n: 32, h: 68, d: 17, hPrime: 4, a: 9, k: 35, m: 49, isSHA2: false, name: "SLH-DSA-SHAKE-256f"}, +} + +// See FIPS-205, Section 11.1 and Section 11.2. +func (p *params) PRFMsg(out, skPrf, optRand, msg []byte) { + if p.isSHA2 { + var h crypto.Hash + if p.n == 16 { + h = crypto.SHA256 + } else { + h = crypto.SHA512 + } + + mac := hmac.New(h.New, skPrf) + concat(mac, optRand, msg) + mac.Sum(out[:0]) + } else { + state := sha3.NewShake256() + concat(&state, skPrf, optRand, msg) + _, _ = state.Read(out) + } +} + +// See FIPS-205, Section 11.1 and Section 11.2. +func (p *params) HashMsg(out, r, msg []byte, pk *PublicKey) { + if p.isSHA2 { + var hLen uint32 + var state hash.Hash + if p.n == 16 { + hLen = sha256.Size + state = sha256.New() + } else { + hLen = sha512.Size + state = sha512.New() + } + + mgfSeed := make([]byte, 2*p.n+hLen+4) + c := cursor(mgfSeed) + copy(c.Next(p.n), r) + copy(c.Next(p.n), pk.seed) + sumInter := c.Next(hLen) + + concat(state, r, pk.seed, pk.root, msg) + state.Sum(sumInter[:0]) + p.mgf1(out, mgfSeed, p.m) + } else { + state := sha3.NewShake256() + concat(&state, r, pk.seed, pk.root, msg) + _, _ = state.Read(out) + } +} + +// MGF1 described in Appendix B.2.1 of RFC 8017. +func (p *params) mgf1(out, mgfSeed []byte, maskLen uint32) { + var hLen uint32 + var hashFn func(out, in []byte) + if p.n == 16 { + hLen = sha256.Size + hashFn = sha256sum + } else { + hLen = sha512.Size + hashFn = sha512sum + } + + offset := uint32(0) + end := (maskLen + hLen - 1) / hLen + counterBytes := mgfSeed[len(mgfSeed)-4:] + + for counter := uint32(0); counter < end; counter++ { + binary.BigEndian.PutUint32(counterBytes, counter) + hashFn(out[offset:], mgfSeed) + offset += hLen + } +} + +func concat(w io.Writer, list ...[]byte) { + for i := range list { + _, err := w.Write(list[i]) + if err != nil { + panic(ErrWriting) + } + } +} diff --git a/sign/slhdsa/slhdsa.go b/sign/slhdsa/slhdsa.go new file mode 100644 index 000000000..aa1c6cc69 --- /dev/null +++ b/sign/slhdsa/slhdsa.go @@ -0,0 +1,282 @@ +// Package slhdsa provides Stateless Hash-based Digital Signature Algorithm. +// +// This package is compliant with [FIPS 205] and the [ParamID] represents +// the following parameter sets: +// +// Category 1 +// - Based on SHA2: [ParamIDSHA2Small128] and [ParamIDSHA2Fast128]. +// - Based on SHAKE: [ParamIDSHAKESmall128] and [ParamIDSHAKEFast128]. +// +// Category 3 +// - Based on SHA2: [ParamIDSHA2Small192] and [ParamIDSHA2Fast192] +// - Based on SHAKE: [ParamIDSHAKESmall192] and [ParamIDSHAKEFast192] +// +// Category 5 +// - Based on SHA2: [ParamIDSHA2Small256] and [ParamIDSHA2Fast256]. +// - Based on SHAKE: [ParamIDSHAKESmall256] and [ParamIDSHAKEFast256]. +// +// [FIPS 205]: https://doi.org/10.6028/NIST.FIPS.205 +package slhdsa + +import ( + "bytes" + "crypto" + "crypto/rand" + "errors" + "fmt" + "io" + + "github.com/cloudflare/circl/internal/sha3" + "github.com/cloudflare/circl/sign" +) + +// [MaxContextSize] is the maximum byte length of a context for signing. +const MaxContextSize = 255 + +// GenerateKey returns a pair of keys using the parameter set specified. +// It returns an error if it fails reading from the random source. +func GenerateKey( + random io.Reader, id ParamID, +) (pub PublicKey, priv PrivateKey, err error) { + // See FIPS 205 -- Section 10.1 -- Algorithm 21. + params := id.params() + + var skPrf, skSeed, pkSeed []byte + skSeed, err = readRandom(random, params.n) + if err != nil { + return + } + + skPrf, err = readRandom(random, params.n) + if err != nil { + return + } + + pkSeed, err = readRandom(random, params.n) + if err != nil { + return + } + + pub, priv = slhKeyGenInternal(params, skSeed, skPrf, pkSeed) + + return +} + +// GenerateKey is similar to [GenerateKey] function, except it always reads +// random bytes from [rand.Reader]. +func (id ParamID) GenerateKey() (sign.PublicKey, sign.PrivateKey, error) { + pub, priv, err := GenerateKey(rand.Reader, id) + if err != nil { + return nil, nil, err + } + + return &pub, &priv, nil +} + +// Deterministically derives a pair of keys from a seed. If you're unsure, +// you're better off using [GenerateKey] function. +// +// Panics if seed is not of length [ParamID.SeedSize]. +func (id ParamID) DeriveKey(seed []byte) (sign.PublicKey, sign.PrivateKey) { + params := id.params() + if len(seed) != id.SeedSize() { + panic(sign.ErrSeedSize) + } + + m := make([]byte, 3*params.n) + if params.isSHA2 { + params.mgf1(m, seed, 3*params.n) + } else { + sha3.ShakeSum256(m, seed) + } + + pub, priv, err := GenerateKey(bytes.NewReader(m), id) + if err != nil { + return nil, nil + } + + return &pub, &priv +} + +// SignRandomized returns a random signature of the message with the +// specified context. +// It returns an error if it fails reading from the random source. +func (k *PrivateKey) SignRandomized( + random io.Reader, message *Message, context []byte, +) (signature []byte, err error) { + params := k.ParamID.params() + addRand, err := readRandom(random, params.n) + if err != nil { + return nil, err + } + + return k.doSign(message, context, addRand) +} + +// SignDeterministic returns the signature of the message with the +// specified context. +// It returns an error if it fails reading from the random source. +func (k *PrivateKey) SignDeterministic( + message *Message, context []byte, +) (signature []byte, err error) { + return k.doSign(message, context, k.publicKey.seed) +} + +func (k *PrivateKey) doSign(msg *Message, ctx, addRnd []byte) ([]byte, error) { + // See FIPS 205 -- Section 10.2 -- Algorithm 22. + params := k.ParamID.params() + msgPrime, err := msg.getMsgPrime(ctx) + if err != nil { + return nil, err + } + + return slhSignInternal(params, k, msgPrime, addRnd) +} + +// [PrivateKey.Sign] returns a signature of the message with the specified +// options. +// +// When options is a [SignatureOpts] struct, the signature is generated as +// specified by the options. Otherwise, options.HashFunc is used as the +// pre-hash function (allowing only SHA256 or SHA512). +// If options is nil, the message is not prehased, and a randomized +// signature with an empty context is generated. +// It returns an error if it fails reading from the random source. +func (k *PrivateKey) Sign( + random io.Reader, message []byte, options crypto.SignerOpts, +) (signature []byte, err error) { + var signOptions SignatureOpts + if options != nil { + switch options.HashFunc() { + case crypto.SHA256: + signOptions.PreHashID = PreHashSHA256 + case crypto.SHA512: + signOptions.PreHashID = PreHashSHA512 + } + + otherOptions, ok := options.(SignatureOpts) + if ok { + signOptions = otherOptions + } + } + + msg := new(Message) + err = msg.init(signOptions.PreHashID, message) + if err != nil { + return nil, err + } + + if signOptions.IsDeterministic { + return k.SignDeterministic(msg, signOptions.Context) + } else { + return k.SignRandomized(random, msg, signOptions.Context) + } +} + +// [ParamID.Sign] returns a randomized signature of the message with the +// specified options. +// This function never pre-hashes the message and uses the context provided +// in options. If options is nil, an empty context is used. +// It returns an empty slice if it fails reading from the random source. +// +// Panics if the key is not a [*PrivateKey] or mismatches with the ParamID. +func (id ParamID) Sign( + key sign.PrivateKey, message []byte, options *sign.SignatureOpts, +) (signature []byte) { + k, ok := key.(*PrivateKey) + if !ok || id != k.ParamID { + panic(sign.ErrTypeMismatch) + } + + var context []byte + if options != nil { + context = []byte(options.Context) + } + + msg := NewMessage(message) + signature, err := k.SignRandomized(rand.Reader, &msg, context) + if err != nil { + return nil + } + + return +} + +// [Verify] returns true if the signature of the message with the specified +// context is valid. +func Verify(key *PublicKey, message *Message, context, signature []byte) bool { + // See FIPS 205 -- Section 10.3 -- Algorithm 24. + params := key.ParamID.params() + msgPrime, err := message.getMsgPrime(context) + if err != nil { + return false + } + + return slhVerifyInternal(params, key, msgPrime, signature) +} + +// [Verify] returns true if the signature of the message with the specified +// context is valid. +// This function never pre-hashes the message and uses the context provided +// in options. If options is nil, an empty context is used. +// +// Panics if the key is not a [*PublicKey] or mismatches with the ParamID. +func (id ParamID) Verify( + key sign.PublicKey, message, signature []byte, options *sign.SignatureOpts, +) bool { + k, ok := key.(*PublicKey) + if !ok || id != k.ParamID { + panic(sign.ErrTypeMismatch) + } + + var context []byte + if options != nil { + context = []byte(options.Context) + } + + msg := NewMessage(message) + return Verify(k, &msg, context, signature) +} + +// [SignatureOpts] is used to specify the generation and verification +// procedure of signatures. +type SignatureOpts struct { + // When set to [NoPreHash] (the zero value), the signature is generated + // over the original message. + // Otherwise, it specifies the function used to pre-hash the message + // before signing. + PreHashID PreHashID + // A context of at most MaxContextSize bytes. + Context []byte + // True for deterministic signatures, false for randomized signatures. + IsDeterministic bool +} + +// HashFunc returns a [crypto.Hash] function only when the PreHashID field +// in the options corresponds to either SHA256 or SHA512. +// Otherwise, it returns the zero value. +func (s SignatureOpts) HashFunc() (h crypto.Hash) { + switch s.PreHashID { + case PreHashSHA256, PreHashSHA512: + h = crypto.Hash(s.PreHashID) + } + return +} + +func readRandom(random io.Reader, size uint32) (out []byte, err error) { + out = make([]byte, size) + if random == nil { + random = rand.Reader + } + _, err = io.ReadFull(random, out) + return +} + +var ( + ErrContext = fmt.Errorf("sign/slhdsa: context is larger than MaxContextSize=%v bytes", MaxContextSize) + ErrParam = errors.New("sign/slhdsa: invalid SLH-DSA parameter") + ErrPreHash = errors.New("sign/slhdsa: invalid prehash function") + ErrSigParse = errors.New("sign/slhdsa: failed to decode the signature") + ErrTree = errors.New("sign/slhdsa: invalid tree height or tree index") + ErrWriting = errors.New("sign/slhdsa: failed to write to a hash function") +) diff --git a/sign/slhdsa/slhdsa_test.go b/sign/slhdsa/slhdsa_test.go new file mode 100644 index 000000000..7f8101553 --- /dev/null +++ b/sign/slhdsa/slhdsa_test.go @@ -0,0 +1,167 @@ +package slhdsa_test + +import ( + "crypto/rand" + "testing" + + "github.com/cloudflare/circl/internal/sha3" + "github.com/cloudflare/circl/internal/test" + "github.com/cloudflare/circl/sign/slhdsa" +) + +var supportedParameters = [12]slhdsa.ParamID{ + slhdsa.ParamIDSHA2Small128, + slhdsa.ParamIDSHAKESmall128, + slhdsa.ParamIDSHA2Fast128, + slhdsa.ParamIDSHAKEFast128, + slhdsa.ParamIDSHA2Small192, + slhdsa.ParamIDSHAKESmall192, + slhdsa.ParamIDSHA2Fast192, + slhdsa.ParamIDSHAKEFast192, + slhdsa.ParamIDSHA2Small256, + slhdsa.ParamIDSHAKESmall256, + slhdsa.ParamIDSHA2Fast256, + slhdsa.ParamIDSHAKEFast256, +} + +var supportedPrehashIDs = [5]slhdsa.PreHashID{ + slhdsa.NoPreHash, + slhdsa.PreHashSHA256, + slhdsa.PreHashSHA512, + slhdsa.PreHashSHAKE128, + slhdsa.PreHashSHAKE256, +} + +func TestSlhdsa(t *testing.T) { + for i := range supportedParameters { + id := supportedParameters[i] + + t.Run(id.Name(), func(t *testing.T) { + t.Parallel() + + t.Run("Keys", func(t *testing.T) { testKeys(t, id) }) + + for j := range supportedPrehashIDs { + ph := supportedPrehashIDs[j] + msg := []byte("Alice and Bob") + ctx := []byte("this is a context string") + pub, priv, err := slhdsa.GenerateKey(rand.Reader, id) + test.CheckNoErr(t, err, "keygen failed") + + t.Run("Sign/"+ph.String(), func(t *testing.T) { + testSign(t, &pub, &priv, msg, ctx, ph) + }) + } + }) + } +} + +func testKeys(t *testing.T, id slhdsa.ParamID) { + reader := sha3.NewShake128() + + reader.Reset() + pub0, priv0, err := slhdsa.GenerateKey(&reader, id) + test.CheckNoErr(t, err, "GenerateKey failed") + + reader.Reset() + pub1, priv1, err := slhdsa.GenerateKey(&reader, id) + test.CheckNoErr(t, err, "GenerateKey failed") + + test.CheckOk(priv0.Equal(&priv1), "private key not equal", t) + test.CheckOk(pub0.Equal(&pub1), "public key not equal", t) + + test.CheckMarshal(t, &priv0, &priv1) + test.CheckMarshal(t, &pub0, &pub1) + + seed := make([]byte, id.SeedSize()) + pub2, priv2 := id.DeriveKey(seed) + pub3, priv3 := id.DeriveKey(seed) + + test.CheckOk(priv2.Equal(priv3), "private key not equal", t) + test.CheckOk(pub2.Equal(pub3), "public key not equal", t) +} + +func testSign( + t *testing.T, + pk *slhdsa.PublicKey, + sk *slhdsa.PrivateKey, + msg, ctx []byte, + ph slhdsa.PreHashID, +) { + m, err := slhdsa.NewMessageWithPreHash(ph) + test.CheckNoErr(t, err, "NewMessageWithPreHash failed") + + _, err = m.Write(msg) + test.CheckNoErr(t, err, "Write message failed") + + sig, err := sk.SignRandomized(rand.Reader, &m, ctx) + test.CheckNoErr(t, err, "SignRandomized failed") + + valid := slhdsa.Verify(pk, &m, ctx, sig) + test.CheckOk(valid, "Verify failed", t) + + sig, err = sk.SignDeterministic(&m, ctx) + test.CheckNoErr(t, err, "SignDeterministic failed") + + valid = slhdsa.Verify(pk, &m, ctx, sig) + test.CheckOk(valid, "Verify failed", t) +} + +func BenchmarkSlhdsa(b *testing.B) { + for i := range supportedParameters { + id := supportedParameters[i] + + b.Run(id.Name(), func(b *testing.B) { + b.Run("GenerateKey", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _, _ = slhdsa.GenerateKey(rand.Reader, id) + } + }) + + for j := range supportedPrehashIDs { + ph := supportedPrehashIDs[j] + msg := []byte("Alice and Bob") + ctx := []byte("this is a context string") + pub, priv, err := slhdsa.GenerateKey(rand.Reader, id) + test.CheckNoErr(b, err, "GenerateKey failed") + + b.Run(ph.String(), func(b *testing.B) { + benchmarkSign(b, &pub, &priv, msg, ctx, ph) + }) + } + }) + } +} + +func benchmarkSign( + b *testing.B, + pk *slhdsa.PublicKey, + sk *slhdsa.PrivateKey, + msg, ctx []byte, + ph slhdsa.PreHashID, +) { + m, err := slhdsa.NewMessageWithPreHash(ph) + test.CheckNoErr(b, err, "NewMessageWithPreHash failed") + + _, err = m.Write(msg) + test.CheckNoErr(b, err, "Write message failed") + + sig, err := sk.SignRandomized(rand.Reader, &m, ctx) + test.CheckNoErr(b, err, "SignRandomized failed") + + b.Run("SignRandomized", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = sk.SignRandomized(rand.Reader, &m, ctx) + } + }) + b.Run("SignDeterministic", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = sk.SignDeterministic(&m, ctx) + } + }) + b.Run("Verify", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = slhdsa.Verify(pk, &m, ctx, sig) + } + }) +} diff --git a/sign/slhdsa/state.go b/sign/slhdsa/state.go new file mode 100644 index 000000000..fa1df4c8f --- /dev/null +++ b/sign/slhdsa/state.go @@ -0,0 +1,355 @@ +package slhdsa + +import ( + "crypto/sha256" + "crypto/sha512" + "hash" + "io" + + "github.com/cloudflare/circl/internal/sha3" +) + +// statePriv encapsulates common data for performing a private operation. +type statePriv struct { + state + PRF statePRF +} + +func (s *statePriv) Size(p *params) uint32 { + return s.state.Size(p) + s.PRF.Size(p) +} + +func (p *params) NewStatePriv(skSeed, pkSeed []byte) (s statePriv) { + c := cursor(make([]byte, s.Size(p))) + s.state.init(p, &c, pkSeed) + s.PRF.Init(p, &c, skSeed, pkSeed) + + return +} + +func (s *statePriv) Clear() { + s.PRF.Clear() + s.state.Clear() +} + +// state encapsulates common data for performing a public operation. +type state struct { + *params + + F stateF + H stateH + T stateT +} + +func (s *state) Size(p *params) uint32 { + return s.F.Size(p) + s.H.Size(p) + s.T.Size(p) +} + +func (p *params) NewStatePub(pkSeed []byte) (s state) { + c := cursor(make([]byte, s.Size(p))) + s.init(p, &c, pkSeed) + + return +} + +func (s *state) init(p *params, c *cursor, pkSeed []byte) { + s.params = p + s.F.Init(p, c, pkSeed) + s.H.Init(p, c, pkSeed) + s.T.Init(p, c, pkSeed) +} + +func (s *state) Clear() { + s.F.Clear() + s.H.Clear() + s.T.Clear() + s.params = nil +} + +func sha256sum(out, in []byte) { s := sha256.Sum256(in); copy(out, s[:]) } +func sha512sum(out, in []byte) { s := sha512.Sum512(in); copy(out, s[:]) } + +type baseHasher struct { + input, output []byte + address + hash func(out, in []byte) +} + +func (b *baseHasher) Size(p *params) uint32 { + return p.n + p.addressSize() +} + +func (b *baseHasher) Clear() { + clearSlice(&b.input) + clearSlice(&b.output) + b.address.Clear() +} + +func (b *baseHasher) Final() []byte { + b.hash(b.output, b.input) + return b.output +} + +type statePRF struct{ baseHasher } + +func (s *statePRF) Init(p *params, cur *cursor, skSeed, pkSeed []byte) { + c := cursor(cur.Next(s.Size(p))) + s.output = c.Next(p.n) + s.input = c.Rest() + copy(c.Next(p.n), pkSeed) + _ = c.Next(s.padSize(p)) + s.address.fromBytes(p, &c) + copy(c.Next(p.n), skSeed) + + if p.isSHA2 { + s.hash = sha256sum + } else { + s.hash = sha3.ShakeSum256 + } +} + +func (s *statePRF) Size(p *params) uint32 { + return 2*p.n + s.padSize(p) + s.baseHasher.Size(p) +} + +func (s *statePRF) padSize(p *params) uint32 { + if p.isSHA2 { + return 64 - p.n + } else { + return 0 + } +} + +type stateF struct { + baseHasher + msg []byte +} + +func (s *stateF) Init(p *params, cur *cursor, pkSeed []byte) { + c := cursor(cur.Next(s.Size(p))) + s.output = c.Next(p.n) + s.input = c.Rest() + copy(c.Next(p.n), pkSeed) + _ = c.Next(s.padSize(p)) + s.address.fromBytes(p, &c) + s.msg = c.Next(p.n) + + if p.isSHA2 { + s.hash = sha256sum + } else { + s.hash = sha3.ShakeSum256 + } +} + +func (s *stateF) SetMessage(msg []byte) { copy(s.msg, msg) } + +func (s *stateF) Clear() { + s.baseHasher.Clear() + clearSlice(&s.msg) +} + +func (s *stateF) Size(p *params) uint32 { + return 2*p.n + s.padSize(p) + s.baseHasher.Size(p) +} + +func (s *stateF) padSize(p *params) uint32 { + if p.isSHA2 { + return 64 - p.n + } else { + return 0 + } +} + +type stateH struct { + baseHasher + msg0, msg1 []byte +} + +func (s *stateH) Init(p *params, cur *cursor, pkSeed []byte) { + c := cursor(cur.Next(s.Size(p))) + s.output = c.Next(p.n) + s.input = c.Rest() + copy(c.Next(p.n), pkSeed) + _ = c.Next(s.padSize(p)) + s.address.fromBytes(p, &c) + s.msg0 = c.Next(p.n) + s.msg1 = c.Next(p.n) + + if p.isSHA2 { + if p.n == 16 { + s.hash = sha256sum + } else { + s.hash = sha512sum + } + } else { + s.hash = sha3.ShakeSum256 + } +} + +func (s *stateH) SetMsgs(m0, m1 []byte) { + copy(s.msg0, m0) + copy(s.msg1, m1) +} + +func (s *stateH) Clear() { + s.baseHasher.Clear() + clearSlice(&s.msg0) + clearSlice(&s.msg1) +} + +func (s *stateH) Size(p *params) uint32 { + return 3*p.n + s.padSize(p) + s.baseHasher.Size(p) +} + +func (s *stateH) padSize(p *params) uint32 { + if p.isSHA2 { + if p.n == 16 { + return 64 - p.n + } else { + return 128 - p.n + } + } else { + return 0 + } +} + +type stateT struct { + input, output []byte + address + hash interface { + io.Writer + Reset() + Final([]byte) + } +} + +func (s *stateT) Init(p *params, cur *cursor, pkSeed []byte) { + c := cursor(cur.Next(s.Size(p))) + s.output = c.Next(s.outputSize(p))[:p.n] + s.input = c.Rest() + copy(c.Next(p.n), pkSeed) + _ = c.Next(s.padSize(p)) + s.address.fromBytes(p, &c) + + if p.isSHA2 { + if p.n == 16 { + s.hash = &sha2rw{sha256.New()} + } else { + s.hash = &sha2rw{sha512.New()} + } + } else { + s.hash = &sha3rw{sha3.NewShake256()} + } +} + +func (s *stateT) Clear() { + clearSlice(&s.input) + clearSlice(&s.output) + s.address.Clear() + s.hash.Reset() +} + +func (s *stateT) Reset() { + s.hash.Reset() + _, _ = s.hash.Write(s.input) +} + +func (s *stateT) WriteMessage(msg []byte) { _, _ = s.hash.Write(msg) } + +func (s *stateT) Final() []byte { + s.hash.Final(s.output) + return s.output +} + +func (s *stateT) Size(p *params) uint32 { + return s.outputSize(p) + s.padSize(p) + p.n + p.addressSize() +} + +func (s *stateT) outputSize(p *params) uint32 { + if p.isSHA2 { + if p.n == 16 { + return sha256.Size + } else { + return sha512.Size + } + } else { + return p.n + } +} + +func (s *stateT) padSize(p *params) uint32 { + if p.isSHA2 { + if p.n == 16 { + return 64 - p.n + } else { + return 128 - p.n + } + } else { + return 0 + } +} + +type sha2rw struct{ hash.Hash } + +func (s *sha2rw) Final(out []byte) { s.Sum(out[:0]) } +func (s *sha2rw) SumIdempotent(out []byte) { s.Sum(out[:0]) } + +type sha3rw struct{ sha3.State } + +func (s *sha3rw) Final(out []byte) { _, _ = s.Read(out) } +func (s *sha3rw) SumIdempotent(out []byte) { _, _ = s.Clone().Read(out) } + +type ( + item struct { + z uint32 + node []byte + } + stackNode []item +) + +func (p *params) NewStack(z uint32) stackNode { + s := make([]item, z) + c := cursor(make([]byte, z*p.n)) + for i := range s { + s[i].node = c.Next(p.n) + } + + return s[:0] +} + +func (s stackNode) isEmpty() bool { return len(s) == 0 } +func (s stackNode) top() item { return s[len(s)-1] } +func (s *stackNode) push(v item) { + next := len(*s) + *s = (*s)[:next+1] + (*s)[next].z = v.z + copy((*s)[next].node, v.node) +} + +func (s *stackNode) pop() (v item) { + last := len(*s) - 1 + if last >= 0 { + v = (*s)[last] + *s = (*s)[:last] + } + return +} + +func (s *stackNode) Clear() { + *s = (*s)[:cap(*s)] + for i := range *s { + clearSlice(&(*s)[i].node) + } + clear((*s)[:]) +} + +type cursor []byte + +func (s *cursor) Rest() []byte { return (*s)[:] } +func (s *cursor) Next(n uint32) (out []byte) { + out = (*s)[:n] + *s = (*s)[n:] + return +} + +func clearSlice(s *[]byte) { clear(*s); *s = nil } diff --git a/sign/slhdsa/testdata/keygen.json.zip b/sign/slhdsa/testdata/keygen.json.zip new file mode 100644 index 000000000..732d1f3f4 Binary files /dev/null and b/sign/slhdsa/testdata/keygen.json.zip differ diff --git a/sign/slhdsa/testdata/sigGen_prompt.json.zip b/sign/slhdsa/testdata/sigGen_prompt.json.zip new file mode 100644 index 000000000..732d2c907 Binary files /dev/null and b/sign/slhdsa/testdata/sigGen_prompt.json.zip differ diff --git a/sign/slhdsa/testdata/sigGen_results.json.zip b/sign/slhdsa/testdata/sigGen_results.json.zip new file mode 100644 index 000000000..dbb8ec463 Binary files /dev/null and b/sign/slhdsa/testdata/sigGen_results.json.zip differ diff --git a/sign/slhdsa/testdata/verify_prompt.json.zip b/sign/slhdsa/testdata/verify_prompt.json.zip new file mode 100644 index 000000000..70f057f3d Binary files /dev/null and b/sign/slhdsa/testdata/verify_prompt.json.zip differ diff --git a/sign/slhdsa/testdata/verify_results.json.zip b/sign/slhdsa/testdata/verify_results.json.zip new file mode 100644 index 000000000..adcb69058 Binary files /dev/null and b/sign/slhdsa/testdata/verify_results.json.zip differ diff --git a/sign/slhdsa/wotsp.go b/sign/slhdsa/wotsp.go new file mode 100644 index 000000000..68f7ac569 --- /dev/null +++ b/sign/slhdsa/wotsp.go @@ -0,0 +1,125 @@ +package slhdsa + +// See FIPS 205 -- Section 5 +// Winternitz One-Time Signature Plus Scheme + +const ( + wotsW = 16 // wotsW is w = 2^lg_w, where lg_w = 4. + wotsLen2 = 3 // wotsLen2 is len_2 fixed to 3. +) + +type ( + wotsPublicKey []byte // n bytes + wotsSignature []byte // wotsLen()*n bytes +) + +func (p *params) wotsSigSize() uint32 { return p.wotsLen() * p.n } +func (p *params) wotsLen() uint32 { return p.wotsLen1() + wotsLen2 } +func (p *params) wotsLen1() uint32 { return 2 * p.n } + +func (ws *wotsSignature) fromBytes(p *params, c *cursor) { + *ws = c.Next(p.wotsSigSize()) +} + +// See FIPS 205 -- Section 5 -- Algorithm 5. +func (s *state) chain(x []byte, index, step uint32, addr address) (out []byte) { + out = x + s.F.address.Set(addr) + for j := index; j < index+step; j++ { + s.F.address.SetHashAddress(j) + s.F.SetMessage(out) + out = s.F.Final() + } + return +} + +// See FIPS 205 -- Section 5.1 -- Algorithm 6. +func (s *statePriv) wotsPkGen(addr address) wotsPublicKey { + s.PRF.address.Set(addr) + s.PRF.address.SetTypeAndClear(addressWotsPrf) + s.PRF.address.SetKeyPairAddress(addr.GetKeyPairAddress()) + + s.T.address.Set(addr) + s.T.address.SetTypeAndClear(addressWotsPk) + s.T.address.SetKeyPairAddress(addr.GetKeyPairAddress()) + + s.T.Reset() + wotsLen := s.wotsLen() + for i := uint32(0); i < wotsLen; i++ { + s.PRF.address.SetChainAddress(i) + sk := s.PRF.Final() + + addr.SetChainAddress(i) + tmpi := s.chain(sk, 0, wotsW-1, addr) + + s.T.WriteMessage(tmpi) + } + + return s.T.Final() +} + +// See FIPS 205 -- Section 5.2 -- Algorithm 7. +func (s *statePriv) wotsSign(sig wotsSignature, msg []byte, addr address) { + curSig := cursor(sig) + wotsLen1 := s.wotsLen1() + csum := wotsLen1 * (wotsW - 1) + + s.PRF.address.Set(addr) + s.PRF.address.SetTypeAndClear(addressWotsPrf) + s.PRF.address.SetKeyPairAddress(addr.GetKeyPairAddress()) + + for i := uint32(0); i < wotsLen1; i++ { + s.PRF.address.SetChainAddress(i) + sk := s.PRF.Final() + + addr.SetChainAddress(i) + msgi := uint32((msg[i/2] >> ((1 - (i & 1)) << 2)) & 0xF) + sigi := s.chain(sk, 0, msgi, addr) + copy(curSig.Next(s.n), sigi) + csum -= msgi + } + + for i := uint32(0); i < wotsLen2; i++ { + s.PRF.address.SetChainAddress(wotsLen1 + i) + sk := s.PRF.Final() + + addr.SetChainAddress(wotsLen1 + i) + csumi := (csum >> (8 - 4*i)) & 0xF + sigi := s.chain(sk, 0, csumi, addr) + copy(curSig.Next(s.n), sigi) + } +} + +// See FIPS 205 -- Section 5.3 -- Algorithm 8. +func (s *state) wotsPkFromSig( + sig wotsSignature, msg []byte, addr address, +) wotsPublicKey { + wotsLen1 := s.wotsLen1() + csum := wotsLen1 * (wotsW - 1) + + s.T.address.Set(addr) + s.T.address.SetTypeAndClear(addressWotsPk) + s.T.address.SetKeyPairAddress(addr.GetKeyPairAddress()) + + s.T.Reset() + curSig := cursor(sig) + + for i := uint32(0); i < wotsLen1; i++ { + addr.SetChainAddress(i) + msgi := uint32((msg[i/2] >> ((1 - (i & 1)) << 2)) & 0xF) + sigi := s.chain(curSig.Next(s.n), msgi, wotsW-1-msgi, addr) + + s.T.WriteMessage(sigi) + csum -= msgi + } + + for i := uint32(0); i < wotsLen2; i++ { + addr.SetChainAddress(wotsLen1 + i) + csumi := (csum >> (8 - 4*i)) & 0xF + sigi := s.chain(curSig.Next(s.n), csumi, wotsW-1-csumi, addr) + + s.T.WriteMessage(sigi) + } + + return s.T.Final() +} diff --git a/sign/slhdsa/wotsp_test.go b/sign/slhdsa/wotsp_test.go new file mode 100644 index 000000000..7e97443fb --- /dev/null +++ b/sign/slhdsa/wotsp_test.go @@ -0,0 +1,64 @@ +package slhdsa + +import ( + "bytes" + "testing" + + "github.com/cloudflare/circl/internal/test" +) + +func testWotsPlus(t *testing.T, p *params) { + skSeed := mustRead(t, p.n) + pkSeed := mustRead(t, p.n) + msg := mustRead(t, p.n) + + state := p.NewStatePriv(skSeed, pkSeed) + + addr := p.NewAddress() + addr.SetTypeAndClear(addressWotsHash) + + pk0 := state.wotsPkGen(addr) + + var sig wotsSignature + curSig := cursor(make([]byte, p.wotsSigSize())) + sig.fromBytes(p, &curSig) + state.wotsSign(sig, msg, addr) + + pk1 := state.wotsPkFromSig(sig, msg, addr) + + if !bytes.Equal(pk0, pk1) { + test.ReportError(t, pk0, pk1, skSeed, pkSeed, msg) + } +} + +func benchmarkWotsPlus(b *testing.B, p *params) { + skSeed := mustRead(b, p.n) + pkSeed := mustRead(b, p.n) + msg := mustRead(b, p.n) + + state := p.NewStatePriv(skSeed, pkSeed) + + addr := p.NewAddress() + addr.SetTypeAndClear(addressWotsHash) + + var sig wotsSignature + curSig := cursor(make([]byte, p.wotsSigSize())) + sig.fromBytes(p, &curSig) + state.wotsSign(sig, msg, addr) + + b.Run("PkGen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = state.wotsPkGen(addr) + } + }) + b.Run("Sign", func(b *testing.B) { + for i := 0; i < b.N; i++ { + state.wotsSign(sig, msg, addr) + } + }) + b.Run("PkFromSig", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = state.wotsPkFromSig(sig, msg, addr) + } + }) +} diff --git a/sign/slhdsa/xmss.go b/sign/slhdsa/xmss.go new file mode 100644 index 000000000..074315282 --- /dev/null +++ b/sign/slhdsa/xmss.go @@ -0,0 +1,107 @@ +package slhdsa + +// See FIPS 205 -- Section 6 +// eXtended Merkle Signature Scheme (XMSS) extends the WOTS+ signature +// scheme into one that can sign multiple messages. + +type ( + xmssPublicKey []byte // n bytes + xmssSignature struct { + wotsSig wotsSignature // wotsSigSize() bytes + authPath []byte // hPrime*n bytes + } // wotsSigSize() + hPrime*n bytes +) + +func (p *params) xmssPkSize() uint32 { return p.n } +func (p *params) xmssAuthPathSize() uint32 { return p.hPrime * p.n } +func (p *params) xmssSigSize() uint32 { + return p.wotsSigSize() + p.xmssAuthPathSize() +} + +func (xs *xmssSignature) fromBytes(p *params, c *cursor) { + xs.wotsSig.fromBytes(p, c) + xs.authPath = c.Next(p.xmssAuthPathSize()) +} + +// See FIPS 205 -- Section 6.1 -- Algorithm 9 -- Iterative version. +func (s *statePriv) xmssNodeIter( + stack stackNode, root []byte, i, z uint32, addr address, +) { + if !(z <= s.hPrime && i < (1<<(s.hPrime-z))) { + panic(ErrTree) + } + + s.H.address.Set(addr) + s.H.address.SetTypeAndClear(addressTree) + + twoZ := uint32(1) << z + iTwoZ := i << z + for k := uint32(0); k < twoZ; k++ { + li := iTwoZ + k + lz := uint32(0) + + addr.SetTypeAndClear(addressWotsHash) + addr.SetKeyPairAddress(li) + node := s.wotsPkGen(addr) + + for !stack.isEmpty() && stack.top().z == lz { + left := stack.pop() + li = (li - 1) >> 1 + lz = lz + 1 + + s.H.address.SetTreeHeight(lz) + s.H.address.SetTreeIndex(li) + s.H.SetMsgs(left.node, node) + node = s.H.Final() + } + + stack.push(item{lz, node}) + } + + copy(root, stack.pop().node) +} + +// See FIPS 205 -- Section 6.2 -- Algorithm 10. +func (s *statePriv) xmssSign( + stack stackNode, sig xmssSignature, msg []byte, idx uint32, addr address, +) { + authPath := cursor(sig.authPath) + for j := uint32(0); j < s.hPrime; j++ { + k := (idx >> j) ^ 1 + s.xmssNodeIter(stack, authPath.Next(s.n), k, j, addr) + } + + addr.SetTypeAndClear(addressWotsHash) + addr.SetKeyPairAddress(idx) + s.wotsSign(sig.wotsSig, msg, addr) +} + +// See FIPS 205 -- Section 6.3 -- Algorithm 11. +func (s *state) xmssPkFromSig( + out xmssPublicKey, msg []byte, sig xmssSignature, idx uint32, addr address, +) { + addr.SetTypeAndClear(addressWotsHash) + addr.SetKeyPairAddress(idx) + pk := xmssPublicKey(s.wotsPkFromSig(sig.wotsSig, msg, addr)) + + treeIdx := idx + s.H.address.Set(addr) + s.H.address.SetTypeAndClear(addressTree) + + authPath := cursor(sig.authPath) + for k := uint32(0); k < s.hPrime; k++ { + if (idx>>k)&0x1 == 0 { + treeIdx = treeIdx >> 1 + s.H.SetMsgs(pk, authPath.Next(s.n)) + } else { + treeIdx = (treeIdx - 1) >> 1 + s.H.SetMsgs(authPath.Next(s.n), pk) + } + + s.H.address.SetTreeHeight(k + 1) + s.H.address.SetTreeIndex(treeIdx) + pk = s.H.Final() + } + + copy(out, pk) +} diff --git a/sign/slhdsa/xmss_test.go b/sign/slhdsa/xmss_test.go new file mode 100644 index 000000000..93ff1a907 --- /dev/null +++ b/sign/slhdsa/xmss_test.go @@ -0,0 +1,116 @@ +package slhdsa + +import ( + "bytes" + "fmt" + "testing" + + "github.com/cloudflare/circl/internal/test" +) + +// See FIPS 205 -- Section 6.1 -- Algorithm 9 -- Recursive version. +func (s *statePriv) xmssNodeRec(i, z uint32, addr address) (node []byte) { + if !(z <= s.hPrime && i < (1<<(s.hPrime-z))) { + panic(ErrTree) + } + + node = make([]byte, s.n) + if z == 0 { + addr.SetTypeAndClear(addressWotsHash) + addr.SetKeyPairAddress(i) + copy(node, s.wotsPkGen(addr)) + } else { + lnode := s.xmssNodeRec(2*i, z-1, addr) + rnode := s.xmssNodeRec(2*i+1, z-1, addr) + + s.H.address.Set(addr) + s.H.address.SetTypeAndClear(addressTree) + s.H.address.SetTreeHeight(z) + s.H.address.SetTreeIndex(i) + s.H.SetMsgs(lnode, rnode) + copy(node, s.H.Final()) + } + + return +} + +func testXmss(t *testing.T, p *params) { + skSeed := mustRead(t, p.n) + pkSeed := mustRead(t, p.n) + msg := mustRead(t, p.n) + + state := p.NewStatePriv(skSeed, pkSeed) + + addr := p.NewAddress() + addr.SetTypeAndClear(addressWotsHash) + idx := uint32(0) + + rootRec := state.xmssNodeRec(idx, p.hPrime, addr) + test.CheckOk( + len(rootRec) == int(p.n), + fmt.Sprintf("bad xmss rootRec length: %v", len(rootRec)), + t, + ) + + stack := p.NewStack(p.hPrime) + rootIter := make([]byte, p.n) + state.xmssNodeIter(stack, rootIter, idx, p.hPrime, addr) + + if !bytes.Equal(rootRec, rootIter) { + test.ReportError(t, rootRec, rootIter, skSeed, pkSeed, msg) + } + + var sig xmssSignature + curSig := cursor(make([]byte, p.xmssSigSize())) + sig.fromBytes(p, &curSig) + state.xmssSign(stack, sig, msg, idx, addr) + + node := make([]byte, p.xmssPkSize()) + state.xmssPkFromSig(node, msg, sig, idx, addr) + + if !bytes.Equal(rootRec, node) { + test.ReportError(t, rootRec, node, skSeed, pkSeed, msg) + } +} + +func benchmarkXmss(b *testing.B, p *params) { + skSeed := mustRead(b, p.n) + pkSeed := mustRead(b, p.n) + msg := mustRead(b, p.n) + + state := p.NewStatePriv(skSeed, pkSeed) + + addr := p.NewAddress() + addr.SetTypeAndClear(addressWotsHash) + idx := uint32(0) + + var sig xmssSignature + curSig := cursor(make([]byte, p.xmssSigSize())) + sig.fromBytes(p, &curSig) + state.xmssSign(state.NewStack(p.hPrime), sig, msg, idx, addr) + + b.Run("NodeRec", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = state.xmssNodeRec(idx, p.hPrime, addr) + } + }) + b.Run("NodeIter", func(b *testing.B) { + node := make([]byte, p.n) + stack := state.NewStack(p.hPrime) + for i := 0; i < b.N; i++ { + state.xmssNodeIter(stack, node, idx, p.hPrime, addr) + } + }) + b.Run("Sign", func(b *testing.B) { + stack := state.NewStack(p.hPrime) + for i := 0; i < b.N; i++ { + state.xmssSign(stack, sig, msg, idx, addr) + } + }) + b.Run("PkFromSig", func(b *testing.B) { + node := make([]byte, p.xmssPkSize()) + for i := 0; i < b.N; i++ { + state.xmssPkFromSig(node, msg, sig, idx, addr) + } + }) +}