diff --git a/go.mod b/go.mod index 55ece8c997..df547a4c1e 100644 --- a/go.mod +++ b/go.mod @@ -53,7 +53,7 @@ require ( github.com/spf13/cast v1.3.0 github.com/spf13/cobra v1.1.1 github.com/spf13/pflag v1.0.5 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 github.com/thoas/go-funk v0.8.0 github.com/vbauerster/mpb/v7 v7.1.5 github.com/vektah/gqlparser/v2 v2.5.16 @@ -75,7 +75,7 @@ require ( github.com/charmbracelet/bubbles v0.18.0 github.com/charmbracelet/bubbletea v0.25.0 github.com/charmbracelet/lipgloss v0.9.1 - github.com/go-git/go-git/v5 v5.12.0 + github.com/go-git/go-git/v5 v5.13.1 github.com/gowebpki/jcs v1.0.1 github.com/klauspost/compress v1.11.4 github.com/mholt/archiver/v3 v3.5.1 @@ -86,14 +86,14 @@ require ( dario.cat/mergo v1.0.0 // indirect github.com/ActiveState/pty v0.0.0-20230628221854-6fb90eb08a14 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect - github.com/ProtonMail/go-crypto v1.0.0 // indirect + github.com/ProtonMail/go-crypto v1.1.3 // indirect github.com/andybalholm/brotli v1.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/cloudflare/circl v1.3.7 // indirect github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 // indirect - github.com/cyphar/filepath-securejoin v0.2.4 // indirect + github.com/cyphar/filepath-securejoin v0.3.6 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect - github.com/go-git/go-billy/v5 v5.5.0 // indirect + github.com/go-git/go-billy/v5 v5.6.1 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/hinshun/vt10x v0.0.0-20220301184237-5011da428d02 // indirect @@ -109,7 +109,7 @@ require ( github.com/pierrec/lz4/v4 v4.1.2 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect - github.com/skeema/knownhosts v1.2.2 // indirect + github.com/skeema/knownhosts v1.3.0 // indirect github.com/sosodev/duration v1.3.1 // indirect golang.org/x/sync v0.10.0 // indirect ) diff --git a/go.sum b/go.sum index fbcccef181..ca20c3449c 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/ProtonMail/go-crypto v1.0.0 h1:LRuvITjQWX+WIfr930YHG2HNfjR1uOfyf5vE0kC2U78= github.com/ProtonMail/go-crypto v1.0.0/go.mod h1:EjAoLdwvbIOoOQr3ihjnSoLZRtE8azugULFRteWMNc0= +github.com/ProtonMail/go-crypto v1.1.3 h1:nRBOetoydLeUb4nHajyO2bKqMLfWQ/ZPwkXqXxPxCFk= +github.com/ProtonMail/go-crypto v1.1.3/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE= github.com/PuerkitoBio/goquery v1.9.3 h1:mpJr/ikUA9/GNJB/DBZcGeFDXUtosHRyRrwh7KGdTG0= github.com/PuerkitoBio/goquery v1.9.3/go.mod h1:1ndLHPdTz+DyQPICCWYlYQMPl0oXZj0G6D4LCYA6u4U= github.com/PuerkitoBio/purell v1.1.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= @@ -131,6 +133,8 @@ github.com/creack/pty v1.1.11 h1:07n33Z8lZxZ2qwegKbObQohDhXDQxiMMz1NOUGYlesw= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53EtKeQYTC3kyg= github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4= +github.com/cyphar/filepath-securejoin v0.3.6 h1:4d9N5ykBnSp5Xn2JkhocYDkOpURL/18CYMpo6xB9uWM= +github.com/cyphar/filepath-securejoin v0.3.6/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI= github.com/dave/jennifer v0.18.0 h1:fhwWYwRltL8wW567TWRHCstLaBCEsk5M5DE4rrMsi94= github.com/dave/jennifer v0.18.0/go.mod h1:fIb+770HOpJ2fmN9EPPKOqm1vMGhB+TwXKMZhrIygKg= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -150,6 +154,7 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a h1:mATvB/9r/3gvcejNsXKSkQ6lcIaNec2nyfOdlTBR2lU= github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a/go.mod h1:Ro8st/ElPeALwNFlcTpWmkr6IoMFfkjXAvTHpevnDsM= +github.com/elazarl/goproxy v1.2.3 h1:xwIyKHbaP5yfT6O9KIeYJR5549MXRQkoQMRXGztz8YQ= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= @@ -168,16 +173,21 @@ github.com/gammazero/workerpool v1.1.1/go.mod h1:5BN0IJVRjSFAypo9QTJCaWdijjNz9Jj github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gliderlabs/ssh v0.3.7 h1:iV3Bqi942d9huXnzEF2Mt+CY9gLu8DNM4Obd+8bODRE= github.com/gliderlabs/ssh v0.3.7/go.mod h1:zpHEXBstFnQYtGnB8k8kQLol82umzn/2/snG7alWVD8= +github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= github.com/globalsign/mgo v0.0.0-20180905125535-1ca0a4f7cbcb/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= github.com/go-git/go-billy/v5 v5.5.0 h1:yEY4yhzCDuMGSv83oGxiBotRzhwhNr8VZyphhiu+mTU= github.com/go-git/go-billy/v5 v5.5.0/go.mod h1:hmexnoNsr2SJU1Ju67OaNz5ASJY3+sHgFRpCtpDCKow= +github.com/go-git/go-billy/v5 v5.6.1 h1:u+dcrgaguSSkbjzHwelEjc0Yj300NUevrrPphk/SoRA= +github.com/go-git/go-billy/v5 v5.6.1/go.mod h1:0AsLr1z2+Uksi4NlElmMblP5rPcDZNRCD8ujZCRR2BE= github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4= github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= github.com/go-git/go-git/v5 v5.12.0 h1:7Md+ndsjrzZxbddRDZjF14qK+NN56sy6wkqaVrjZtys= github.com/go-git/go-git/v5 v5.12.0/go.mod h1:FTM9VKtnI2m65hNI/TenDDDnUf2Q9FHnXYjuz9i5OEY= +github.com/go-git/go-git/v5 v5.13.1 h1:DAQ9APonnlvSWpvolXWIuV6Q6zXy2wHbN4cVlNR5Q+M= +github.com/go-git/go-git/v5 v5.13.1/go.mod h1:qryJB4cSBoq3FRoBRf5A77joojuBcmPJ0qu3XXXVixc= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= @@ -549,6 +559,7 @@ github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M= +github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= @@ -597,6 +608,7 @@ github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rollbar/rollbar-go v1.1.0 h1:3ysiHp3ep8W50ykgBMCKXJGaK2Jdivru7SW9EYfAo+M= github.com/rollbar/rollbar-go v1.1.0/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtISeXco0L5PKQ= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -621,6 +633,8 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/skeema/knownhosts v1.2.2 h1:Iug2P4fLmDw9f41PB6thxUkNUkJzB5i+1/exaj40L3A= github.com/skeema/knownhosts v1.2.2/go.mod h1:xYbVRSPxqBZFrdmDyMmsOs+uX1UZC3nTN3ThzgDxUwo= +github.com/skeema/knownhosts v1.3.0 h1:AM+y0rI04VksttfwjkSTNQorvGqmwATnvnAHpSgc0LY= +github.com/skeema/knownhosts v1.3.0/go.mod h1:sPINvnADmT/qYH1kfv+ePMmOBTH6Tbl7b5LvTDjFK7M= github.com/skratchdot/open-golang v0.0.0-20190104022628-a2dfa6d0dab6 h1:cGT4dcuEyBwwu/v6tosyqcDp2yoIo/LwjMGixUvg3nU= github.com/skratchdot/open-golang v0.0.0-20190104022628-a2dfa6d0dab6/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= @@ -654,6 +668,8 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/thoas/go-funk v0.8.0 h1:JP9tKSvnpFVclYgDM0Is7FD9M4fhPvqA0s0BsXmzSRQ= github.com/thoas/go-funk v0.8.0/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= diff --git a/vendor/github.com/ProtonMail/go-crypto/ocb/ocb.go b/vendor/github.com/ProtonMail/go-crypto/ocb/ocb.go index 1a6f73502e..5022285b44 100644 --- a/vendor/github.com/ProtonMail/go-crypto/ocb/ocb.go +++ b/vendor/github.com/ProtonMail/go-crypto/ocb/ocb.go @@ -18,8 +18,9 @@ import ( "crypto/cipher" "crypto/subtle" "errors" - "github.com/ProtonMail/go-crypto/internal/byteutil" "math/bits" + + "github.com/ProtonMail/go-crypto/internal/byteutil" ) type ocb struct { @@ -153,7 +154,7 @@ func (o *ocb) crypt(instruction int, Y, nonce, adata, X []byte) []byte { truncatedNonce := make([]byte, len(nonce)) copy(truncatedNonce, nonce) truncatedNonce[len(truncatedNonce)-1] &= 192 - Ktop := make([]byte, blockSize) + var Ktop []byte if bytes.Equal(truncatedNonce, o.reusableKtop.noncePrefix) { Ktop = o.reusableKtop.Ktop } else { diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/armor/armor.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/armor/armor.go index d7af9141e3..e0a677f284 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/armor/armor.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/armor/armor.go @@ -23,7 +23,7 @@ import ( // Headers // // base64-encoded Bytes -// '=' base64 encoded checksum +// '=' base64 encoded checksum (optional) not checked anymore // -----END Type----- // // where Headers is a possibly empty sequence of Key: Value lines. @@ -40,36 +40,15 @@ type Block struct { var ArmorCorrupt error = errors.StructuralError("armor invalid") -const crc24Init = 0xb704ce -const crc24Poly = 0x1864cfb -const crc24Mask = 0xffffff - -// crc24 calculates the OpenPGP checksum as specified in RFC 4880, section 6.1 -func crc24(crc uint32, d []byte) uint32 { - for _, b := range d { - crc ^= uint32(b) << 16 - for i := 0; i < 8; i++ { - crc <<= 1 - if crc&0x1000000 != 0 { - crc ^= crc24Poly - } - } - } - return crc -} - var armorStart = []byte("-----BEGIN ") var armorEnd = []byte("-----END ") var armorEndOfLine = []byte("-----") -// lineReader wraps a line based reader. It watches for the end of an armor -// block and records the expected CRC value. +// lineReader wraps a line based reader. It watches for the end of an armor block type lineReader struct { - in *bufio.Reader - buf []byte - eof bool - crc uint32 - crcSet bool + in *bufio.Reader + buf []byte + eof bool } func (l *lineReader) Read(p []byte) (n int, err error) { @@ -98,26 +77,9 @@ func (l *lineReader) Read(p []byte) (n int, err error) { if len(line) == 5 && line[0] == '=' { // This is the checksum line - var expectedBytes [3]byte - var m int - m, err = base64.StdEncoding.Decode(expectedBytes[0:], line[1:]) - if m != 3 || err != nil { - return - } - l.crc = uint32(expectedBytes[0])<<16 | - uint32(expectedBytes[1])<<8 | - uint32(expectedBytes[2]) - - line, _, err = l.in.ReadLine() - if err != nil && err != io.EOF { - return - } - if !bytes.HasPrefix(line, armorEnd) { - return 0, ArmorCorrupt - } + // Don't check the checksum l.eof = true - l.crcSet = true return 0, io.EOF } @@ -138,23 +100,14 @@ func (l *lineReader) Read(p []byte) (n int, err error) { return } -// openpgpReader passes Read calls to the underlying base64 decoder, but keeps -// a running CRC of the resulting data and checks the CRC against the value -// found by the lineReader at EOF. +// openpgpReader passes Read calls to the underlying base64 decoder. type openpgpReader struct { - lReader *lineReader - b64Reader io.Reader - currentCRC uint32 + lReader *lineReader + b64Reader io.Reader } func (r *openpgpReader) Read(p []byte) (n int, err error) { n, err = r.b64Reader.Read(p) - r.currentCRC = crc24(r.currentCRC, p[:n]) - - if err == io.EOF && r.lReader.crcSet && r.lReader.crc != uint32(r.currentCRC&crc24Mask) { - return 0, ArmorCorrupt - } - return } @@ -222,7 +175,6 @@ TryNextBlock: } p.lReader.in = r - p.oReader.currentCRC = crc24Init p.oReader.lReader = &p.lReader p.oReader.b64Reader = base64.NewDecoder(base64.StdEncoding, &p.lReader) p.Body = &p.oReader diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/armor/encode.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/armor/encode.go index 5b6e16c19d..112f98b835 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/armor/encode.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/armor/encode.go @@ -14,6 +14,23 @@ var blockEnd = []byte("\n=") var newline = []byte("\n") var armorEndOfLineOut = []byte("-----\n") +const crc24Init = 0xb704ce +const crc24Poly = 0x1864cfb + +// crc24 calculates the OpenPGP checksum as specified in RFC 4880, section 6.1 +func crc24(crc uint32, d []byte) uint32 { + for _, b := range d { + crc ^= uint32(b) << 16 + for i := 0; i < 8; i++ { + crc <<= 1 + if crc&0x1000000 != 0 { + crc ^= crc24Poly + } + } + } + return crc +} + // writeSlices writes its arguments to the given Writer. func writeSlices(out io.Writer, slices ...[]byte) (err error) { for _, s := range slices { @@ -99,15 +116,18 @@ func (l *lineBreaker) Close() (err error) { // // encoding -> base64 encoder -> lineBreaker -> out type encoding struct { - out io.Writer - breaker *lineBreaker - b64 io.WriteCloser - crc uint32 - blockType []byte + out io.Writer + breaker *lineBreaker + b64 io.WriteCloser + crc uint32 + crcEnabled bool + blockType []byte } func (e *encoding) Write(data []byte) (n int, err error) { - e.crc = crc24(e.crc, data) + if e.crcEnabled { + e.crc = crc24(e.crc, data) + } return e.b64.Write(data) } @@ -118,20 +138,21 @@ func (e *encoding) Close() (err error) { } e.breaker.Close() - var checksumBytes [3]byte - checksumBytes[0] = byte(e.crc >> 16) - checksumBytes[1] = byte(e.crc >> 8) - checksumBytes[2] = byte(e.crc) + if e.crcEnabled { + var checksumBytes [3]byte + checksumBytes[0] = byte(e.crc >> 16) + checksumBytes[1] = byte(e.crc >> 8) + checksumBytes[2] = byte(e.crc) - var b64ChecksumBytes [4]byte - base64.StdEncoding.Encode(b64ChecksumBytes[:], checksumBytes[:]) + var b64ChecksumBytes [4]byte + base64.StdEncoding.Encode(b64ChecksumBytes[:], checksumBytes[:]) - return writeSlices(e.out, blockEnd, b64ChecksumBytes[:], newline, armorEnd, e.blockType, armorEndOfLine) + return writeSlices(e.out, blockEnd, b64ChecksumBytes[:], newline, armorEnd, e.blockType, armorEndOfLine) + } + return writeSlices(e.out, newline, armorEnd, e.blockType, armorEndOfLine) } -// Encode returns a WriteCloser which will encode the data written to it in -// OpenPGP armor. -func Encode(out io.Writer, blockType string, headers map[string]string) (w io.WriteCloser, err error) { +func encode(out io.Writer, blockType string, headers map[string]string, checksum bool) (w io.WriteCloser, err error) { bType := []byte(blockType) err = writeSlices(out, armorStart, bType, armorEndOfLineOut) if err != nil { @@ -151,11 +172,27 @@ func Encode(out io.Writer, blockType string, headers map[string]string) (w io.Wr } e := &encoding{ - out: out, - breaker: newLineBreaker(out, 64), - crc: crc24Init, - blockType: bType, + out: out, + breaker: newLineBreaker(out, 64), + blockType: bType, + crc: crc24Init, + crcEnabled: checksum, } e.b64 = base64.NewEncoder(base64.StdEncoding, e.breaker) return e, nil } + +// Encode returns a WriteCloser which will encode the data written to it in +// OpenPGP armor. +func Encode(out io.Writer, blockType string, headers map[string]string) (w io.WriteCloser, err error) { + return encode(out, blockType, headers, true) +} + +// EncodeWithChecksumOption returns a WriteCloser which will encode the data written to it in +// OpenPGP armor and provides the option to include a checksum. +// When forming ASCII Armor, the CRC24 footer SHOULD NOT be generated, +// unless interoperability with implementations that require the CRC24 footer +// to be present is a concern. +func EncodeWithChecksumOption(out io.Writer, blockType string, headers map[string]string, doChecksum bool) (w io.WriteCloser, err error) { + return encode(out, blockType, headers, doChecksum) +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/canonical_text.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/canonical_text.go index a94f6150c4..5b40e1375d 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/canonical_text.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/canonical_text.go @@ -30,8 +30,12 @@ func writeCanonical(cw io.Writer, buf []byte, s *int) (int, error) { if c == '\r' { *s = 1 } else if c == '\n' { - cw.Write(buf[start:i]) - cw.Write(newline) + if _, err := cw.Write(buf[start:i]); err != nil { + return 0, err + } + if _, err := cw.Write(newline); err != nil { + return 0, err + } start = i + 1 } case 1: @@ -39,7 +43,9 @@ func writeCanonical(cw io.Writer, buf []byte, s *int) (int, error) { } } - cw.Write(buf[start:]) + if _, err := cw.Write(buf[start:]); err != nil { + return 0, err + } return len(buf), nil } diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/ecdh/ecdh.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/ecdh/ecdh.go index c895bad6bb..db8fb163b6 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/ecdh/ecdh.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/ecdh/ecdh.go @@ -163,13 +163,9 @@ func buildKey(pub *PublicKey, zb []byte, curveOID, fingerprint []byte, stripLead if _, err := param.Write([]byte("Anonymous Sender ")); err != nil { return nil, err } - // For v5 keys, the 20 leftmost octets of the fingerprint are used. - if _, err := param.Write(fingerprint[:20]); err != nil { + if _, err := param.Write(fingerprint[:]); err != nil { return nil, err } - if param.Len()-len(curveOID) != 45 { - return nil, errors.New("ecdh: malformed KDF Param") - } // MB = Hash ( 00 || 00 || 00 || 01 || ZB || Param ); h := pub.KDF.Hash.New() diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/ed25519/ed25519.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/ed25519/ed25519.go new file mode 100644 index 0000000000..6abdf7c446 --- /dev/null +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/ed25519/ed25519.go @@ -0,0 +1,115 @@ +// Package ed25519 implements the ed25519 signature algorithm for OpenPGP +// as defined in the Open PGP crypto refresh. +package ed25519 + +import ( + "crypto/subtle" + "io" + + "github.com/ProtonMail/go-crypto/openpgp/errors" + ed25519lib "github.com/cloudflare/circl/sign/ed25519" +) + +const ( + // PublicKeySize is the size, in bytes, of public keys in this package. + PublicKeySize = ed25519lib.PublicKeySize + // SeedSize is the size, in bytes, of private key seeds. + // The private key representation used by RFC 8032. + SeedSize = ed25519lib.SeedSize + // SignatureSize is the size, in bytes, of signatures generated and verified by this package. + SignatureSize = ed25519lib.SignatureSize +) + +type PublicKey struct { + // Point represents the elliptic curve point of the public key. + Point []byte +} + +type PrivateKey struct { + PublicKey + // Key the private key representation by RFC 8032, + // encoded as seed | pub key point. + Key []byte +} + +// NewPublicKey creates a new empty ed25519 public key. +func NewPublicKey() *PublicKey { + return &PublicKey{} +} + +// NewPrivateKey creates a new empty private key referencing the public key. +func NewPrivateKey(key PublicKey) *PrivateKey { + return &PrivateKey{ + PublicKey: key, + } +} + +// Seed returns the ed25519 private key secret seed. +// The private key representation by RFC 8032. +func (pk *PrivateKey) Seed() []byte { + return pk.Key[:SeedSize] +} + +// MarshalByteSecret returns the underlying 32 byte seed of the private key. +func (pk *PrivateKey) MarshalByteSecret() []byte { + return pk.Seed() +} + +// UnmarshalByteSecret computes the private key from the secret seed +// and stores it in the private key object. +func (sk *PrivateKey) UnmarshalByteSecret(seed []byte) error { + sk.Key = ed25519lib.NewKeyFromSeed(seed) + return nil +} + +// GenerateKey generates a fresh private key with the provided randomness source. +func GenerateKey(rand io.Reader) (*PrivateKey, error) { + publicKey, privateKey, err := ed25519lib.GenerateKey(rand) + if err != nil { + return nil, err + } + privateKeyOut := new(PrivateKey) + privateKeyOut.PublicKey.Point = publicKey[:] + privateKeyOut.Key = privateKey[:] + return privateKeyOut, nil +} + +// Sign signs a message with the ed25519 algorithm. +// priv MUST be a valid key! Check this with Validate() before use. +func Sign(priv *PrivateKey, message []byte) ([]byte, error) { + return ed25519lib.Sign(priv.Key, message), nil +} + +// Verify verifies an ed25519 signature. +func Verify(pub *PublicKey, message []byte, signature []byte) bool { + return ed25519lib.Verify(pub.Point, message, signature) +} + +// Validate checks if the ed25519 private key is valid. +func Validate(priv *PrivateKey) error { + expectedPrivateKey := ed25519lib.NewKeyFromSeed(priv.Seed()) + if subtle.ConstantTimeCompare(priv.Key, expectedPrivateKey) == 0 { + return errors.KeyInvalidError("ed25519: invalid ed25519 secret") + } + if subtle.ConstantTimeCompare(priv.PublicKey.Point, expectedPrivateKey[SeedSize:]) == 0 { + return errors.KeyInvalidError("ed25519: invalid ed25519 public key") + } + return nil +} + +// ENCODING/DECODING signature: + +// WriteSignature encodes and writes an ed25519 signature to writer. +func WriteSignature(writer io.Writer, signature []byte) error { + _, err := writer.Write(signature) + return err +} + +// ReadSignature decodes an ed25519 signature from a reader. +func ReadSignature(reader io.Reader) ([]byte, error) { + signature := make([]byte, SignatureSize) + if _, err := io.ReadFull(reader, signature); err != nil { + return nil, err + } + return signature, nil +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/ed448/ed448.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/ed448/ed448.go new file mode 100644 index 0000000000..b11fb4fb17 --- /dev/null +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/ed448/ed448.go @@ -0,0 +1,119 @@ +// Package ed448 implements the ed448 signature algorithm for OpenPGP +// as defined in the Open PGP crypto refresh. +package ed448 + +import ( + "crypto/subtle" + "io" + + "github.com/ProtonMail/go-crypto/openpgp/errors" + ed448lib "github.com/cloudflare/circl/sign/ed448" +) + +const ( + // PublicKeySize is the size, in bytes, of public keys in this package. + PublicKeySize = ed448lib.PublicKeySize + // SeedSize is the size, in bytes, of private key seeds. + // The private key representation used by RFC 8032. + SeedSize = ed448lib.SeedSize + // SignatureSize is the size, in bytes, of signatures generated and verified by this package. + SignatureSize = ed448lib.SignatureSize +) + +type PublicKey struct { + // Point represents the elliptic curve point of the public key. + Point []byte +} + +type PrivateKey struct { + PublicKey + // Key the private key representation by RFC 8032, + // encoded as seed | public key point. + Key []byte +} + +// NewPublicKey creates a new empty ed448 public key. +func NewPublicKey() *PublicKey { + return &PublicKey{} +} + +// NewPrivateKey creates a new empty private key referencing the public key. +func NewPrivateKey(key PublicKey) *PrivateKey { + return &PrivateKey{ + PublicKey: key, + } +} + +// Seed returns the ed448 private key secret seed. +// The private key representation by RFC 8032. +func (pk *PrivateKey) Seed() []byte { + return pk.Key[:SeedSize] +} + +// MarshalByteSecret returns the underlying seed of the private key. +func (pk *PrivateKey) MarshalByteSecret() []byte { + return pk.Seed() +} + +// UnmarshalByteSecret computes the private key from the secret seed +// and stores it in the private key object. +func (sk *PrivateKey) UnmarshalByteSecret(seed []byte) error { + sk.Key = ed448lib.NewKeyFromSeed(seed) + return nil +} + +// GenerateKey generates a fresh private key with the provided randomness source. +func GenerateKey(rand io.Reader) (*PrivateKey, error) { + publicKey, privateKey, err := ed448lib.GenerateKey(rand) + if err != nil { + return nil, err + } + privateKeyOut := new(PrivateKey) + privateKeyOut.PublicKey.Point = publicKey[:] + privateKeyOut.Key = privateKey[:] + return privateKeyOut, nil +} + +// Sign signs a message with the ed448 algorithm. +// priv MUST be a valid key! Check this with Validate() before use. +func Sign(priv *PrivateKey, message []byte) ([]byte, error) { + // Ed448 is used with the empty string as a context string. + // See https://datatracker.ietf.org/doc/html/draft-ietf-openpgp-crypto-refresh-08#section-13.7 + return ed448lib.Sign(priv.Key, message, ""), nil +} + +// Verify verifies a ed448 signature +func Verify(pub *PublicKey, message []byte, signature []byte) bool { + // Ed448 is used with the empty string as a context string. + // See https://datatracker.ietf.org/doc/html/draft-ietf-openpgp-crypto-refresh-08#section-13.7 + return ed448lib.Verify(pub.Point, message, signature, "") +} + +// Validate checks if the ed448 private key is valid +func Validate(priv *PrivateKey) error { + expectedPrivateKey := ed448lib.NewKeyFromSeed(priv.Seed()) + if subtle.ConstantTimeCompare(priv.Key, expectedPrivateKey) == 0 { + return errors.KeyInvalidError("ed448: invalid ed448 secret") + } + if subtle.ConstantTimeCompare(priv.PublicKey.Point, expectedPrivateKey[SeedSize:]) == 0 { + return errors.KeyInvalidError("ed448: invalid ed448 public key") + } + return nil +} + +// ENCODING/DECODING signature: + +// WriteSignature encodes and writes an ed448 signature to writer. +func WriteSignature(writer io.Writer, signature []byte) error { + _, err := writer.Write(signature) + return err +} + +// ReadSignature decodes an ed448 signature from a reader. +func ReadSignature(reader io.Reader) ([]byte, error) { + signature := make([]byte, SignatureSize) + if _, err := io.ReadFull(reader, signature); err != nil { + return nil, err + } + return signature, nil +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/errors/errors.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/errors/errors.go index 17e2bcfed2..0eb3937b39 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/errors/errors.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/errors/errors.go @@ -9,6 +9,18 @@ import ( "strconv" ) +var ( + // ErrDecryptSessionKeyParsing is a generic error message for parsing errors in decrypted data + // to reduce the risk of oracle attacks. + ErrDecryptSessionKeyParsing = DecryptWithSessionKeyError("parsing error") + // ErrAEADTagVerification is returned if one of the tag verifications in SEIPDv2 fails + ErrAEADTagVerification error = DecryptWithSessionKeyError("AEAD tag verification failed") + // ErrMDCHashMismatch + ErrMDCHashMismatch error = SignatureError("MDC hash mismatch") + // ErrMDCMissing + ErrMDCMissing error = SignatureError("MDC packet not found") +) + // A StructuralError is returned when OpenPGP data is found to be syntactically // invalid. type StructuralError string @@ -17,6 +29,34 @@ func (s StructuralError) Error() string { return "openpgp: invalid data: " + string(s) } +// A DecryptWithSessionKeyError is returned when a failure occurs when reading from symmetrically decrypted data or +// an authentication tag verification fails. +// Such an error indicates that the supplied session key is likely wrong or the data got corrupted. +type DecryptWithSessionKeyError string + +func (s DecryptWithSessionKeyError) Error() string { + return "openpgp: decryption with session key failed: " + string(s) +} + +// HandleSensitiveParsingError handles parsing errors when reading data from potentially decrypted data. +// The function makes parsing errors generic to reduce the risk of oracle attacks in SEIPDv1. +func HandleSensitiveParsingError(err error, decrypted bool) error { + if !decrypted { + // Data was not encrypted so we return the inner error. + return err + } + // The data is read from a stream that decrypts using a session key; + // therefore, we need to handle parsing errors appropriately. + // This is essential to mitigate the risk of oracle attacks. + if decError, ok := err.(*DecryptWithSessionKeyError); ok { + return decError + } + if decError, ok := err.(DecryptWithSessionKeyError); ok { + return decError + } + return ErrDecryptSessionKeyParsing +} + // UnsupportedError indicates that, although the OpenPGP data is valid, it // makes use of currently unimplemented features. type UnsupportedError string @@ -41,9 +81,6 @@ func (b SignatureError) Error() string { return "openpgp: invalid signature: " + string(b) } -var ErrMDCHashMismatch error = SignatureError("MDC hash mismatch") -var ErrMDCMissing error = SignatureError("MDC packet not found") - type signatureExpiredError int func (se signatureExpiredError) Error() string { @@ -58,6 +95,14 @@ func (ke keyExpiredError) Error() string { return "openpgp: key expired" } +var ErrSignatureOlderThanKey error = signatureOlderThanKeyError(0) + +type signatureOlderThanKeyError int + +func (ske signatureOlderThanKeyError) Error() string { + return "openpgp: signature is older than the key" +} + var ErrKeyExpired error = keyExpiredError(0) type keyIncorrectError int @@ -92,12 +137,24 @@ func (keyRevokedError) Error() string { var ErrKeyRevoked error = keyRevokedError(0) +type WeakAlgorithmError string + +func (e WeakAlgorithmError) Error() string { + return "openpgp: weak algorithms are rejected: " + string(e) +} + type UnknownPacketTypeError uint8 func (upte UnknownPacketTypeError) Error() string { return "openpgp: unknown packet type: " + strconv.Itoa(int(upte)) } +type CriticalUnknownPacketTypeError uint8 + +func (upte CriticalUnknownPacketTypeError) Error() string { + return "openpgp: unknown critical packet type: " + strconv.Itoa(int(upte)) +} + // AEADError indicates that there is a problem when initializing or using a // AEAD instance, configuration struct, nonces or index values. type AEADError string @@ -114,3 +171,10 @@ type ErrDummyPrivateKey string func (dke ErrDummyPrivateKey) Error() string { return "openpgp: s2k GNU dummy key: " + string(dke) } + +// ErrMalformedMessage results when the packet sequence is incorrect +type ErrMalformedMessage string + +func (dke ErrMalformedMessage) Error() string { + return "openpgp: malformed message " + string(dke) +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/algorithm/cipher.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/algorithm/cipher.go index 5760cff80e..c76a75bcda 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/algorithm/cipher.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/algorithm/cipher.go @@ -51,24 +51,14 @@ func (sk CipherFunction) Id() uint8 { return uint8(sk) } -var keySizeByID = map[uint8]int{ - TripleDES.Id(): 24, - CAST5.Id(): cast5.KeySize, - AES128.Id(): 16, - AES192.Id(): 24, - AES256.Id(): 32, -} - // KeySize returns the key size, in bytes, of cipher. func (cipher CipherFunction) KeySize() int { switch cipher { - case TripleDES: - return 24 case CAST5: return cast5.KeySize case AES128: return 16 - case AES192: + case AES192, TripleDES: return 24 case AES256: return 32 diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/ecc/curve_info.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/ecc/curve_info.go index 35751034dd..0da2d0d852 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/ecc/curve_info.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/ecc/curve_info.go @@ -4,11 +4,14 @@ package ecc import ( "bytes" "crypto/elliptic" + "github.com/ProtonMail/go-crypto/bitcurves" "github.com/ProtonMail/go-crypto/brainpool" "github.com/ProtonMail/go-crypto/openpgp/internal/encoding" ) +const Curve25519GenName = "Curve25519" + type CurveInfo struct { GenName string Oid *encoding.OID @@ -42,19 +45,19 @@ var Curves = []CurveInfo{ }, { // Curve25519 - GenName: "Curve25519", + GenName: Curve25519GenName, Oid: encoding.NewOID([]byte{0x2B, 0x06, 0x01, 0x04, 0x01, 0x97, 0x55, 0x01, 0x05, 0x01}), Curve: NewCurve25519(), }, { - // X448 + // x448 GenName: "Curve448", Oid: encoding.NewOID([]byte{0x2B, 0x65, 0x6F}), Curve: NewX448(), }, { // Ed25519 - GenName: "Curve25519", + GenName: Curve25519GenName, Oid: encoding.NewOID([]byte{0x2B, 0x06, 0x01, 0x04, 0x01, 0xDA, 0x47, 0x0F, 0x01}), Curve: NewEd25519(), }, diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/ecc/ed25519.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/ecc/ed25519.go index 54a08a8a38..5a4c3a8596 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/ecc/ed25519.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/ecc/ed25519.go @@ -2,6 +2,7 @@ package ecc import ( + "bytes" "crypto/subtle" "io" @@ -90,7 +91,14 @@ func (c *ed25519) GenerateEdDSA(rand io.Reader) (pub, priv []byte, err error) { } func getEd25519Sk(publicKey, privateKey []byte) ed25519lib.PrivateKey { - return append(privateKey, publicKey...) + privateKeyCap, privateKeyLen, publicKeyLen := cap(privateKey), len(privateKey), len(publicKey) + + if privateKeyCap >= privateKeyLen+publicKeyLen && + bytes.Equal(privateKey[privateKeyLen:privateKeyLen+publicKeyLen], publicKey) { + return privateKey[:privateKeyLen+publicKeyLen] + } + + return append(privateKey[:privateKeyLen:privateKeyLen], publicKey...) } func (c *ed25519) Sign(publicKey, privateKey, message []byte) (sig []byte, err error) { diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/ecc/ed448.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/ecc/ed448.go index 18cd80434b..b6edda7480 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/ecc/ed448.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/internal/ecc/ed448.go @@ -2,6 +2,7 @@ package ecc import ( + "bytes" "crypto/subtle" "io" @@ -84,7 +85,14 @@ func (c *ed448) GenerateEdDSA(rand io.Reader) (pub, priv []byte, err error) { } func getEd448Sk(publicKey, privateKey []byte) ed448lib.PrivateKey { - return append(privateKey, publicKey...) + privateKeyCap, privateKeyLen, publicKeyLen := cap(privateKey), len(privateKey), len(publicKey) + + if privateKeyCap >= privateKeyLen+publicKeyLen && + bytes.Equal(privateKey[privateKeyLen:privateKeyLen+publicKeyLen], publicKey) { + return privateKey[:privateKeyLen+publicKeyLen] + } + + return append(privateKey[:privateKeyLen:privateKeyLen], publicKey...) } func (c *ed448) Sign(publicKey, privateKey, message []byte) (sig []byte, err error) { diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/key_generation.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/key_generation.go index 0e71934cd9..77213f66be 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/key_generation.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/key_generation.go @@ -15,11 +15,15 @@ import ( "github.com/ProtonMail/go-crypto/openpgp/ecdh" "github.com/ProtonMail/go-crypto/openpgp/ecdsa" + "github.com/ProtonMail/go-crypto/openpgp/ed25519" + "github.com/ProtonMail/go-crypto/openpgp/ed448" "github.com/ProtonMail/go-crypto/openpgp/eddsa" "github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "github.com/ProtonMail/go-crypto/openpgp/internal/ecc" "github.com/ProtonMail/go-crypto/openpgp/packet" + "github.com/ProtonMail/go-crypto/openpgp/x25519" + "github.com/ProtonMail/go-crypto/openpgp/x448" ) // NewEntity returns an Entity that contains a fresh RSA/RSA keypair with a @@ -36,8 +40,10 @@ func NewEntity(name, comment, email string, config *packet.Config) (*Entity, err return nil, err } primary := packet.NewSignerPrivateKey(creationTime, primaryPrivRaw) - if config != nil && config.V5Keys { - primary.UpgradeToV5() + if config.V6() { + if err := primary.UpgradeToV6(); err != nil { + return nil, err + } } e := &Entity{ @@ -45,9 +51,25 @@ func NewEntity(name, comment, email string, config *packet.Config) (*Entity, err PrivateKey: primary, Identities: make(map[string]*Identity), Subkeys: []Subkey{}, + Signatures: []*packet.Signature{}, + } + + if config.V6() { + // In v6 keys algorithm preferences should be stored in direct key signatures + selfSignature := createSignaturePacket(&primary.PublicKey, packet.SigTypeDirectSignature, config) + err = writeKeyProperties(selfSignature, creationTime, keyLifetimeSecs, config) + if err != nil { + return nil, err + } + err = selfSignature.SignDirectKeyBinding(&primary.PublicKey, primary, config) + if err != nil { + return nil, err + } + e.Signatures = append(e.Signatures, selfSignature) + e.SelfSignature = selfSignature } - err = e.addUserId(name, comment, email, config, creationTime, keyLifetimeSecs) + err = e.addUserId(name, comment, email, config, creationTime, keyLifetimeSecs, !config.V6()) if err != nil { return nil, err } @@ -65,32 +87,19 @@ func NewEntity(name, comment, email string, config *packet.Config) (*Entity, err func (t *Entity) AddUserId(name, comment, email string, config *packet.Config) error { creationTime := config.Now() keyLifetimeSecs := config.KeyLifetime() - return t.addUserId(name, comment, email, config, creationTime, keyLifetimeSecs) + return t.addUserId(name, comment, email, config, creationTime, keyLifetimeSecs, !config.V6()) } -func (t *Entity) addUserId(name, comment, email string, config *packet.Config, creationTime time.Time, keyLifetimeSecs uint32) error { - uid := packet.NewUserId(name, comment, email) - if uid == nil { - return errors.InvalidArgumentError("user id field contained invalid characters") - } - - if _, ok := t.Identities[uid.Id]; ok { - return errors.InvalidArgumentError("user id exist") - } - - primary := t.PrivateKey - - isPrimaryId := len(t.Identities) == 0 +func writeKeyProperties(selfSignature *packet.Signature, creationTime time.Time, keyLifetimeSecs uint32, config *packet.Config) error { + advertiseAead := config.AEAD() != nil - selfSignature := createSignaturePacket(&primary.PublicKey, packet.SigTypePositiveCert, config) selfSignature.CreationTime = creationTime selfSignature.KeyLifetimeSecs = &keyLifetimeSecs - selfSignature.IsPrimaryId = &isPrimaryId selfSignature.FlagsValid = true selfSignature.FlagSign = true selfSignature.FlagCertify = true selfSignature.SEIPDv1 = true // true by default, see 5.8 vs. 5.14 - selfSignature.SEIPDv2 = config.AEAD() != nil + selfSignature.SEIPDv2 = advertiseAead // Set the PreferredHash for the SelfSignature from the packet.Config. // If it is not the must-implement algorithm from rfc4880bis, append that. @@ -119,18 +128,44 @@ func (t *Entity) addUserId(name, comment, email string, config *packet.Config, c selfSignature.PreferredCompression = append(selfSignature.PreferredCompression, uint8(config.Compression())) } - // And for DefaultMode. - modes := []uint8{uint8(config.AEAD().Mode())} - if config.AEAD().Mode() != packet.AEADModeOCB { - modes = append(modes, uint8(packet.AEADModeOCB)) + if advertiseAead { + // Get the preferred AEAD mode from the packet.Config. + // If it is not the must-implement algorithm from rfc9580, append that. + modes := []uint8{uint8(config.AEAD().Mode())} + if config.AEAD().Mode() != packet.AEADModeOCB { + modes = append(modes, uint8(packet.AEADModeOCB)) + } + + // For preferred (AES256, GCM), we'll generate (AES256, GCM), (AES256, OCB), (AES128, GCM), (AES128, OCB) + for _, cipher := range selfSignature.PreferredSymmetric { + for _, mode := range modes { + selfSignature.PreferredCipherSuites = append(selfSignature.PreferredCipherSuites, [2]uint8{cipher, mode}) + } + } + } + return nil +} + +func (t *Entity) addUserId(name, comment, email string, config *packet.Config, creationTime time.Time, keyLifetimeSecs uint32, writeProperties bool) error { + uid := packet.NewUserId(name, comment, email) + if uid == nil { + return errors.InvalidArgumentError("user id field contained invalid characters") + } + + if _, ok := t.Identities[uid.Id]; ok { + return errors.InvalidArgumentError("user id exist") } - // For preferred (AES256, GCM), we'll generate (AES256, GCM), (AES256, OCB), (AES128, GCM), (AES128, OCB) - for _, cipher := range selfSignature.PreferredSymmetric { - for _, mode := range modes { - selfSignature.PreferredCipherSuites = append(selfSignature.PreferredCipherSuites, [2]uint8{cipher, mode}) + primary := t.PrivateKey + isPrimaryId := len(t.Identities) == 0 + selfSignature := createSignaturePacket(&primary.PublicKey, packet.SigTypePositiveCert, config) + if writeProperties { + err := writeKeyProperties(selfSignature, creationTime, keyLifetimeSecs, config) + if err != nil { + return err } } + selfSignature.IsPrimaryId = &isPrimaryId // User ID binding signature err := selfSignature.SignUserId(uid.Id, &primary.PublicKey, primary, config) @@ -158,8 +193,10 @@ func (e *Entity) AddSigningSubkey(config *packet.Config) error { } sub := packet.NewSignerPrivateKey(creationTime, subPrivRaw) sub.IsSubkey = true - if config != nil && config.V5Keys { - sub.UpgradeToV5() + if config.V6() { + if err := sub.UpgradeToV6(); err != nil { + return err + } } subkey := Subkey{ @@ -203,8 +240,10 @@ func (e *Entity) addEncryptionSubkey(config *packet.Config, creationTime time.Ti } sub := packet.NewDecrypterPrivateKey(creationTime, subPrivRaw) sub.IsSubkey = true - if config != nil && config.V5Keys { - sub.UpgradeToV5() + if config.V6() { + if err := sub.UpgradeToV6(); err != nil { + return err + } } subkey := Subkey{ @@ -242,6 +281,11 @@ func newSigner(config *packet.Config) (signer interface{}, err error) { } return rsa.GenerateKey(config.Random(), bits) case packet.PubKeyAlgoEdDSA: + if config.V6() { + // Implementations MUST NOT accept or generate v6 key material + // using the deprecated OIDs. + return nil, errors.InvalidArgumentError("EdDSALegacy cannot be used for v6 keys") + } curve := ecc.FindEdDSAByGenName(string(config.CurveName())) if curve == nil { return nil, errors.InvalidArgumentError("unsupported curve") @@ -263,6 +307,18 @@ func newSigner(config *packet.Config) (signer interface{}, err error) { return nil, err } return priv, nil + case packet.PubKeyAlgoEd25519: + priv, err := ed25519.GenerateKey(config.Random()) + if err != nil { + return nil, err + } + return priv, nil + case packet.PubKeyAlgoEd448: + priv, err := ed448.GenerateKey(config.Random()) + if err != nil { + return nil, err + } + return priv, nil default: return nil, errors.InvalidArgumentError("unsupported public key algorithm") } @@ -285,6 +341,13 @@ func newDecrypter(config *packet.Config) (decrypter interface{}, err error) { case packet.PubKeyAlgoEdDSA, packet.PubKeyAlgoECDSA: fallthrough // When passing EdDSA or ECDSA, we generate an ECDH subkey case packet.PubKeyAlgoECDH: + if config.V6() && + (config.CurveName() == packet.Curve25519 || + config.CurveName() == packet.Curve448) { + // Implementations MUST NOT accept or generate v6 key material + // using the deprecated OIDs. + return nil, errors.InvalidArgumentError("ECDH with Curve25519/448 legacy cannot be used for v6 keys") + } var kdf = ecdh.KDF{ Hash: algorithm.SHA512, Cipher: algorithm.AES256, @@ -294,6 +357,10 @@ func newDecrypter(config *packet.Config) (decrypter interface{}, err error) { return nil, errors.InvalidArgumentError("unsupported curve") } return ecdh.GenerateKey(config.Random(), curve, kdf) + case packet.PubKeyAlgoEd25519, packet.PubKeyAlgoX25519: // When passing Ed25519, we generate an x25519 subkey + return x25519.GenerateKey(config.Random()) + case packet.PubKeyAlgoEd448, packet.PubKeyAlgoX448: // When passing Ed448, we generate an x448 subkey + return x448.GenerateKey(config.Random()) default: return nil, errors.InvalidArgumentError("unsupported public key algorithm") } @@ -302,7 +369,7 @@ func newDecrypter(config *packet.Config) (decrypter interface{}, err error) { var bigOne = big.NewInt(1) // generateRSAKeyWithPrimes generates a multi-prime RSA keypair of the -// given bit size, using the given random source and prepopulated primes. +// given bit size, using the given random source and pre-populated primes. func generateRSAKeyWithPrimes(random io.Reader, nprimes int, bits int, prepopulatedPrimes []*big.Int) (*rsa.PrivateKey, error) { priv := new(rsa.PrivateKey) priv.E = 65537 diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/keys.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/keys.go index 2d7b0cf373..a071353e2e 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/keys.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/keys.go @@ -6,6 +6,7 @@ package openpgp import ( goerrors "errors" + "fmt" "io" "time" @@ -24,11 +25,13 @@ var PrivateKeyType = "PGP PRIVATE KEY BLOCK" // (which must be a signing key), one or more identities claimed by that key, // and zero or more subkeys, which may be encryption keys. type Entity struct { - PrimaryKey *packet.PublicKey - PrivateKey *packet.PrivateKey - Identities map[string]*Identity // indexed by Identity.Name - Revocations []*packet.Signature - Subkeys []Subkey + PrimaryKey *packet.PublicKey + PrivateKey *packet.PrivateKey + Identities map[string]*Identity // indexed by Identity.Name + Revocations []*packet.Signature + Subkeys []Subkey + SelfSignature *packet.Signature // Direct-key self signature of the PrimaryKey (contains primary key properties in v6) + Signatures []*packet.Signature // all (potentially unverified) self-signatures, revocations, and third-party signatures } // An Identity represents an identity claimed by an Entity and zero or more @@ -120,12 +123,12 @@ func shouldPreferIdentity(existingId, potentialNewId *Identity) bool { // given Entity. func (e *Entity) EncryptionKey(now time.Time) (Key, bool) { // Fail to find any encryption key if the... - i := e.PrimaryIdentity() - if e.PrimaryKey.KeyExpired(i.SelfSignature, now) || // primary key has expired - i.SelfSignature == nil || // user ID has no self-signature - i.SelfSignature.SigExpired(now) || // user ID self-signature has expired + primarySelfSignature, primaryIdentity := e.PrimarySelfSignature() + if primarySelfSignature == nil || // no self-signature found + e.PrimaryKey.KeyExpired(primarySelfSignature, now) || // primary key has expired e.Revoked(now) || // primary key has been revoked - i.Revoked(now) { // user ID has been revoked + primarySelfSignature.SigExpired(now) || // user ID or or direct self-signature has expired + (primaryIdentity != nil && primaryIdentity.Revoked(now)) { // user ID has been revoked (for v4 keys) return Key{}, false } @@ -152,9 +155,9 @@ func (e *Entity) EncryptionKey(now time.Time) (Key, bool) { // If we don't have any subkeys for encryption and the primary key // is marked as OK to encrypt with, then we can use it. - if i.SelfSignature.FlagsValid && i.SelfSignature.FlagEncryptCommunications && + if primarySelfSignature.FlagsValid && primarySelfSignature.FlagEncryptCommunications && e.PrimaryKey.PubKeyAlgo.CanEncrypt() { - return Key{e, e.PrimaryKey, e.PrivateKey, i.SelfSignature, e.Revocations}, true + return Key{e, e.PrimaryKey, e.PrivateKey, primarySelfSignature, e.Revocations}, true } return Key{}, false @@ -186,12 +189,12 @@ func (e *Entity) SigningKeyById(now time.Time, id uint64) (Key, bool) { func (e *Entity) signingKeyByIdUsage(now time.Time, id uint64, flags int) (Key, bool) { // Fail to find any signing key if the... - i := e.PrimaryIdentity() - if e.PrimaryKey.KeyExpired(i.SelfSignature, now) || // primary key has expired - i.SelfSignature == nil || // user ID has no self-signature - i.SelfSignature.SigExpired(now) || // user ID self-signature has expired + primarySelfSignature, primaryIdentity := e.PrimarySelfSignature() + if primarySelfSignature == nil || // no self-signature found + e.PrimaryKey.KeyExpired(primarySelfSignature, now) || // primary key has expired e.Revoked(now) || // primary key has been revoked - i.Revoked(now) { // user ID has been revoked + primarySelfSignature.SigExpired(now) || // user ID or direct self-signature has expired + (primaryIdentity != nil && primaryIdentity.Revoked(now)) { // user ID has been revoked (for v4 keys) return Key{}, false } @@ -220,12 +223,12 @@ func (e *Entity) signingKeyByIdUsage(now time.Time, id uint64, flags int) (Key, // If we don't have any subkeys for signing and the primary key // is marked as OK to sign with, then we can use it. - if i.SelfSignature.FlagsValid && - (flags&packet.KeyFlagCertify == 0 || i.SelfSignature.FlagCertify) && - (flags&packet.KeyFlagSign == 0 || i.SelfSignature.FlagSign) && + if primarySelfSignature.FlagsValid && + (flags&packet.KeyFlagCertify == 0 || primarySelfSignature.FlagCertify) && + (flags&packet.KeyFlagSign == 0 || primarySelfSignature.FlagSign) && e.PrimaryKey.PubKeyAlgo.CanSign() && (id == 0 || e.PrimaryKey.KeyId == id) { - return Key{e, e.PrimaryKey, e.PrivateKey, i.SelfSignature, e.Revocations}, true + return Key{e, e.PrimaryKey, e.PrivateKey, primarySelfSignature, e.Revocations}, true } // No keys with a valid Signing Flag or no keys matched the id passed in @@ -259,7 +262,7 @@ func (e *Entity) EncryptPrivateKeys(passphrase []byte, config *packet.Config) er var keysToEncrypt []*packet.PrivateKey // Add entity private key to encrypt. if e.PrivateKey != nil && !e.PrivateKey.Dummy() && !e.PrivateKey.Encrypted { - keysToEncrypt = append(keysToEncrypt, e.PrivateKey) + keysToEncrypt = append(keysToEncrypt, e.PrivateKey) } // Add subkeys to encrypt. @@ -271,7 +274,7 @@ func (e *Entity) EncryptPrivateKeys(passphrase []byte, config *packet.Config) er return packet.EncryptPrivateKeys(keysToEncrypt, passphrase, config) } -// DecryptPrivateKeys decrypts all encrypted keys in the entitiy with the given passphrase. +// DecryptPrivateKeys decrypts all encrypted keys in the entity with the given passphrase. // Avoids recomputation of similar s2k key derivations. Public keys and dummy keys are ignored, // and don't cause an error to be returned. func (e *Entity) DecryptPrivateKeys(passphrase []byte) error { @@ -284,7 +287,7 @@ func (e *Entity) DecryptPrivateKeys(passphrase []byte) error { // Add subkeys to decrypt. for _, sub := range e.Subkeys { if sub.PrivateKey != nil && !sub.PrivateKey.Dummy() && sub.PrivateKey.Encrypted { - keysToDecrypt = append(keysToDecrypt, sub.PrivateKey) + keysToDecrypt = append(keysToDecrypt, sub.PrivateKey) } } return packet.DecryptPrivateKeys(keysToDecrypt, passphrase) @@ -318,8 +321,7 @@ type EntityList []*Entity func (el EntityList) KeysById(id uint64) (keys []Key) { for _, e := range el { if e.PrimaryKey.KeyId == id { - ident := e.PrimaryIdentity() - selfSig := ident.SelfSignature + selfSig, _ := e.PrimarySelfSignature() keys = append(keys, Key{e, e.PrimaryKey, e.PrivateKey, selfSig, e.Revocations}) } @@ -441,7 +443,6 @@ func readToNextPublicKey(packets *packet.Reader) (err error) { return } else if err != nil { if _, ok := err.(errors.UnsupportedError); ok { - err = nil continue } return @@ -479,6 +480,7 @@ func ReadEntity(packets *packet.Reader) (*Entity, error) { } var revocations []*packet.Signature + var directSignatures []*packet.Signature EachPacket: for { p, err := packets.Next() @@ -497,9 +499,7 @@ EachPacket: if pkt.SigType == packet.SigTypeKeyRevocation { revocations = append(revocations, pkt) } else if pkt.SigType == packet.SigTypeDirectSignature { - // TODO: RFC4880 5.2.1 permits signatures - // directly on keys (eg. to bind additional - // revocation keys). + directSignatures = append(directSignatures, pkt) } // Else, ignoring the signature as it does not follow anything // we would know to attach it to. @@ -522,12 +522,39 @@ EachPacket: return nil, err } default: - // we ignore unknown packets + // we ignore unknown packets. } } - if len(e.Identities) == 0 { - return nil, errors.StructuralError("entity without any identities") + if len(e.Identities) == 0 && e.PrimaryKey.Version < 6 { + return nil, errors.StructuralError(fmt.Sprintf("v%d entity without any identities", e.PrimaryKey.Version)) + } + + // An implementation MUST ensure that a valid direct-key signature is present before using a v6 key. + if e.PrimaryKey.Version == 6 { + if len(directSignatures) == 0 { + return nil, errors.StructuralError("v6 entity without a valid direct-key signature") + } + // Select main direct key signature. + var mainDirectKeySelfSignature *packet.Signature + for _, directSignature := range directSignatures { + if directSignature.SigType == packet.SigTypeDirectSignature && + directSignature.CheckKeyIdOrFingerprint(e.PrimaryKey) && + (mainDirectKeySelfSignature == nil || + directSignature.CreationTime.After(mainDirectKeySelfSignature.CreationTime)) { + mainDirectKeySelfSignature = directSignature + } + } + if mainDirectKeySelfSignature == nil { + return nil, errors.StructuralError("no valid direct-key self-signature for v6 primary key found") + } + // Check that the main self-signature is valid. + err = e.PrimaryKey.VerifyDirectKeySignature(mainDirectKeySelfSignature) + if err != nil { + return nil, errors.StructuralError("invalid direct-key self-signature for v6 primary key") + } + e.SelfSignature = mainDirectKeySelfSignature + e.Signatures = directSignatures } for _, revocation := range revocations { @@ -672,6 +699,12 @@ func (e *Entity) serializePrivate(w io.Writer, config *packet.Config, reSign boo return err } } + for _, directSignature := range e.Signatures { + err := directSignature.Serialize(w) + if err != nil { + return err + } + } for _, ident := range e.Identities { err = ident.UserId.Serialize(w) if err != nil { @@ -738,6 +771,12 @@ func (e *Entity) Serialize(w io.Writer) error { return err } } + for _, directSignature := range e.Signatures { + err := directSignature.Serialize(w) + if err != nil { + return err + } + } for _, ident := range e.Identities { err = ident.UserId.Serialize(w) if err != nil { @@ -840,3 +879,23 @@ func (e *Entity) RevokeSubkey(sk *Subkey, reason packet.ReasonForRevocation, rea sk.Revocations = append(sk.Revocations, revSig) return nil } + +func (e *Entity) primaryDirectSignature() *packet.Signature { + return e.SelfSignature +} + +// PrimarySelfSignature searches the entity for the self-signature that stores key preferences. +// For V4 keys, returns the self-signature of the primary identity, and the identity. +// For V6 keys, returns the latest valid direct-key self-signature, and no identity (nil). +// This self-signature is to be used to check the key expiration, +// algorithm preferences, and so on. +func (e *Entity) PrimarySelfSignature() (*packet.Signature, *Identity) { + if e.PrimaryKey.Version == 6 { + return e.primaryDirectSignature(), nil + } + primaryIdentity := e.PrimaryIdentity() + if primaryIdentity == nil { + return nil, nil + } + return primaryIdentity.SelfSignature, primaryIdentity +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/aead_crypter.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/aead_crypter.go index cee83bdc7a..2eecd062f5 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/aead_crypter.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/aead_crypter.go @@ -88,17 +88,20 @@ func (ar *aeadDecrypter) Read(dst []byte) (n int, err error) { if errRead != nil && errRead != io.EOF { return 0, errRead } - decrypted, errChunk := ar.openChunk(cipherChunk) - if errChunk != nil { - return 0, errChunk - } - // Return decrypted bytes, buffering if necessary - if len(dst) < len(decrypted) { - n = copy(dst, decrypted[:len(dst)]) - ar.buffer.Write(decrypted[len(dst):]) - } else { - n = copy(dst, decrypted) + if len(cipherChunk) > 0 { + decrypted, errChunk := ar.openChunk(cipherChunk) + if errChunk != nil { + return 0, errChunk + } + + // Return decrypted bytes, buffering if necessary + if len(dst) < len(decrypted) { + n = copy(dst, decrypted[:len(dst)]) + ar.buffer.Write(decrypted[len(dst):]) + } else { + n = copy(dst, decrypted) + } } // Check final authentication tag @@ -116,6 +119,12 @@ func (ar *aeadDecrypter) Read(dst []byte) (n int, err error) { // checked in the last Read call. In the future, this function could be used to // wipe the reader and peeked, decrypted bytes, if necessary. func (ar *aeadDecrypter) Close() (err error) { + if !ar.eof { + errChunk := ar.validateFinalTag(ar.peekedBytes) + if errChunk != nil { + return errChunk + } + } return nil } @@ -138,7 +147,7 @@ func (ar *aeadDecrypter) openChunk(data []byte) ([]byte, error) { nonce := ar.computeNextNonce() plainChunk, err := ar.aead.Open(nil, nonce, chunk, adata) if err != nil { - return nil, err + return nil, errors.ErrAEADTagVerification } ar.bytesProcessed += len(plainChunk) if err = ar.aeadCrypter.incrementIndex(); err != nil { @@ -163,9 +172,8 @@ func (ar *aeadDecrypter) validateFinalTag(tag []byte) error { // ... and total number of encrypted octets adata = append(adata, amountBytes...) nonce := ar.computeNextNonce() - _, err := ar.aead.Open(nil, nonce, tag, adata) - if err != nil { - return err + if _, err := ar.aead.Open(nil, nonce, tag, adata); err != nil { + return errors.ErrAEADTagVerification } return nil } diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/compressed.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/compressed.go index 2f5cad71da..0bcb38caca 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/compressed.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/compressed.go @@ -8,9 +8,10 @@ import ( "compress/bzip2" "compress/flate" "compress/zlib" - "github.com/ProtonMail/go-crypto/openpgp/errors" "io" "strconv" + + "github.com/ProtonMail/go-crypto/openpgp/errors" ) // Compressed represents a compressed OpenPGP packet. The decompressed contents @@ -39,6 +40,37 @@ type CompressionConfig struct { Level int } +// decompressionReader ensures that the whole compression packet is read. +type decompressionReader struct { + compressed io.Reader + decompressed io.ReadCloser + readAll bool +} + +func newDecompressionReader(r io.Reader, decompressor io.ReadCloser) *decompressionReader { + return &decompressionReader{ + compressed: r, + decompressed: decompressor, + } +} + +func (dr *decompressionReader) Read(data []byte) (n int, err error) { + if dr.readAll { + return 0, io.EOF + } + n, err = dr.decompressed.Read(data) + if err == io.EOF { + dr.readAll = true + // Close the decompressor. + if errDec := dr.decompressed.Close(); errDec != nil { + return n, errDec + } + // Consume all remaining data from the compressed packet. + consumeAll(dr.compressed) + } + return n, err +} + func (c *Compressed) parse(r io.Reader) error { var buf [1]byte _, err := readFull(r, buf[:]) @@ -50,11 +82,15 @@ func (c *Compressed) parse(r io.Reader) error { case 0: c.Body = r case 1: - c.Body = flate.NewReader(r) + c.Body = newDecompressionReader(r, flate.NewReader(r)) case 2: - c.Body, err = zlib.NewReader(r) + decompressor, err := zlib.NewReader(r) + if err != nil { + return err + } + c.Body = newDecompressionReader(r, decompressor) case 3: - c.Body = bzip2.NewReader(r) + c.Body = newDecompressionReader(r, io.NopCloser(bzip2.NewReader(r))) default: err = errors.UnsupportedError("unknown compression algorithm: " + strconv.Itoa(int(buf[0]))) } diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/config.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/config.go index 04994bec97..8bf8e6e51f 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/config.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/config.go @@ -14,6 +14,34 @@ import ( "github.com/ProtonMail/go-crypto/openpgp/s2k" ) +var ( + defaultRejectPublicKeyAlgorithms = map[PublicKeyAlgorithm]bool{ + PubKeyAlgoElGamal: true, + PubKeyAlgoDSA: true, + } + defaultRejectHashAlgorithms = map[crypto.Hash]bool{ + crypto.MD5: true, + crypto.RIPEMD160: true, + } + defaultRejectMessageHashAlgorithms = map[crypto.Hash]bool{ + crypto.SHA1: true, + crypto.MD5: true, + crypto.RIPEMD160: true, + } + defaultRejectCurves = map[Curve]bool{ + CurveSecP256k1: true, + } +) + +// A global feature flag to indicate v5 support. +// Can be set via a build tag, e.g.: `go build -tags v5 ./...` +// If the build tag is missing config_v5.go will set it to true. +// +// Disables parsing of v5 keys and v5 signatures. +// These are non-standard entities, which in the crypto-refresh have been superseded +// by v6 keys, v6 signatures and SEIPDv2 encrypted data, respectively. +var V5Disabled = false + // Config collects a number of parameters along with sensible defaults. // A nil *Config is valid and results in all default values. type Config struct { @@ -73,9 +101,16 @@ type Config struct { // **Note: using this option may break compatibility with other OpenPGP // implementations, as well as future versions of this library.** AEADConfig *AEADConfig - // V5Keys configures version 5 key generation. If false, this package still - // supports version 5 keys, but produces version 4 keys. - V5Keys bool + // V6Keys configures version 6 key generation. If false, this package still + // supports version 6 keys, but produces version 4 keys. + V6Keys bool + // Minimum RSA key size allowed for key generation and message signing, verification and encryption. + MinRSABits uint16 + // Reject insecure algorithms, only works with v2 api + RejectPublicKeyAlgorithms map[PublicKeyAlgorithm]bool + RejectHashAlgorithms map[crypto.Hash]bool + RejectMessageHashAlgorithms map[crypto.Hash]bool + RejectCurves map[Curve]bool // "The validity period of the key. This is the number of seconds after // the key creation time that the key expires. If this is not present // or has a value of zero, the key never expires. This is found only on @@ -104,12 +139,40 @@ type Config struct { // might be no other way than to tolerate the missing MDC. Setting this flag, allows this // mode of operation. It should be considered a measure of last resort. InsecureAllowUnauthenticatedMessages bool + // InsecureAllowDecryptionWithSigningKeys allows decryption with keys marked as signing keys in the v2 API. + // This setting is potentially insecure, but it is needed as some libraries + // ignored key flags when selecting a key for encryption. + // Not relevant for the v1 API, as all keys were allowed in decryption. + InsecureAllowDecryptionWithSigningKeys bool // KnownNotations is a map of Notation Data names to bools, which controls // the notation names that are allowed to be present in critical Notation Data // signature subpackets. KnownNotations map[string]bool // SignatureNotations is a list of Notations to be added to any signatures. SignatureNotations []*Notation + // CheckIntendedRecipients controls, whether the OpenPGP Intended Recipient Fingerprint feature + // should be enabled for encryption and decryption. + // (See https://www.ietf.org/archive/id/draft-ietf-openpgp-crypto-refresh-12.html#name-intended-recipient-fingerpr). + // When the flag is set, encryption produces Intended Recipient Fingerprint signature sub-packets and decryption + // checks whether the key it was encrypted to is one of the included fingerprints in the signature. + // If the flag is disabled, no Intended Recipient Fingerprint sub-packets are created or checked. + // The default behavior, when the config or flag is nil, is to enable the feature. + CheckIntendedRecipients *bool + // CacheSessionKey controls if decryption should return the session key used for decryption. + // If the flag is set, the session key is cached in the message details struct. + CacheSessionKey bool + // CheckPacketSequence is a flag that controls if the pgp message reader should strictly check + // that the packet sequence conforms with the grammar mandated by rfc4880. + // The default behavior, when the config or flag is nil, is to check the packet sequence. + CheckPacketSequence *bool + // NonDeterministicSignaturesViaNotation is a flag to enable randomization of signatures. + // If true, a salt notation is used to randomize signatures generated by v4 and v5 keys + // (v6 signatures are always non-deterministic, by design). + // This protects EdDSA signatures from potentially leaking the secret key in case of faults (i.e. bitflips) which, in principle, could occur + // during the signing computation. It is added to signatures of any algo for simplicity, and as it may also serve as protection in case of + // weaknesses in the hash algo, potentially hindering e.g. some chosen-prefix attacks. + // The default behavior, when the config or flag is nil, is to enable the feature. + NonDeterministicSignaturesViaNotation *bool } func (c *Config) Random() io.Reader { @@ -197,7 +260,7 @@ func (c *Config) S2K() *s2k.Config { return nil } // for backwards compatibility - if c != nil && c.S2KCount > 0 && c.S2KConfig == nil { + if c.S2KCount > 0 && c.S2KConfig == nil { return &s2k.Config{ S2KCount: c.S2KCount, } @@ -233,6 +296,13 @@ func (c *Config) AllowUnauthenticatedMessages() bool { return c.InsecureAllowUnauthenticatedMessages } +func (c *Config) AllowDecryptionWithSigningKeys() bool { + if c == nil { + return false + } + return c.InsecureAllowDecryptionWithSigningKeys +} + func (c *Config) KnownNotation(notationName string) bool { if c == nil { return false @@ -246,3 +316,95 @@ func (c *Config) Notations() []*Notation { } return c.SignatureNotations } + +func (c *Config) V6() bool { + if c == nil { + return false + } + return c.V6Keys +} + +func (c *Config) IntendedRecipients() bool { + if c == nil || c.CheckIntendedRecipients == nil { + return true + } + return *c.CheckIntendedRecipients +} + +func (c *Config) RetrieveSessionKey() bool { + if c == nil { + return false + } + return c.CacheSessionKey +} + +func (c *Config) MinimumRSABits() uint16 { + if c == nil || c.MinRSABits == 0 { + return 2047 + } + return c.MinRSABits +} + +func (c *Config) RejectPublicKeyAlgorithm(alg PublicKeyAlgorithm) bool { + var rejectedAlgorithms map[PublicKeyAlgorithm]bool + if c == nil || c.RejectPublicKeyAlgorithms == nil { + // Default + rejectedAlgorithms = defaultRejectPublicKeyAlgorithms + } else { + rejectedAlgorithms = c.RejectPublicKeyAlgorithms + } + return rejectedAlgorithms[alg] +} + +func (c *Config) RejectHashAlgorithm(hash crypto.Hash) bool { + var rejectedAlgorithms map[crypto.Hash]bool + if c == nil || c.RejectHashAlgorithms == nil { + // Default + rejectedAlgorithms = defaultRejectHashAlgorithms + } else { + rejectedAlgorithms = c.RejectHashAlgorithms + } + return rejectedAlgorithms[hash] +} + +func (c *Config) RejectMessageHashAlgorithm(hash crypto.Hash) bool { + var rejectedAlgorithms map[crypto.Hash]bool + if c == nil || c.RejectMessageHashAlgorithms == nil { + // Default + rejectedAlgorithms = defaultRejectMessageHashAlgorithms + } else { + rejectedAlgorithms = c.RejectMessageHashAlgorithms + } + return rejectedAlgorithms[hash] +} + +func (c *Config) RejectCurve(curve Curve) bool { + var rejectedCurve map[Curve]bool + if c == nil || c.RejectCurves == nil { + // Default + rejectedCurve = defaultRejectCurves + } else { + rejectedCurve = c.RejectCurves + } + return rejectedCurve[curve] +} + +func (c *Config) StrictPacketSequence() bool { + if c == nil || c.CheckPacketSequence == nil { + return true + } + return *c.CheckPacketSequence +} + +func (c *Config) RandomizeSignaturesViaNotation() bool { + if c == nil || c.NonDeterministicSignaturesViaNotation == nil { + return true + } + return *c.NonDeterministicSignaturesViaNotation +} + +// BoolPointer is a helper function to set a boolean pointer in the Config. +// e.g., config.CheckPacketSequence = BoolPointer(true) +func BoolPointer(value bool) *bool { + return &value +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/config_v5.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/config_v5.go new file mode 100644 index 0000000000..f2415906b9 --- /dev/null +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/config_v5.go @@ -0,0 +1,7 @@ +//go:build !v5 + +package packet + +func init() { + V5Disabled = true +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/encrypted_key.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/encrypted_key.go index eeff2902c1..b90bb28911 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/encrypted_key.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/encrypted_key.go @@ -5,9 +5,11 @@ package packet import ( + "bytes" "crypto" "crypto/rsa" "encoding/binary" + "encoding/hex" "io" "math/big" "strconv" @@ -16,32 +18,85 @@ import ( "github.com/ProtonMail/go-crypto/openpgp/elgamal" "github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/internal/encoding" + "github.com/ProtonMail/go-crypto/openpgp/x25519" + "github.com/ProtonMail/go-crypto/openpgp/x448" ) -const encryptedKeyVersion = 3 - // EncryptedKey represents a public-key encrypted session key. See RFC 4880, // section 5.1. type EncryptedKey struct { - KeyId uint64 - Algo PublicKeyAlgorithm - CipherFunc CipherFunction // only valid after a successful Decrypt for a v3 packet - Key []byte // only valid after a successful Decrypt + Version int + KeyId uint64 + KeyVersion int // v6 + KeyFingerprint []byte // v6 + Algo PublicKeyAlgorithm + CipherFunc CipherFunction // only valid after a successful Decrypt for a v3 packet + Key []byte // only valid after a successful Decrypt encryptedMPI1, encryptedMPI2 encoding.Field + ephemeralPublicX25519 *x25519.PublicKey // used for x25519 + ephemeralPublicX448 *x448.PublicKey // used for x448 + encryptedSession []byte // used for x25519 and x448 } func (e *EncryptedKey) parse(r io.Reader) (err error) { - var buf [10]byte - _, err = readFull(r, buf[:]) + var buf [8]byte + _, err = readFull(r, buf[:versionSize]) if err != nil { return } - if buf[0] != encryptedKeyVersion { + e.Version = int(buf[0]) + if e.Version != 3 && e.Version != 6 { return errors.UnsupportedError("unknown EncryptedKey version " + strconv.Itoa(int(buf[0]))) } - e.KeyId = binary.BigEndian.Uint64(buf[1:9]) - e.Algo = PublicKeyAlgorithm(buf[9]) + if e.Version == 6 { + //Read a one-octet size of the following two fields. + if _, err = readFull(r, buf[:1]); err != nil { + return + } + // The size may also be zero, and the key version and + // fingerprint omitted for an "anonymous recipient" + if buf[0] != 0 { + // non-anonymous case + _, err = readFull(r, buf[:versionSize]) + if err != nil { + return + } + e.KeyVersion = int(buf[0]) + if e.KeyVersion != 4 && e.KeyVersion != 6 { + return errors.UnsupportedError("unknown public key version " + strconv.Itoa(e.KeyVersion)) + } + var fingerprint []byte + if e.KeyVersion == 6 { + fingerprint = make([]byte, fingerprintSizeV6) + } else if e.KeyVersion == 4 { + fingerprint = make([]byte, fingerprintSize) + } + _, err = readFull(r, fingerprint) + if err != nil { + return + } + e.KeyFingerprint = fingerprint + if e.KeyVersion == 6 { + e.KeyId = binary.BigEndian.Uint64(e.KeyFingerprint[:keyIdSize]) + } else if e.KeyVersion == 4 { + e.KeyId = binary.BigEndian.Uint64(e.KeyFingerprint[fingerprintSize-keyIdSize : fingerprintSize]) + } + } + } else { + _, err = readFull(r, buf[:8]) + if err != nil { + return + } + e.KeyId = binary.BigEndian.Uint64(buf[:keyIdSize]) + } + + _, err = readFull(r, buf[:1]) + if err != nil { + return + } + e.Algo = PublicKeyAlgorithm(buf[0]) + var cipherFunction byte switch e.Algo { case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly: e.encryptedMPI1 = new(encoding.MPI) @@ -68,26 +123,39 @@ func (e *EncryptedKey) parse(r io.Reader) (err error) { if _, err = e.encryptedMPI2.ReadFrom(r); err != nil { return } + case PubKeyAlgoX25519: + e.ephemeralPublicX25519, e.encryptedSession, cipherFunction, err = x25519.DecodeFields(r, e.Version == 6) + if err != nil { + return + } + case PubKeyAlgoX448: + e.ephemeralPublicX448, e.encryptedSession, cipherFunction, err = x448.DecodeFields(r, e.Version == 6) + if err != nil { + return + } + } + if e.Version < 6 { + switch e.Algo { + case PubKeyAlgoX25519, PubKeyAlgoX448: + e.CipherFunc = CipherFunction(cipherFunction) + // Check for validiy is in the Decrypt method + } } + _, err = consumeAll(r) return } -func checksumKeyMaterial(key []byte) uint16 { - var checksum uint16 - for _, v := range key { - checksum += uint16(v) - } - return checksum -} - // Decrypt decrypts an encrypted session key with the given private key. The // private key must have been decrypted first. // If config is nil, sensible defaults will be used. func (e *EncryptedKey) Decrypt(priv *PrivateKey, config *Config) error { - if e.KeyId != 0 && e.KeyId != priv.KeyId { + if e.Version < 6 && e.KeyId != 0 && e.KeyId != priv.KeyId { return errors.InvalidArgumentError("cannot decrypt encrypted session key for key id " + strconv.FormatUint(e.KeyId, 16) + " with private key id " + strconv.FormatUint(priv.KeyId, 16)) } + if e.Version == 6 && e.KeyVersion != 0 && !bytes.Equal(e.KeyFingerprint, priv.Fingerprint) { + return errors.InvalidArgumentError("cannot decrypt encrypted session key for key fingerprint " + hex.EncodeToString(e.KeyFingerprint) + " with private key fingerprint " + hex.EncodeToString(priv.Fingerprint)) + } if e.Algo != priv.PubKeyAlgo { return errors.InvalidArgumentError("cannot decrypt encrypted session key of type " + strconv.Itoa(int(e.Algo)) + " with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo))) } @@ -113,52 +181,116 @@ func (e *EncryptedKey) Decrypt(priv *PrivateKey, config *Config) error { vsG := e.encryptedMPI1.Bytes() m := e.encryptedMPI2.Bytes() oid := priv.PublicKey.oid.EncodedBytes() - b, err = ecdh.Decrypt(priv.PrivateKey.(*ecdh.PrivateKey), vsG, m, oid, priv.PublicKey.Fingerprint[:]) + fp := priv.PublicKey.Fingerprint[:] + if priv.PublicKey.Version == 5 { + // For v5 the, the fingerprint must be restricted to 20 bytes + fp = fp[:20] + } + b, err = ecdh.Decrypt(priv.PrivateKey.(*ecdh.PrivateKey), vsG, m, oid, fp) + case PubKeyAlgoX25519: + b, err = x25519.Decrypt(priv.PrivateKey.(*x25519.PrivateKey), e.ephemeralPublicX25519, e.encryptedSession) + case PubKeyAlgoX448: + b, err = x448.Decrypt(priv.PrivateKey.(*x448.PrivateKey), e.ephemeralPublicX448, e.encryptedSession) default: err = errors.InvalidArgumentError("cannot decrypt encrypted session key with private key of type " + strconv.Itoa(int(priv.PubKeyAlgo))) } - if err != nil { return err } - e.CipherFunc = CipherFunction(b[0]) - if !e.CipherFunc.IsSupported() { - return errors.UnsupportedError("unsupported encryption function") - } - - e.Key = b[1 : len(b)-2] - expectedChecksum := uint16(b[len(b)-2])<<8 | uint16(b[len(b)-1]) - checksum := checksumKeyMaterial(e.Key) - if checksum != expectedChecksum { - return errors.StructuralError("EncryptedKey checksum incorrect") + var key []byte + switch priv.PubKeyAlgo { + case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH: + keyOffset := 0 + if e.Version < 6 { + e.CipherFunc = CipherFunction(b[0]) + keyOffset = 1 + if !e.CipherFunc.IsSupported() { + return errors.UnsupportedError("unsupported encryption function") + } + } + key, err = decodeChecksumKey(b[keyOffset:]) + if err != nil { + return err + } + case PubKeyAlgoX25519, PubKeyAlgoX448: + if e.Version < 6 { + switch e.CipherFunc { + case CipherAES128, CipherAES192, CipherAES256: + break + default: + return errors.StructuralError("v3 PKESK mandates AES as cipher function for x25519 and x448") + } + } + key = b[:] + default: + return errors.UnsupportedError("unsupported algorithm for decryption") } - + e.Key = key return nil } // Serialize writes the encrypted key packet, e, to w. func (e *EncryptedKey) Serialize(w io.Writer) error { - var mpiLen int + var encodedLength int switch e.Algo { case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly: - mpiLen = int(e.encryptedMPI1.EncodedLength()) + encodedLength = int(e.encryptedMPI1.EncodedLength()) case PubKeyAlgoElGamal: - mpiLen = int(e.encryptedMPI1.EncodedLength()) + int(e.encryptedMPI2.EncodedLength()) + encodedLength = int(e.encryptedMPI1.EncodedLength()) + int(e.encryptedMPI2.EncodedLength()) case PubKeyAlgoECDH: - mpiLen = int(e.encryptedMPI1.EncodedLength()) + int(e.encryptedMPI2.EncodedLength()) + encodedLength = int(e.encryptedMPI1.EncodedLength()) + int(e.encryptedMPI2.EncodedLength()) + case PubKeyAlgoX25519: + encodedLength = x25519.EncodedFieldsLength(e.encryptedSession, e.Version == 6) + case PubKeyAlgoX448: + encodedLength = x448.EncodedFieldsLength(e.encryptedSession, e.Version == 6) default: return errors.InvalidArgumentError("don't know how to serialize encrypted key type " + strconv.Itoa(int(e.Algo))) } - err := serializeHeader(w, packetTypeEncryptedKey, 1 /* version */ +8 /* key id */ +1 /* algo */ +mpiLen) + packetLen := versionSize /* version */ + keyIdSize /* key id */ + algorithmSize /* algo */ + encodedLength + if e.Version == 6 { + packetLen = versionSize /* version */ + algorithmSize /* algo */ + encodedLength + keyVersionSize /* key version */ + if e.KeyVersion == 6 { + packetLen += fingerprintSizeV6 + } else if e.KeyVersion == 4 { + packetLen += fingerprintSize + } + } + + err := serializeHeader(w, packetTypeEncryptedKey, packetLen) if err != nil { return err } - w.Write([]byte{encryptedKeyVersion}) - binary.Write(w, binary.BigEndian, e.KeyId) - w.Write([]byte{byte(e.Algo)}) + _, err = w.Write([]byte{byte(e.Version)}) + if err != nil { + return err + } + if e.Version == 6 { + _, err = w.Write([]byte{byte(e.KeyVersion)}) + if err != nil { + return err + } + // The key version number may also be zero, + // and the fingerprint omitted + if e.KeyVersion != 0 { + _, err = w.Write(e.KeyFingerprint) + if err != nil { + return err + } + } + } else { + // Write KeyID + err = binary.Write(w, binary.BigEndian, e.KeyId) + if err != nil { + return err + } + } + _, err = w.Write([]byte{byte(e.Algo)}) + if err != nil { + return err + } switch e.Algo { case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly: @@ -176,34 +308,115 @@ func (e *EncryptedKey) Serialize(w io.Writer) error { } _, err := w.Write(e.encryptedMPI2.EncodedBytes()) return err + case PubKeyAlgoX25519: + err := x25519.EncodeFields(w, e.ephemeralPublicX25519, e.encryptedSession, byte(e.CipherFunc), e.Version == 6) + return err + case PubKeyAlgoX448: + err := x448.EncodeFields(w, e.ephemeralPublicX448, e.encryptedSession, byte(e.CipherFunc), e.Version == 6) + return err default: panic("internal error") } } -// SerializeEncryptedKey serializes an encrypted key packet to w that contains +// SerializeEncryptedKeyAEAD serializes an encrypted key packet to w that contains // key, encrypted to pub. +// If aeadSupported is set, PKESK v6 is used, otherwise v3. +// Note: aeadSupported MUST match the value passed to SerializeSymmetricallyEncrypted. // If config is nil, sensible defaults will be used. -func SerializeEncryptedKey(w io.Writer, pub *PublicKey, cipherFunc CipherFunction, key []byte, config *Config) error { - var buf [10]byte - buf[0] = encryptedKeyVersion - binary.BigEndian.PutUint64(buf[1:9], pub.KeyId) - buf[9] = byte(pub.PubKeyAlgo) - - keyBlock := make([]byte, 1 /* cipher type */ +len(key)+2 /* checksum */) - keyBlock[0] = byte(cipherFunc) - copy(keyBlock[1:], key) - checksum := checksumKeyMaterial(key) - keyBlock[1+len(key)] = byte(checksum >> 8) - keyBlock[1+len(key)+1] = byte(checksum) +func SerializeEncryptedKeyAEAD(w io.Writer, pub *PublicKey, cipherFunc CipherFunction, aeadSupported bool, key []byte, config *Config) error { + return SerializeEncryptedKeyAEADwithHiddenOption(w, pub, cipherFunc, aeadSupported, key, false, config) +} + +// SerializeEncryptedKeyAEADwithHiddenOption serializes an encrypted key packet to w that contains +// key, encrypted to pub. +// Offers the hidden flag option to indicated if the PKESK packet should include a wildcard KeyID. +// If aeadSupported is set, PKESK v6 is used, otherwise v3. +// Note: aeadSupported MUST match the value passed to SerializeSymmetricallyEncrypted. +// If config is nil, sensible defaults will be used. +func SerializeEncryptedKeyAEADwithHiddenOption(w io.Writer, pub *PublicKey, cipherFunc CipherFunction, aeadSupported bool, key []byte, hidden bool, config *Config) error { + var buf [36]byte // max possible header size is v6 + lenHeaderWritten := versionSize + version := 3 + + if aeadSupported { + version = 6 + } + // An implementation MUST NOT generate ElGamal v6 PKESKs. + if version == 6 && pub.PubKeyAlgo == PubKeyAlgoElGamal { + return errors.InvalidArgumentError("ElGamal v6 PKESK are not allowed") + } + // In v3 PKESKs, for x25519 and x448, mandate using AES + if version == 3 && (pub.PubKeyAlgo == PubKeyAlgoX25519 || pub.PubKeyAlgo == PubKeyAlgoX448) { + switch cipherFunc { + case CipherAES128, CipherAES192, CipherAES256: + break + default: + return errors.InvalidArgumentError("v3 PKESK mandates AES for x25519 and x448") + } + } + + buf[0] = byte(version) + + // If hidden is set, the key should be hidden + // An implementation MAY accept or use a Key ID of all zeros, + // or a key version of zero and no key fingerprint, to hide the intended decryption key. + // See Section 5.1.8. in the open pgp crypto refresh + if version == 6 { + if !hidden { + // A one-octet size of the following two fields. + buf[1] = byte(keyVersionSize + len(pub.Fingerprint)) + // A one octet key version number. + buf[2] = byte(pub.Version) + lenHeaderWritten += keyVersionSize + 1 + // The fingerprint of the public key + copy(buf[lenHeaderWritten:lenHeaderWritten+len(pub.Fingerprint)], pub.Fingerprint) + lenHeaderWritten += len(pub.Fingerprint) + } else { + // The size may also be zero, and the key version + // and fingerprint omitted for an "anonymous recipient" + buf[1] = 0 + lenHeaderWritten += 1 + } + } else { + if !hidden { + binary.BigEndian.PutUint64(buf[versionSize:(versionSize+keyIdSize)], pub.KeyId) + } + lenHeaderWritten += keyIdSize + } + buf[lenHeaderWritten] = byte(pub.PubKeyAlgo) + lenHeaderWritten += algorithmSize + + var keyBlock []byte + switch pub.PubKeyAlgo { + case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH: + lenKeyBlock := len(key) + 2 + if version < 6 { + lenKeyBlock += 1 // cipher type included + } + keyBlock = make([]byte, lenKeyBlock) + keyOffset := 0 + if version < 6 { + keyBlock[0] = byte(cipherFunc) + keyOffset = 1 + } + encodeChecksumKey(keyBlock[keyOffset:], key) + case PubKeyAlgoX25519, PubKeyAlgoX448: + // algorithm is added in plaintext below + keyBlock = key + } switch pub.PubKeyAlgo { case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly: - return serializeEncryptedKeyRSA(w, config.Random(), buf, pub.PublicKey.(*rsa.PublicKey), keyBlock) + return serializeEncryptedKeyRSA(w, config.Random(), buf[:lenHeaderWritten], pub.PublicKey.(*rsa.PublicKey), keyBlock) case PubKeyAlgoElGamal: - return serializeEncryptedKeyElGamal(w, config.Random(), buf, pub.PublicKey.(*elgamal.PublicKey), keyBlock) + return serializeEncryptedKeyElGamal(w, config.Random(), buf[:lenHeaderWritten], pub.PublicKey.(*elgamal.PublicKey), keyBlock) case PubKeyAlgoECDH: - return serializeEncryptedKeyECDH(w, config.Random(), buf, pub.PublicKey.(*ecdh.PublicKey), keyBlock, pub.oid, pub.Fingerprint) + return serializeEncryptedKeyECDH(w, config.Random(), buf[:lenHeaderWritten], pub.PublicKey.(*ecdh.PublicKey), keyBlock, pub.oid, pub.Fingerprint) + case PubKeyAlgoX25519: + return serializeEncryptedKeyX25519(w, config.Random(), buf[:lenHeaderWritten], pub.PublicKey.(*x25519.PublicKey), keyBlock, byte(cipherFunc), version) + case PubKeyAlgoX448: + return serializeEncryptedKeyX448(w, config.Random(), buf[:lenHeaderWritten], pub.PublicKey.(*x448.PublicKey), keyBlock, byte(cipherFunc), version) case PubKeyAlgoDSA, PubKeyAlgoRSASignOnly: return errors.InvalidArgumentError("cannot encrypt to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo))) } @@ -211,14 +424,32 @@ func SerializeEncryptedKey(w io.Writer, pub *PublicKey, cipherFunc CipherFunctio return errors.UnsupportedError("encrypting a key to public key of type " + strconv.Itoa(int(pub.PubKeyAlgo))) } -func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub *rsa.PublicKey, keyBlock []byte) error { +// SerializeEncryptedKey serializes an encrypted key packet to w that contains +// key, encrypted to pub. +// PKESKv6 is used if config.AEAD() is not nil. +// If config is nil, sensible defaults will be used. +// Deprecated: Use SerializeEncryptedKeyAEAD instead. +func SerializeEncryptedKey(w io.Writer, pub *PublicKey, cipherFunc CipherFunction, key []byte, config *Config) error { + return SerializeEncryptedKeyAEAD(w, pub, cipherFunc, config.AEAD() != nil, key, config) +} + +// SerializeEncryptedKeyWithHiddenOption serializes an encrypted key packet to w that contains +// key, encrypted to pub. PKESKv6 is used if config.AEAD() is not nil. +// The hidden option controls if the packet should be anonymous, i.e., omit key metadata. +// If config is nil, sensible defaults will be used. +// Deprecated: Use SerializeEncryptedKeyAEADwithHiddenOption instead. +func SerializeEncryptedKeyWithHiddenOption(w io.Writer, pub *PublicKey, cipherFunc CipherFunction, key []byte, hidden bool, config *Config) error { + return SerializeEncryptedKeyAEADwithHiddenOption(w, pub, cipherFunc, config.AEAD() != nil, key, hidden, config) +} + +func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header []byte, pub *rsa.PublicKey, keyBlock []byte) error { cipherText, err := rsa.EncryptPKCS1v15(rand, pub, keyBlock) if err != nil { return errors.InvalidArgumentError("RSA encryption failed: " + err.Error()) } cipherMPI := encoding.NewMPI(cipherText) - packetLen := 10 /* header length */ + int(cipherMPI.EncodedLength()) + packetLen := len(header) /* header length */ + int(cipherMPI.EncodedLength()) err = serializeHeader(w, packetTypeEncryptedKey, packetLen) if err != nil { @@ -232,13 +463,13 @@ func serializeEncryptedKeyRSA(w io.Writer, rand io.Reader, header [10]byte, pub return err } -func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte, pub *elgamal.PublicKey, keyBlock []byte) error { +func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header []byte, pub *elgamal.PublicKey, keyBlock []byte) error { c1, c2, err := elgamal.Encrypt(rand, pub, keyBlock) if err != nil { return errors.InvalidArgumentError("ElGamal encryption failed: " + err.Error()) } - packetLen := 10 /* header length */ + packetLen := len(header) /* header length */ packetLen += 2 /* mpi size */ + (c1.BitLen()+7)/8 packetLen += 2 /* mpi size */ + (c2.BitLen()+7)/8 @@ -257,7 +488,7 @@ func serializeEncryptedKeyElGamal(w io.Writer, rand io.Reader, header [10]byte, return err } -func serializeEncryptedKeyECDH(w io.Writer, rand io.Reader, header [10]byte, pub *ecdh.PublicKey, keyBlock []byte, oid encoding.Field, fingerprint []byte) error { +func serializeEncryptedKeyECDH(w io.Writer, rand io.Reader, header []byte, pub *ecdh.PublicKey, keyBlock []byte, oid encoding.Field, fingerprint []byte) error { vsG, c, err := ecdh.Encrypt(rand, pub, keyBlock, oid.EncodedBytes(), fingerprint) if err != nil { return errors.InvalidArgumentError("ECDH encryption failed: " + err.Error()) @@ -266,7 +497,7 @@ func serializeEncryptedKeyECDH(w io.Writer, rand io.Reader, header [10]byte, pub g := encoding.NewMPI(vsG) m := encoding.NewOID(c) - packetLen := 10 /* header length */ + packetLen := len(header) /* header length */ packetLen += int(g.EncodedLength()) + int(m.EncodedLength()) err = serializeHeader(w, packetTypeEncryptedKey, packetLen) @@ -284,3 +515,70 @@ func serializeEncryptedKeyECDH(w io.Writer, rand io.Reader, header [10]byte, pub _, err = w.Write(m.EncodedBytes()) return err } + +func serializeEncryptedKeyX25519(w io.Writer, rand io.Reader, header []byte, pub *x25519.PublicKey, keyBlock []byte, cipherFunc byte, version int) error { + ephemeralPublicX25519, ciphertext, err := x25519.Encrypt(rand, pub, keyBlock) + if err != nil { + return errors.InvalidArgumentError("x25519 encryption failed: " + err.Error()) + } + + packetLen := len(header) /* header length */ + packetLen += x25519.EncodedFieldsLength(ciphertext, version == 6) + + err = serializeHeader(w, packetTypeEncryptedKey, packetLen) + if err != nil { + return err + } + + _, err = w.Write(header[:]) + if err != nil { + return err + } + return x25519.EncodeFields(w, ephemeralPublicX25519, ciphertext, cipherFunc, version == 6) +} + +func serializeEncryptedKeyX448(w io.Writer, rand io.Reader, header []byte, pub *x448.PublicKey, keyBlock []byte, cipherFunc byte, version int) error { + ephemeralPublicX448, ciphertext, err := x448.Encrypt(rand, pub, keyBlock) + if err != nil { + return errors.InvalidArgumentError("x448 encryption failed: " + err.Error()) + } + + packetLen := len(header) /* header length */ + packetLen += x448.EncodedFieldsLength(ciphertext, version == 6) + + err = serializeHeader(w, packetTypeEncryptedKey, packetLen) + if err != nil { + return err + } + + _, err = w.Write(header[:]) + if err != nil { + return err + } + return x448.EncodeFields(w, ephemeralPublicX448, ciphertext, cipherFunc, version == 6) +} + +func checksumKeyMaterial(key []byte) uint16 { + var checksum uint16 + for _, v := range key { + checksum += uint16(v) + } + return checksum +} + +func decodeChecksumKey(msg []byte) (key []byte, err error) { + key = msg[:len(msg)-2] + expectedChecksum := uint16(msg[len(msg)-2])<<8 | uint16(msg[len(msg)-1]) + checksum := checksumKeyMaterial(key) + if checksum != expectedChecksum { + err = errors.StructuralError("session key checksum is incorrect") + } + return +} + +func encodeChecksumKey(buffer []byte, key []byte) { + copy(buffer, key) + checksum := checksumKeyMaterial(key) + buffer[len(key)] = byte(checksum >> 8) + buffer[len(key)+1] = byte(checksum) +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/literal.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/literal.go index 4be987609b..8a028c8a17 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/literal.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/literal.go @@ -58,9 +58,9 @@ func (l *LiteralData) parse(r io.Reader) (err error) { // on completion. The fileName is truncated to 255 bytes. func SerializeLiteral(w io.WriteCloser, isBinary bool, fileName string, time uint32) (plaintext io.WriteCloser, err error) { var buf [4]byte - buf[0] = 't' - if isBinary { - buf[0] = 'b' + buf[0] = 'b' + if !isBinary { + buf[0] = 'u' } if len(fileName) > 255 { fileName = fileName[:255] diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/marker.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/marker.go new file mode 100644 index 0000000000..1ee378ba3c --- /dev/null +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/marker.go @@ -0,0 +1,33 @@ +package packet + +import ( + "io" + + "github.com/ProtonMail/go-crypto/openpgp/errors" +) + +type Marker struct{} + +const markerString = "PGP" + +// parse just checks if the packet contains "PGP". +func (m *Marker) parse(reader io.Reader) error { + var buffer [3]byte + if _, err := io.ReadFull(reader, buffer[:]); err != nil { + return err + } + if string(buffer[:]) != markerString { + return errors.StructuralError("invalid marker packet") + } + return nil +} + +// SerializeMarker writes a marker packet to writer. +func SerializeMarker(writer io.Writer) error { + err := serializeHeader(writer, packetTypeMarker, len(markerString)) + if err != nil { + return err + } + _, err = writer.Write([]byte(markerString)) + return err +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/one_pass_signature.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/one_pass_signature.go index 033fb2d7e8..f393c4063b 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/one_pass_signature.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/one_pass_signature.go @@ -7,34 +7,37 @@ package packet import ( "crypto" "encoding/binary" - "github.com/ProtonMail/go-crypto/openpgp/errors" - "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "io" "strconv" + + "github.com/ProtonMail/go-crypto/openpgp/errors" + "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" ) // OnePassSignature represents a one-pass signature packet. See RFC 4880, // section 5.4. type OnePassSignature struct { - SigType SignatureType - Hash crypto.Hash - PubKeyAlgo PublicKeyAlgorithm - KeyId uint64 - IsLast bool + Version int + SigType SignatureType + Hash crypto.Hash + PubKeyAlgo PublicKeyAlgorithm + KeyId uint64 + IsLast bool + Salt []byte // v6 only + KeyFingerprint []byte // v6 only } -const onePassSignatureVersion = 3 - func (ops *OnePassSignature) parse(r io.Reader) (err error) { - var buf [13]byte - - _, err = readFull(r, buf[:]) + var buf [8]byte + // Read: version | signature type | hash algorithm | public-key algorithm + _, err = readFull(r, buf[:4]) if err != nil { return } - if buf[0] != onePassSignatureVersion { - err = errors.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0]))) + if buf[0] != 3 && buf[0] != 6 { + return errors.UnsupportedError("one-pass-signature packet version " + strconv.Itoa(int(buf[0]))) } + ops.Version = int(buf[0]) var ok bool ops.Hash, ok = algorithm.HashIdToHashWithSha1(buf[2]) @@ -44,15 +47,69 @@ func (ops *OnePassSignature) parse(r io.Reader) (err error) { ops.SigType = SignatureType(buf[1]) ops.PubKeyAlgo = PublicKeyAlgorithm(buf[3]) - ops.KeyId = binary.BigEndian.Uint64(buf[4:12]) - ops.IsLast = buf[12] != 0 + + if ops.Version == 6 { + // Only for v6, a variable-length field containing the salt + _, err = readFull(r, buf[:1]) + if err != nil { + return + } + saltLength := int(buf[0]) + var expectedSaltLength int + expectedSaltLength, err = SaltLengthForHash(ops.Hash) + if err != nil { + return + } + if saltLength != expectedSaltLength { + err = errors.StructuralError("unexpected salt size for the given hash algorithm") + return + } + salt := make([]byte, expectedSaltLength) + _, err = readFull(r, salt) + if err != nil { + return + } + ops.Salt = salt + + // Only for v6 packets, 32 octets of the fingerprint of the signing key. + fingerprint := make([]byte, 32) + _, err = readFull(r, fingerprint) + if err != nil { + return + } + ops.KeyFingerprint = fingerprint + ops.KeyId = binary.BigEndian.Uint64(ops.KeyFingerprint[:8]) + } else { + _, err = readFull(r, buf[:8]) + if err != nil { + return + } + ops.KeyId = binary.BigEndian.Uint64(buf[:8]) + } + + _, err = readFull(r, buf[:1]) + if err != nil { + return + } + ops.IsLast = buf[0] != 0 return } // Serialize marshals the given OnePassSignature to w. func (ops *OnePassSignature) Serialize(w io.Writer) error { - var buf [13]byte - buf[0] = onePassSignatureVersion + //v3 length 1+1+1+1+8+1 = + packetLength := 13 + if ops.Version == 6 { + // v6 length 1+1+1+1+1+len(salt)+32+1 = + packetLength = 38 + len(ops.Salt) + } + + if err := serializeHeader(w, packetTypeOnePassSignature, packetLength); err != nil { + return err + } + + var buf [8]byte + buf[0] = byte(ops.Version) buf[1] = uint8(ops.SigType) var ok bool buf[2], ok = algorithm.HashToHashIdWithSha1(ops.Hash) @@ -60,14 +117,41 @@ func (ops *OnePassSignature) Serialize(w io.Writer) error { return errors.UnsupportedError("hash type: " + strconv.Itoa(int(ops.Hash))) } buf[3] = uint8(ops.PubKeyAlgo) - binary.BigEndian.PutUint64(buf[4:12], ops.KeyId) - if ops.IsLast { - buf[12] = 1 - } - if err := serializeHeader(w, packetTypeOnePassSignature, len(buf)); err != nil { + _, err := w.Write(buf[:4]) + if err != nil { return err } - _, err := w.Write(buf[:]) + + if ops.Version == 6 { + // write salt for v6 signatures + _, err := w.Write([]byte{uint8(len(ops.Salt))}) + if err != nil { + return err + } + _, err = w.Write(ops.Salt) + if err != nil { + return err + } + + // write fingerprint v6 signatures + _, err = w.Write(ops.KeyFingerprint) + if err != nil { + return err + } + } else { + binary.BigEndian.PutUint64(buf[:8], ops.KeyId) + _, err := w.Write(buf[:8]) + if err != nil { + return err + } + } + + isLast := []byte{byte(0)} + if ops.IsLast { + isLast[0] = 1 + } + + _, err = w.Write(isLast) return err } diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/opaque.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/opaque.go index 4f8204079f..cef7c661d3 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/opaque.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/opaque.go @@ -7,7 +7,6 @@ package packet import ( "bytes" "io" - "io/ioutil" "github.com/ProtonMail/go-crypto/openpgp/errors" ) @@ -26,7 +25,7 @@ type OpaquePacket struct { } func (op *OpaquePacket) parse(r io.Reader) (err error) { - op.Contents, err = ioutil.ReadAll(r) + op.Contents, err = io.ReadAll(r) return } diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/packet.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/packet.go index 4d86a7da82..1e92e22c97 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/packet.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/packet.go @@ -311,12 +311,15 @@ const ( packetTypePrivateSubkey packetType = 7 packetTypeCompressed packetType = 8 packetTypeSymmetricallyEncrypted packetType = 9 + packetTypeMarker packetType = 10 packetTypeLiteralData packetType = 11 + packetTypeTrust packetType = 12 packetTypeUserId packetType = 13 packetTypePublicSubkey packetType = 14 packetTypeUserAttribute packetType = 17 packetTypeSymmetricallyEncryptedIntegrityProtected packetType = 18 packetTypeAEADEncrypted packetType = 20 + packetPadding packetType = 21 ) // EncryptedDataPacket holds encrypted data. It is currently implemented by @@ -328,7 +331,7 @@ type EncryptedDataPacket interface { // Read reads a single OpenPGP packet from the given io.Reader. If there is an // error parsing a packet, the whole packet is consumed from the input. func Read(r io.Reader) (p Packet, err error) { - tag, _, contents, err := readHeader(r) + tag, len, contents, err := readHeader(r) if err != nil { return } @@ -367,8 +370,93 @@ func Read(r io.Reader) (p Packet, err error) { p = se case packetTypeAEADEncrypted: p = new(AEADEncrypted) + case packetPadding: + p = Padding(len) + case packetTypeMarker: + p = new(Marker) + case packetTypeTrust: + // Not implemented, just consume + err = errors.UnknownPacketTypeError(tag) default: + // Packet Tags from 0 to 39 are critical. + // Packet Tags from 40 to 63 are non-critical. + if tag < 40 { + err = errors.CriticalUnknownPacketTypeError(tag) + } else { + err = errors.UnknownPacketTypeError(tag) + } + } + if p != nil { + err = p.parse(contents) + } + if err != nil { + consumeAll(contents) + } + return +} + +// ReadWithCheck reads a single OpenPGP message packet from the given io.Reader. If there is an +// error parsing a packet, the whole packet is consumed from the input. +// ReadWithCheck additionally checks if the OpenPGP message packet sequence adheres +// to the packet composition rules in rfc4880, if not throws an error. +func ReadWithCheck(r io.Reader, sequence *SequenceVerifier) (p Packet, msgErr error, err error) { + tag, len, contents, err := readHeader(r) + if err != nil { + return + } + switch tag { + case packetTypeEncryptedKey: + msgErr = sequence.Next(ESKSymbol) + p = new(EncryptedKey) + case packetTypeSignature: + msgErr = sequence.Next(SigSymbol) + p = new(Signature) + case packetTypeSymmetricKeyEncrypted: + msgErr = sequence.Next(ESKSymbol) + p = new(SymmetricKeyEncrypted) + case packetTypeOnePassSignature: + msgErr = sequence.Next(OPSSymbol) + p = new(OnePassSignature) + case packetTypeCompressed: + msgErr = sequence.Next(CompSymbol) + p = new(Compressed) + case packetTypeSymmetricallyEncrypted: + msgErr = sequence.Next(EncSymbol) + p = new(SymmetricallyEncrypted) + case packetTypeLiteralData: + msgErr = sequence.Next(LDSymbol) + p = new(LiteralData) + case packetTypeSymmetricallyEncryptedIntegrityProtected: + msgErr = sequence.Next(EncSymbol) + se := new(SymmetricallyEncrypted) + se.IntegrityProtected = true + p = se + case packetTypeAEADEncrypted: + msgErr = sequence.Next(EncSymbol) + p = new(AEADEncrypted) + case packetPadding: + p = Padding(len) + case packetTypeMarker: + p = new(Marker) + case packetTypeTrust: + // Not implemented, just consume err = errors.UnknownPacketTypeError(tag) + case packetTypePrivateKey, + packetTypePrivateSubkey, + packetTypePublicKey, + packetTypePublicSubkey, + packetTypeUserId, + packetTypeUserAttribute: + msgErr = sequence.Next(UnknownSymbol) + consumeAll(contents) + default: + // Packet Tags from 0 to 39 are critical. + // Packet Tags from 40 to 63 are non-critical. + if tag < 40 { + err = errors.CriticalUnknownPacketTypeError(tag) + } else { + err = errors.UnknownPacketTypeError(tag) + } } if p != nil { err = p.parse(contents) @@ -385,17 +473,17 @@ type SignatureType uint8 const ( SigTypeBinary SignatureType = 0x00 - SigTypeText = 0x01 - SigTypeGenericCert = 0x10 - SigTypePersonaCert = 0x11 - SigTypeCasualCert = 0x12 - SigTypePositiveCert = 0x13 - SigTypeSubkeyBinding = 0x18 - SigTypePrimaryKeyBinding = 0x19 - SigTypeDirectSignature = 0x1F - SigTypeKeyRevocation = 0x20 - SigTypeSubkeyRevocation = 0x28 - SigTypeCertificationRevocation = 0x30 + SigTypeText SignatureType = 0x01 + SigTypeGenericCert SignatureType = 0x10 + SigTypePersonaCert SignatureType = 0x11 + SigTypeCasualCert SignatureType = 0x12 + SigTypePositiveCert SignatureType = 0x13 + SigTypeSubkeyBinding SignatureType = 0x18 + SigTypePrimaryKeyBinding SignatureType = 0x19 + SigTypeDirectSignature SignatureType = 0x1F + SigTypeKeyRevocation SignatureType = 0x20 + SigTypeSubkeyRevocation SignatureType = 0x28 + SigTypeCertificationRevocation SignatureType = 0x30 ) // PublicKeyAlgorithm represents the different public key system specified for @@ -412,6 +500,11 @@ const ( PubKeyAlgoECDSA PublicKeyAlgorithm = 19 // https://www.ietf.org/archive/id/draft-koch-eddsa-for-openpgp-04.txt PubKeyAlgoEdDSA PublicKeyAlgorithm = 22 + // https://datatracker.ietf.org/doc/html/draft-ietf-openpgp-crypto-refresh + PubKeyAlgoX25519 PublicKeyAlgorithm = 25 + PubKeyAlgoX448 PublicKeyAlgorithm = 26 + PubKeyAlgoEd25519 PublicKeyAlgorithm = 27 + PubKeyAlgoEd448 PublicKeyAlgorithm = 28 // Deprecated in RFC 4880, Section 13.5. Use key flags instead. PubKeyAlgoRSAEncryptOnly PublicKeyAlgorithm = 2 @@ -422,7 +515,7 @@ const ( // key of the given type. func (pka PublicKeyAlgorithm) CanEncrypt() bool { switch pka { - case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH: + case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoElGamal, PubKeyAlgoECDH, PubKeyAlgoX25519, PubKeyAlgoX448: return true } return false @@ -432,7 +525,7 @@ func (pka PublicKeyAlgorithm) CanEncrypt() bool { // sign a message. func (pka PublicKeyAlgorithm) CanSign() bool { switch pka { - case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA, PubKeyAlgoECDSA, PubKeyAlgoEdDSA: + case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA, PubKeyAlgoECDSA, PubKeyAlgoEdDSA, PubKeyAlgoEd25519, PubKeyAlgoEd448: return true } return false @@ -512,6 +605,11 @@ func (mode AEADMode) TagLength() int { return algorithm.AEADMode(mode).TagLength() } +// IsSupported returns true if the aead mode is supported from the library +func (mode AEADMode) IsSupported() bool { + return algorithm.AEADMode(mode).TagLength() > 0 +} + // new returns a fresh instance of the given mode. func (mode AEADMode) new(block cipher.Block) cipher.AEAD { return algorithm.AEADMode(mode).New(block) @@ -526,8 +624,17 @@ const ( KeySuperseded ReasonForRevocation = 1 KeyCompromised ReasonForRevocation = 2 KeyRetired ReasonForRevocation = 3 + UserIDNotValid ReasonForRevocation = 32 + Unknown ReasonForRevocation = 200 ) +func NewReasonForRevocation(value byte) ReasonForRevocation { + if value < 4 || value == 32 { + return ReasonForRevocation(value) + } + return Unknown +} + // Curve is a mapping to supported ECC curves for key generation. // See https://www.ietf.org/archive/id/draft-ietf-openpgp-crypto-refresh-06.html#name-curve-specific-wire-formats type Curve string @@ -549,3 +656,20 @@ type TrustLevel uint8 // TrustAmount represents a trust amount per RFC4880 5.2.3.13 type TrustAmount uint8 + +const ( + // versionSize is the length in bytes of the version value. + versionSize = 1 + // algorithmSize is the length in bytes of the key algorithm value. + algorithmSize = 1 + // keyVersionSize is the length in bytes of the key version value + keyVersionSize = 1 + // keyIdSize is the length in bytes of the key identifier value. + keyIdSize = 8 + // timestampSize is the length in bytes of encoded timestamps. + timestampSize = 4 + // fingerprintSizeV6 is the length in bytes of the key fingerprint in v6. + fingerprintSizeV6 = 32 + // fingerprintSize is the length in bytes of the key fingerprint. + fingerprintSize = 20 +) diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/packet_sequence.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/packet_sequence.go new file mode 100644 index 0000000000..55a8a56c2d --- /dev/null +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/packet_sequence.go @@ -0,0 +1,222 @@ +package packet + +// This file implements the pushdown automata (PDA) from PGPainless (Paul Schaub) +// to verify pgp packet sequences. See Paul's blogpost for more details: +// https://blog.jabberhead.tk/2022/10/26/implementing-packet-sequence-validation-using-pushdown-automata/ +import ( + "fmt" + + "github.com/ProtonMail/go-crypto/openpgp/errors" +) + +func NewErrMalformedMessage(from State, input InputSymbol, stackSymbol StackSymbol) errors.ErrMalformedMessage { + return errors.ErrMalformedMessage(fmt.Sprintf("state %d, input symbol %d, stack symbol %d ", from, input, stackSymbol)) +} + +// InputSymbol defines the input alphabet of the PDA +type InputSymbol uint8 + +const ( + LDSymbol InputSymbol = iota + SigSymbol + OPSSymbol + CompSymbol + ESKSymbol + EncSymbol + EOSSymbol + UnknownSymbol +) + +// StackSymbol defines the stack alphabet of the PDA +type StackSymbol int8 + +const ( + MsgStackSymbol StackSymbol = iota + OpsStackSymbol + KeyStackSymbol + EndStackSymbol + EmptyStackSymbol +) + +// State defines the states of the PDA +type State int8 + +const ( + OpenPGPMessage State = iota + ESKMessage + LiteralMessage + CompressedMessage + EncryptedMessage + ValidMessage +) + +// transition represents a state transition in the PDA +type transition func(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) + +// SequenceVerifier is a pushdown automata to verify +// PGP messages packet sequences according to rfc4880. +type SequenceVerifier struct { + stack []StackSymbol + state State +} + +// Next performs a state transition with the given input symbol. +// If the transition fails a ErrMalformedMessage is returned. +func (sv *SequenceVerifier) Next(input InputSymbol) error { + for { + stackSymbol := sv.popStack() + transitionFunc := getTransition(sv.state) + nextState, newStackSymbols, redo, err := transitionFunc(input, stackSymbol) + if err != nil { + return err + } + if redo { + sv.pushStack(stackSymbol) + } + for _, newStackSymbol := range newStackSymbols { + sv.pushStack(newStackSymbol) + } + sv.state = nextState + if !redo { + break + } + } + return nil +} + +// Valid returns true if RDA is in a valid state. +func (sv *SequenceVerifier) Valid() bool { + return sv.state == ValidMessage && len(sv.stack) == 0 +} + +func (sv *SequenceVerifier) AssertValid() error { + if !sv.Valid() { + return errors.ErrMalformedMessage("invalid message") + } + return nil +} + +func NewSequenceVerifier() *SequenceVerifier { + return &SequenceVerifier{ + stack: []StackSymbol{EndStackSymbol, MsgStackSymbol}, + state: OpenPGPMessage, + } +} + +func (sv *SequenceVerifier) popStack() StackSymbol { + if len(sv.stack) == 0 { + return EmptyStackSymbol + } + elemIndex := len(sv.stack) - 1 + stackSymbol := sv.stack[elemIndex] + sv.stack = sv.stack[:elemIndex] + return stackSymbol +} + +func (sv *SequenceVerifier) pushStack(stackSymbol StackSymbol) { + sv.stack = append(sv.stack, stackSymbol) +} + +func getTransition(from State) transition { + switch from { + case OpenPGPMessage: + return fromOpenPGPMessage + case LiteralMessage: + return fromLiteralMessage + case CompressedMessage: + return fromCompressedMessage + case EncryptedMessage: + return fromEncryptedMessage + case ESKMessage: + return fromESKMessage + case ValidMessage: + return fromValidMessage + } + return nil +} + +// fromOpenPGPMessage is the transition for the state OpenPGPMessage. +func fromOpenPGPMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) { + if stackSymbol != MsgStackSymbol { + return 0, nil, false, NewErrMalformedMessage(OpenPGPMessage, input, stackSymbol) + } + switch input { + case LDSymbol: + return LiteralMessage, nil, false, nil + case SigSymbol: + return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, false, nil + case OPSSymbol: + return OpenPGPMessage, []StackSymbol{OpsStackSymbol, MsgStackSymbol}, false, nil + case CompSymbol: + return CompressedMessage, nil, false, nil + case ESKSymbol: + return ESKMessage, []StackSymbol{KeyStackSymbol}, false, nil + case EncSymbol: + return EncryptedMessage, nil, false, nil + } + return 0, nil, false, NewErrMalformedMessage(OpenPGPMessage, input, stackSymbol) +} + +// fromESKMessage is the transition for the state ESKMessage. +func fromESKMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) { + if stackSymbol != KeyStackSymbol { + return 0, nil, false, NewErrMalformedMessage(ESKMessage, input, stackSymbol) + } + switch input { + case ESKSymbol: + return ESKMessage, []StackSymbol{KeyStackSymbol}, false, nil + case EncSymbol: + return EncryptedMessage, nil, false, nil + } + return 0, nil, false, NewErrMalformedMessage(ESKMessage, input, stackSymbol) +} + +// fromLiteralMessage is the transition for the state LiteralMessage. +func fromLiteralMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) { + switch input { + case SigSymbol: + if stackSymbol == OpsStackSymbol { + return LiteralMessage, nil, false, nil + } + case EOSSymbol: + if stackSymbol == EndStackSymbol { + return ValidMessage, nil, false, nil + } + } + return 0, nil, false, NewErrMalformedMessage(LiteralMessage, input, stackSymbol) +} + +// fromLiteralMessage is the transition for the state CompressedMessage. +func fromCompressedMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) { + switch input { + case SigSymbol: + if stackSymbol == OpsStackSymbol { + return CompressedMessage, nil, false, nil + } + case EOSSymbol: + if stackSymbol == EndStackSymbol { + return ValidMessage, nil, false, nil + } + } + return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, true, nil +} + +// fromEncryptedMessage is the transition for the state EncryptedMessage. +func fromEncryptedMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) { + switch input { + case SigSymbol: + if stackSymbol == OpsStackSymbol { + return EncryptedMessage, nil, false, nil + } + case EOSSymbol: + if stackSymbol == EndStackSymbol { + return ValidMessage, nil, false, nil + } + } + return OpenPGPMessage, []StackSymbol{MsgStackSymbol}, true, nil +} + +// fromValidMessage is the transition for the state ValidMessage. +func fromValidMessage(input InputSymbol, stackSymbol StackSymbol) (State, []StackSymbol, bool, error) { + return 0, nil, false, NewErrMalformedMessage(ValidMessage, input, stackSymbol) +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/packet_unsupported.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/packet_unsupported.go new file mode 100644 index 0000000000..2d714723cf --- /dev/null +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/packet_unsupported.go @@ -0,0 +1,24 @@ +package packet + +import ( + "io" + + "github.com/ProtonMail/go-crypto/openpgp/errors" +) + +// UnsupportedPackage represents a OpenPGP packet with a known packet type +// but with unsupported content. +type UnsupportedPacket struct { + IncompletePacket Packet + Error errors.UnsupportedError +} + +// Implements the Packet interface +func (up *UnsupportedPacket) parse(read io.Reader) error { + err := up.IncompletePacket.parse(read) + if castedErr, ok := err.(errors.UnsupportedError); ok { + up.Error = castedErr + return nil + } + return err +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/padding.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/padding.go new file mode 100644 index 0000000000..3b6a7045d1 --- /dev/null +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/padding.go @@ -0,0 +1,26 @@ +package packet + +import ( + "io" +) + +// Padding type represents a Padding Packet (Tag 21). +// The padding type is represented by the length of its padding. +// see https://datatracker.ietf.org/doc/html/draft-ietf-openpgp-crypto-refresh#name-padding-packet-tag-21 +type Padding int + +// parse just ignores the padding content. +func (pad Padding) parse(reader io.Reader) error { + _, err := io.CopyN(io.Discard, reader, int64(pad)) + return err +} + +// SerializePadding writes the padding to writer. +func (pad Padding) SerializePadding(writer io.Writer, rand io.Reader) error { + err := serializeHeader(writer, packetPadding, int(pad)) + if err != nil { + return err + } + _, err = io.CopyN(writer, rand, int64(pad)) + return err +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/private_key.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/private_key.go index 2fc4386437..f04e6c6b87 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/private_key.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/private_key.go @@ -9,22 +9,28 @@ import ( "crypto" "crypto/cipher" "crypto/dsa" - "crypto/rand" "crypto/rsa" "crypto/sha1" + "crypto/sha256" + "crypto/subtle" + "fmt" "io" - "io/ioutil" "math/big" "strconv" "time" "github.com/ProtonMail/go-crypto/openpgp/ecdh" "github.com/ProtonMail/go-crypto/openpgp/ecdsa" + "github.com/ProtonMail/go-crypto/openpgp/ed25519" + "github.com/ProtonMail/go-crypto/openpgp/ed448" "github.com/ProtonMail/go-crypto/openpgp/eddsa" "github.com/ProtonMail/go-crypto/openpgp/elgamal" "github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/internal/encoding" "github.com/ProtonMail/go-crypto/openpgp/s2k" + "github.com/ProtonMail/go-crypto/openpgp/x25519" + "github.com/ProtonMail/go-crypto/openpgp/x448" + "golang.org/x/crypto/hkdf" ) // PrivateKey represents a possibly encrypted private key. See RFC 4880, @@ -35,14 +41,14 @@ type PrivateKey struct { encryptedData []byte cipher CipherFunction s2k func(out, in []byte) - // An *{rsa|dsa|elgamal|ecdh|ecdsa|ed25519}.PrivateKey or + aead AEADMode // only relevant if S2KAEAD is enabled + // An *{rsa|dsa|elgamal|ecdh|ecdsa|ed25519|ed448}.PrivateKey or // crypto.Signer/crypto.Decrypter (Decryptor RSA only). - PrivateKey interface{} - sha1Checksum bool - iv []byte + PrivateKey interface{} + iv []byte // Type of encryption of the S2K packet - // Allowed values are 0 (Not encrypted), 254 (SHA1), or + // Allowed values are 0 (Not encrypted), 253 (AEAD), 254 (SHA1), or // 255 (2-byte checksum) s2kType S2KType // Full parameters of the S2K packet @@ -55,6 +61,8 @@ type S2KType uint8 const ( // S2KNON unencrypt S2KNON S2KType = 0 + // S2KAEAD use authenticated encryption + S2KAEAD S2KType = 253 // S2KSHA1 sha1 sum check S2KSHA1 S2KType = 254 // S2KCHECKSUM sum check @@ -103,6 +111,34 @@ func NewECDHPrivateKey(creationTime time.Time, priv *ecdh.PrivateKey) *PrivateKe return pk } +func NewX25519PrivateKey(creationTime time.Time, priv *x25519.PrivateKey) *PrivateKey { + pk := new(PrivateKey) + pk.PublicKey = *NewX25519PublicKey(creationTime, &priv.PublicKey) + pk.PrivateKey = priv + return pk +} + +func NewX448PrivateKey(creationTime time.Time, priv *x448.PrivateKey) *PrivateKey { + pk := new(PrivateKey) + pk.PublicKey = *NewX448PublicKey(creationTime, &priv.PublicKey) + pk.PrivateKey = priv + return pk +} + +func NewEd25519PrivateKey(creationTime time.Time, priv *ed25519.PrivateKey) *PrivateKey { + pk := new(PrivateKey) + pk.PublicKey = *NewEd25519PublicKey(creationTime, &priv.PublicKey) + pk.PrivateKey = priv + return pk +} + +func NewEd448PrivateKey(creationTime time.Time, priv *ed448.PrivateKey) *PrivateKey { + pk := new(PrivateKey) + pk.PublicKey = *NewEd448PublicKey(creationTime, &priv.PublicKey) + pk.PrivateKey = priv + return pk +} + // NewSignerPrivateKey creates a PrivateKey from a crypto.Signer that // implements RSA, ECDSA or EdDSA. func NewSignerPrivateKey(creationTime time.Time, signer interface{}) *PrivateKey { @@ -122,6 +158,14 @@ func NewSignerPrivateKey(creationTime time.Time, signer interface{}) *PrivateKey pk.PublicKey = *NewEdDSAPublicKey(creationTime, &pubkey.PublicKey) case eddsa.PrivateKey: pk.PublicKey = *NewEdDSAPublicKey(creationTime, &pubkey.PublicKey) + case *ed25519.PrivateKey: + pk.PublicKey = *NewEd25519PublicKey(creationTime, &pubkey.PublicKey) + case ed25519.PrivateKey: + pk.PublicKey = *NewEd25519PublicKey(creationTime, &pubkey.PublicKey) + case *ed448.PrivateKey: + pk.PublicKey = *NewEd448PublicKey(creationTime, &pubkey.PublicKey) + case ed448.PrivateKey: + pk.PublicKey = *NewEd448PublicKey(creationTime, &pubkey.PublicKey) default: panic("openpgp: unknown signer type in NewSignerPrivateKey") } @@ -129,7 +173,7 @@ func NewSignerPrivateKey(creationTime time.Time, signer interface{}) *PrivateKey return pk } -// NewDecrypterPrivateKey creates a PrivateKey from a *{rsa|elgamal|ecdh}.PrivateKey. +// NewDecrypterPrivateKey creates a PrivateKey from a *{rsa|elgamal|ecdh|x25519|x448}.PrivateKey. func NewDecrypterPrivateKey(creationTime time.Time, decrypter interface{}) *PrivateKey { pk := new(PrivateKey) switch priv := decrypter.(type) { @@ -139,6 +183,10 @@ func NewDecrypterPrivateKey(creationTime time.Time, decrypter interface{}) *Priv pk.PublicKey = *NewElGamalPublicKey(creationTime, &priv.PublicKey) case *ecdh.PrivateKey: pk.PublicKey = *NewECDHPublicKey(creationTime, &priv.PublicKey) + case *x25519.PrivateKey: + pk.PublicKey = *NewX25519PublicKey(creationTime, &priv.PublicKey) + case *x448.PrivateKey: + pk.PublicKey = *NewX448PublicKey(creationTime, &priv.PublicKey) default: panic("openpgp: unknown decrypter type in NewDecrypterPrivateKey") } @@ -152,6 +200,11 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) { return } v5 := pk.PublicKey.Version == 5 + v6 := pk.PublicKey.Version == 6 + + if V5Disabled && v5 { + return errors.UnsupportedError("support for parsing v5 entities is disabled; build with `-tags v5` if needed") + } var buf [1]byte _, err = readFull(r, buf[:]) @@ -160,7 +213,7 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) { } pk.s2kType = S2KType(buf[0]) var optCount [1]byte - if v5 { + if v5 || (v6 && pk.s2kType != S2KNON) { if _, err = readFull(r, optCount[:]); err != nil { return } @@ -170,9 +223,9 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) { case S2KNON: pk.s2k = nil pk.Encrypted = false - case S2KSHA1, S2KCHECKSUM: - if v5 && pk.s2kType == S2KCHECKSUM { - return errors.StructuralError("wrong s2k identifier for version 5") + case S2KSHA1, S2KCHECKSUM, S2KAEAD: + if (v5 || v6) && pk.s2kType == S2KCHECKSUM { + return errors.StructuralError(fmt.Sprintf("wrong s2k identifier for version %d", pk.Version)) } _, err = readFull(r, buf[:]) if err != nil { @@ -182,6 +235,29 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) { if pk.cipher != 0 && !pk.cipher.IsSupported() { return errors.UnsupportedError("unsupported cipher function in private key") } + // [Optional] If string-to-key usage octet was 253, + // a one-octet AEAD algorithm. + if pk.s2kType == S2KAEAD { + _, err = readFull(r, buf[:]) + if err != nil { + return + } + pk.aead = AEADMode(buf[0]) + if !pk.aead.IsSupported() { + return errors.UnsupportedError("unsupported aead mode in private key") + } + } + + // [Optional] Only for a version 6 packet, + // and if string-to-key usage octet was 255, 254, or 253, + // an one-octet count of the following field. + if v6 { + _, err = readFull(r, buf[:]) + if err != nil { + return + } + } + pk.s2kParams, err = s2k.ParseIntoParams(r) if err != nil { return @@ -189,28 +265,43 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) { if pk.s2kParams.Dummy() { return } + if pk.s2kParams.Mode() == s2k.Argon2S2K && pk.s2kType != S2KAEAD { + return errors.StructuralError("using Argon2 S2K without AEAD is not allowed") + } + if pk.s2kParams.Mode() == s2k.SimpleS2K && pk.Version == 6 { + return errors.StructuralError("using Simple S2K with version 6 keys is not allowed") + } pk.s2k, err = pk.s2kParams.Function() if err != nil { return } pk.Encrypted = true - if pk.s2kType == S2KSHA1 { - pk.sha1Checksum = true - } default: return errors.UnsupportedError("deprecated s2k function in private key") } if pk.Encrypted { - blockSize := pk.cipher.blockSize() - if blockSize == 0 { + var ivSize int + // If the S2K usage octet was 253, the IV is of the size expected by the AEAD mode, + // unless it's a version 5 key, in which case it's the size of the symmetric cipher's block size. + // For all other S2K modes, it's always the block size. + if !v5 && pk.s2kType == S2KAEAD { + ivSize = pk.aead.IvLength() + } else { + ivSize = pk.cipher.blockSize() + } + + if ivSize == 0 { return errors.UnsupportedError("unsupported cipher in private key: " + strconv.Itoa(int(pk.cipher))) } - pk.iv = make([]byte, blockSize) + pk.iv = make([]byte, ivSize) _, err = readFull(r, pk.iv) if err != nil { return } + if v5 && pk.s2kType == S2KAEAD { + pk.iv = pk.iv[:pk.aead.IvLength()] + } } var privateKeyData []byte @@ -230,7 +321,7 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) { return } } else { - privateKeyData, err = ioutil.ReadAll(r) + privateKeyData, err = io.ReadAll(r) if err != nil { return } @@ -239,16 +330,22 @@ func (pk *PrivateKey) parse(r io.Reader) (err error) { if len(privateKeyData) < 2 { return errors.StructuralError("truncated private key data") } - var sum uint16 - for i := 0; i < len(privateKeyData)-2; i++ { - sum += uint16(privateKeyData[i]) - } - if privateKeyData[len(privateKeyData)-2] != uint8(sum>>8) || - privateKeyData[len(privateKeyData)-1] != uint8(sum) { - return errors.StructuralError("private key checksum failure") + if pk.Version != 6 { + // checksum + var sum uint16 + for i := 0; i < len(privateKeyData)-2; i++ { + sum += uint16(privateKeyData[i]) + } + if privateKeyData[len(privateKeyData)-2] != uint8(sum>>8) || + privateKeyData[len(privateKeyData)-1] != uint8(sum) { + return errors.StructuralError("private key checksum failure") + } + privateKeyData = privateKeyData[:len(privateKeyData)-2] + return pk.parsePrivateKey(privateKeyData) + } else { + // No checksum + return pk.parsePrivateKey(privateKeyData) } - privateKeyData = privateKeyData[:len(privateKeyData)-2] - return pk.parsePrivateKey(privateKeyData) } pk.encryptedData = privateKeyData @@ -280,18 +377,59 @@ func (pk *PrivateKey) Serialize(w io.Writer) (err error) { optional := bytes.NewBuffer(nil) if pk.Encrypted || pk.Dummy() { - optional.Write([]byte{uint8(pk.cipher)}) - if err := pk.s2kParams.Serialize(optional); err != nil { + // [Optional] If string-to-key usage octet was 255, 254, or 253, + // a one-octet symmetric encryption algorithm. + if _, err = optional.Write([]byte{uint8(pk.cipher)}); err != nil { + return + } + // [Optional] If string-to-key usage octet was 253, + // a one-octet AEAD algorithm. + if pk.s2kType == S2KAEAD { + if _, err = optional.Write([]byte{uint8(pk.aead)}); err != nil { + return + } + } + + s2kBuffer := bytes.NewBuffer(nil) + if err := pk.s2kParams.Serialize(s2kBuffer); err != nil { return err } + // [Optional] Only for a version 6 packet, and if string-to-key + // usage octet was 255, 254, or 253, an one-octet + // count of the following field. + if pk.Version == 6 { + if _, err = optional.Write([]byte{uint8(s2kBuffer.Len())}); err != nil { + return + } + } + // [Optional] If string-to-key usage octet was 255, 254, or 253, + // a string-to-key (S2K) specifier. The length of the string-to-key specifier + // depends on its type + if _, err = io.Copy(optional, s2kBuffer); err != nil { + return + } + + // IV if pk.Encrypted { - optional.Write(pk.iv) + if _, err = optional.Write(pk.iv); err != nil { + return + } + if pk.Version == 5 && pk.s2kType == S2KAEAD { + // Add padding for version 5 + padding := make([]byte, pk.cipher.blockSize()-len(pk.iv)) + if _, err = optional.Write(padding); err != nil { + return + } + } } } - if pk.Version == 5 { + if pk.Version == 5 || (pk.Version == 6 && pk.s2kType != S2KNON) { contents.Write([]byte{uint8(optional.Len())}) } - io.Copy(contents, optional) + + if _, err := io.Copy(contents, optional); err != nil { + return err + } if !pk.Dummy() { l := 0 @@ -303,8 +441,10 @@ func (pk *PrivateKey) Serialize(w io.Writer) (err error) { return err } l = buf.Len() - checksum := mod64kHash(buf.Bytes()) - buf.Write([]byte{byte(checksum >> 8), byte(checksum)}) + if pk.Version != 6 { + checksum := mod64kHash(buf.Bytes()) + buf.Write([]byte{byte(checksum >> 8), byte(checksum)}) + } priv = buf.Bytes() } else { priv, l = pk.encryptedData, len(pk.encryptedData) @@ -370,6 +510,26 @@ func serializeECDHPrivateKey(w io.Writer, priv *ecdh.PrivateKey) error { return err } +func serializeX25519PrivateKey(w io.Writer, priv *x25519.PrivateKey) error { + _, err := w.Write(priv.Secret) + return err +} + +func serializeX448PrivateKey(w io.Writer, priv *x448.PrivateKey) error { + _, err := w.Write(priv.Secret) + return err +} + +func serializeEd25519PrivateKey(w io.Writer, priv *ed25519.PrivateKey) error { + _, err := w.Write(priv.MarshalByteSecret()) + return err +} + +func serializeEd448PrivateKey(w io.Writer, priv *ed448.PrivateKey) error { + _, err := w.Write(priv.MarshalByteSecret()) + return err +} + // decrypt decrypts an encrypted private key using a decryption key. func (pk *PrivateKey) decrypt(decryptionKey []byte) error { if pk.Dummy() { @@ -378,37 +538,51 @@ func (pk *PrivateKey) decrypt(decryptionKey []byte) error { if !pk.Encrypted { return nil } - block := pk.cipher.new(decryptionKey) - cfb := cipher.NewCFBDecrypter(block, pk.iv) - - data := make([]byte, len(pk.encryptedData)) - cfb.XORKeyStream(data, pk.encryptedData) - - if pk.sha1Checksum { - if len(data) < sha1.Size { - return errors.StructuralError("truncated private key data") - } - h := sha1.New() - h.Write(data[:len(data)-sha1.Size]) - sum := h.Sum(nil) - if !bytes.Equal(sum, data[len(data)-sha1.Size:]) { - return errors.StructuralError("private key checksum failure") - } - data = data[:len(data)-sha1.Size] - } else { - if len(data) < 2 { - return errors.StructuralError("truncated private key data") + var data []byte + switch pk.s2kType { + case S2KAEAD: + aead := pk.aead.new(block) + additionalData, err := pk.additionalData() + if err != nil { + return err } - var sum uint16 - for i := 0; i < len(data)-2; i++ { - sum += uint16(data[i]) + // Decrypt the encrypted key material with aead + data, err = aead.Open(nil, pk.iv, pk.encryptedData, additionalData) + if err != nil { + return err } - if data[len(data)-2] != uint8(sum>>8) || - data[len(data)-1] != uint8(sum) { - return errors.StructuralError("private key checksum failure") + case S2KSHA1, S2KCHECKSUM: + cfb := cipher.NewCFBDecrypter(block, pk.iv) + data = make([]byte, len(pk.encryptedData)) + cfb.XORKeyStream(data, pk.encryptedData) + if pk.s2kType == S2KSHA1 { + if len(data) < sha1.Size { + return errors.StructuralError("truncated private key data") + } + h := sha1.New() + h.Write(data[:len(data)-sha1.Size]) + sum := h.Sum(nil) + if !bytes.Equal(sum, data[len(data)-sha1.Size:]) { + return errors.StructuralError("private key checksum failure") + } + data = data[:len(data)-sha1.Size] + } else { + if len(data) < 2 { + return errors.StructuralError("truncated private key data") + } + var sum uint16 + for i := 0; i < len(data)-2; i++ { + sum += uint16(data[i]) + } + if data[len(data)-2] != uint8(sum>>8) || + data[len(data)-1] != uint8(sum) { + return errors.StructuralError("private key checksum failure") + } + data = data[:len(data)-2] } - data = data[:len(data)-2] + default: + return errors.InvalidArgumentError("invalid s2k type") } err := pk.parsePrivateKey(data) @@ -424,7 +598,6 @@ func (pk *PrivateKey) decrypt(decryptionKey []byte) error { pk.s2k = nil pk.Encrypted = false pk.encryptedData = nil - return nil } @@ -440,6 +613,9 @@ func (pk *PrivateKey) decryptWithCache(passphrase []byte, keyCache *s2k.Cache) e if err != nil { return err } + if pk.s2kType == S2KAEAD { + key = pk.applyHKDF(key) + } return pk.decrypt(key) } @@ -454,11 +630,14 @@ func (pk *PrivateKey) Decrypt(passphrase []byte) error { key := make([]byte, pk.cipher.KeySize()) pk.s2k(key, passphrase) + if pk.s2kType == S2KAEAD { + key = pk.applyHKDF(key) + } return pk.decrypt(key) } // DecryptPrivateKeys decrypts all encrypted keys with the given config and passphrase. -// Avoids recomputation of similar s2k key derivations. +// Avoids recomputation of similar s2k key derivations. func DecryptPrivateKeys(keys []*PrivateKey, passphrase []byte) error { // Create a cache to avoid recomputation of key derviations for the same passphrase. s2kCache := &s2k.Cache{} @@ -474,7 +653,7 @@ func DecryptPrivateKeys(keys []*PrivateKey, passphrase []byte) error { } // encrypt encrypts an unencrypted private key. -func (pk *PrivateKey) encrypt(key []byte, params *s2k.Params, cipherFunction CipherFunction) error { +func (pk *PrivateKey) encrypt(key []byte, params *s2k.Params, s2kType S2KType, cipherFunction CipherFunction, rand io.Reader) error { if pk.Dummy() { return errors.ErrDummyPrivateKey("dummy key found") } @@ -485,7 +664,15 @@ func (pk *PrivateKey) encrypt(key []byte, params *s2k.Params, cipherFunction Cip if len(key) != cipherFunction.KeySize() { return errors.InvalidArgumentError("supplied encryption key has the wrong size") } - + + if params.Mode() == s2k.Argon2S2K && s2kType != S2KAEAD { + return errors.InvalidArgumentError("using Argon2 S2K without AEAD is not allowed") + } + if params.Mode() != s2k.Argon2S2K && params.Mode() != s2k.IteratedSaltedS2K && + params.Mode() != s2k.SaltedS2K { // only allowed for high-entropy passphrases + return errors.InvalidArgumentError("insecure S2K mode") + } + priv := bytes.NewBuffer(nil) err := pk.serializePrivateKey(priv) if err != nil { @@ -497,35 +684,53 @@ func (pk *PrivateKey) encrypt(key []byte, params *s2k.Params, cipherFunction Cip pk.s2k, err = pk.s2kParams.Function() if err != nil { return err - } + } privateKeyBytes := priv.Bytes() - pk.sha1Checksum = true + pk.s2kType = s2kType block := pk.cipher.new(key) - pk.iv = make([]byte, pk.cipher.blockSize()) - _, err = rand.Read(pk.iv) - if err != nil { - return err - } - cfb := cipher.NewCFBEncrypter(block, pk.iv) - - if pk.sha1Checksum { - pk.s2kType = S2KSHA1 - h := sha1.New() - h.Write(privateKeyBytes) - sum := h.Sum(nil) - privateKeyBytes = append(privateKeyBytes, sum...) - } else { - pk.s2kType = S2KCHECKSUM - var sum uint16 - for _, b := range privateKeyBytes { - sum += uint16(b) + switch s2kType { + case S2KAEAD: + if pk.aead == 0 { + return errors.StructuralError("aead mode is not set on key") } - priv.Write([]byte{uint8(sum >> 8), uint8(sum)}) + aead := pk.aead.new(block) + additionalData, err := pk.additionalData() + if err != nil { + return err + } + pk.iv = make([]byte, aead.NonceSize()) + _, err = io.ReadFull(rand, pk.iv) + if err != nil { + return err + } + // Decrypt the encrypted key material with aead + pk.encryptedData = aead.Seal(nil, pk.iv, privateKeyBytes, additionalData) + case S2KSHA1, S2KCHECKSUM: + pk.iv = make([]byte, pk.cipher.blockSize()) + _, err = io.ReadFull(rand, pk.iv) + if err != nil { + return err + } + cfb := cipher.NewCFBEncrypter(block, pk.iv) + if s2kType == S2KSHA1 { + h := sha1.New() + h.Write(privateKeyBytes) + sum := h.Sum(nil) + privateKeyBytes = append(privateKeyBytes, sum...) + } else { + var sum uint16 + for _, b := range privateKeyBytes { + sum += uint16(b) + } + privateKeyBytes = append(privateKeyBytes, []byte{uint8(sum >> 8), uint8(sum)}...) + } + pk.encryptedData = make([]byte, len(privateKeyBytes)) + cfb.XORKeyStream(pk.encryptedData, privateKeyBytes) + default: + return errors.InvalidArgumentError("invalid s2k type for encryption") } - pk.encryptedData = make([]byte, len(privateKeyBytes)) - cfb.XORKeyStream(pk.encryptedData, privateKeyBytes) pk.Encrypted = true pk.PrivateKey = nil return err @@ -544,8 +749,15 @@ func (pk *PrivateKey) EncryptWithConfig(passphrase []byte, config *Config) error return err } s2k(key, passphrase) + s2kType := S2KSHA1 + if config.AEAD() != nil { + s2kType = S2KAEAD + pk.aead = config.AEAD().Mode() + pk.cipher = config.Cipher() + key = pk.applyHKDF(key) + } // Encrypt the private key with the derived encryption key. - return pk.encrypt(key, params, config.Cipher()) + return pk.encrypt(key, params, s2kType, config.Cipher(), config.Random()) } // EncryptPrivateKeys encrypts all unencrypted keys with the given config and passphrase. @@ -564,7 +776,16 @@ func EncryptPrivateKeys(keys []*PrivateKey, passphrase []byte, config *Config) e s2k(encryptionKey, passphrase) for _, key := range keys { if key != nil && !key.Dummy() && !key.Encrypted { - err = key.encrypt(encryptionKey, params, config.Cipher()) + s2kType := S2KSHA1 + if config.AEAD() != nil { + s2kType = S2KAEAD + key.aead = config.AEAD().Mode() + key.cipher = config.Cipher() + derivedKey := key.applyHKDF(encryptionKey) + err = key.encrypt(derivedKey, params, s2kType, config.Cipher(), config.Random()) + } else { + err = key.encrypt(encryptionKey, params, s2kType, config.Cipher(), config.Random()) + } if err != nil { return err } @@ -581,7 +802,7 @@ func (pk *PrivateKey) Encrypt(passphrase []byte) error { S2KMode: s2k.IteratedSaltedS2K, S2KCount: 65536, Hash: crypto.SHA256, - } , + }, DefaultCipher: CipherAES256, } return pk.EncryptWithConfig(passphrase, config) @@ -601,6 +822,14 @@ func (pk *PrivateKey) serializePrivateKey(w io.Writer) (err error) { err = serializeEdDSAPrivateKey(w, priv) case *ecdh.PrivateKey: err = serializeECDHPrivateKey(w, priv) + case *x25519.PrivateKey: + err = serializeX25519PrivateKey(w, priv) + case *x448.PrivateKey: + err = serializeX448PrivateKey(w, priv) + case *ed25519.PrivateKey: + err = serializeEd25519PrivateKey(w, priv) + case *ed448.PrivateKey: + err = serializeEd448PrivateKey(w, priv) default: err = errors.InvalidArgumentError("unknown private key type") } @@ -621,8 +850,18 @@ func (pk *PrivateKey) parsePrivateKey(data []byte) (err error) { return pk.parseECDHPrivateKey(data) case PubKeyAlgoEdDSA: return pk.parseEdDSAPrivateKey(data) + case PubKeyAlgoX25519: + return pk.parseX25519PrivateKey(data) + case PubKeyAlgoX448: + return pk.parseX448PrivateKey(data) + case PubKeyAlgoEd25519: + return pk.parseEd25519PrivateKey(data) + case PubKeyAlgoEd448: + return pk.parseEd448PrivateKey(data) + default: + err = errors.StructuralError("unknown private key type") + return } - panic("impossible") } func (pk *PrivateKey) parseRSAPrivateKey(data []byte) (err error) { @@ -743,6 +982,86 @@ func (pk *PrivateKey) parseECDHPrivateKey(data []byte) (err error) { return nil } +func (pk *PrivateKey) parseX25519PrivateKey(data []byte) (err error) { + publicKey := pk.PublicKey.PublicKey.(*x25519.PublicKey) + privateKey := x25519.NewPrivateKey(*publicKey) + privateKey.PublicKey = *publicKey + + privateKey.Secret = make([]byte, x25519.KeySize) + + if len(data) != x25519.KeySize { + err = errors.StructuralError("wrong x25519 key size") + return err + } + subtle.ConstantTimeCopy(1, privateKey.Secret, data) + if err = x25519.Validate(privateKey); err != nil { + return err + } + pk.PrivateKey = privateKey + return nil +} + +func (pk *PrivateKey) parseX448PrivateKey(data []byte) (err error) { + publicKey := pk.PublicKey.PublicKey.(*x448.PublicKey) + privateKey := x448.NewPrivateKey(*publicKey) + privateKey.PublicKey = *publicKey + + privateKey.Secret = make([]byte, x448.KeySize) + + if len(data) != x448.KeySize { + err = errors.StructuralError("wrong x448 key size") + return err + } + subtle.ConstantTimeCopy(1, privateKey.Secret, data) + if err = x448.Validate(privateKey); err != nil { + return err + } + pk.PrivateKey = privateKey + return nil +} + +func (pk *PrivateKey) parseEd25519PrivateKey(data []byte) (err error) { + publicKey := pk.PublicKey.PublicKey.(*ed25519.PublicKey) + privateKey := ed25519.NewPrivateKey(*publicKey) + privateKey.PublicKey = *publicKey + + if len(data) != ed25519.SeedSize { + err = errors.StructuralError("wrong ed25519 key size") + return err + } + err = privateKey.UnmarshalByteSecret(data) + if err != nil { + return err + } + err = ed25519.Validate(privateKey) + if err != nil { + return err + } + pk.PrivateKey = privateKey + return nil +} + +func (pk *PrivateKey) parseEd448PrivateKey(data []byte) (err error) { + publicKey := pk.PublicKey.PublicKey.(*ed448.PublicKey) + privateKey := ed448.NewPrivateKey(*publicKey) + privateKey.PublicKey = *publicKey + + if len(data) != ed448.SeedSize { + err = errors.StructuralError("wrong ed448 key size") + return err + } + err = privateKey.UnmarshalByteSecret(data) + if err != nil { + return err + } + err = ed448.Validate(privateKey) + if err != nil { + return err + } + pk.PrivateKey = privateKey + return nil +} + func (pk *PrivateKey) parseEdDSAPrivateKey(data []byte) (err error) { eddsaPub := pk.PublicKey.PublicKey.(*eddsa.PublicKey) eddsaPriv := eddsa.NewPrivateKey(*eddsaPub) @@ -767,6 +1086,41 @@ func (pk *PrivateKey) parseEdDSAPrivateKey(data []byte) (err error) { return nil } +func (pk *PrivateKey) additionalData() ([]byte, error) { + additionalData := bytes.NewBuffer(nil) + // Write additional data prefix based on packet type + var packetByte byte + if pk.PublicKey.IsSubkey { + packetByte = 0xc7 + } else { + packetByte = 0xc5 + } + // Write public key to additional data + _, err := additionalData.Write([]byte{packetByte}) + if err != nil { + return nil, err + } + err = pk.PublicKey.serializeWithoutHeaders(additionalData) + if err != nil { + return nil, err + } + return additionalData.Bytes(), nil +} + +func (pk *PrivateKey) applyHKDF(inputKey []byte) []byte { + var packetByte byte + if pk.PublicKey.IsSubkey { + packetByte = 0xc7 + } else { + packetByte = 0xc5 + } + associatedData := []byte{packetByte, byte(pk.Version), byte(pk.cipher), byte(pk.aead)} + hkdfReader := hkdf.New(sha256.New, inputKey, []byte{}, associatedData) + encryptionKey := make([]byte, pk.cipher.KeySize()) + _, _ = readFull(hkdfReader, encryptionKey) + return encryptionKey +} + func validateDSAParameters(priv *dsa.PrivateKey) error { p := priv.P // group prime q := priv.Q // subgroup order diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/public_key.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/public_key.go index 3402b8c140..f8da781bbe 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/public_key.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/public_key.go @@ -5,7 +5,6 @@ package packet import ( - "crypto" "crypto/dsa" "crypto/rsa" "crypto/sha1" @@ -21,23 +20,24 @@ import ( "github.com/ProtonMail/go-crypto/openpgp/ecdh" "github.com/ProtonMail/go-crypto/openpgp/ecdsa" + "github.com/ProtonMail/go-crypto/openpgp/ed25519" + "github.com/ProtonMail/go-crypto/openpgp/ed448" "github.com/ProtonMail/go-crypto/openpgp/eddsa" "github.com/ProtonMail/go-crypto/openpgp/elgamal" "github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" "github.com/ProtonMail/go-crypto/openpgp/internal/ecc" "github.com/ProtonMail/go-crypto/openpgp/internal/encoding" + "github.com/ProtonMail/go-crypto/openpgp/x25519" + "github.com/ProtonMail/go-crypto/openpgp/x448" ) -type kdfHashFunction byte -type kdfAlgorithm byte - // PublicKey represents an OpenPGP public key. See RFC 4880, section 5.5.2. type PublicKey struct { Version int CreationTime time.Time PubKeyAlgo PublicKeyAlgorithm - PublicKey interface{} // *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey or *eddsa.PublicKey + PublicKey interface{} // *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey or *eddsa.PublicKey, *x25519.PublicKey, *x448.PublicKey, *ed25519.PublicKey, *ed448.PublicKey Fingerprint []byte KeyId uint64 IsSubkey bool @@ -61,11 +61,19 @@ func (pk *PublicKey) UpgradeToV5() { pk.setFingerprintAndKeyId() } +// UpgradeToV6 updates the version of the key to v6, and updates all necessary +// fields. +func (pk *PublicKey) UpgradeToV6() error { + pk.Version = 6 + pk.setFingerprintAndKeyId() + return pk.checkV6Compatibility() +} + // signingKey provides a convenient abstraction over signature verification // for v3 and v4 public keys. type signingKey interface { SerializeForHash(io.Writer) error - SerializeSignaturePrefix(io.Writer) + SerializeSignaturePrefix(io.Writer) error serializeWithoutHeaders(io.Writer) error } @@ -174,6 +182,54 @@ func NewEdDSAPublicKey(creationTime time.Time, pub *eddsa.PublicKey) *PublicKey return pk } +func NewX25519PublicKey(creationTime time.Time, pub *x25519.PublicKey) *PublicKey { + pk := &PublicKey{ + Version: 4, + CreationTime: creationTime, + PubKeyAlgo: PubKeyAlgoX25519, + PublicKey: pub, + } + + pk.setFingerprintAndKeyId() + return pk +} + +func NewX448PublicKey(creationTime time.Time, pub *x448.PublicKey) *PublicKey { + pk := &PublicKey{ + Version: 4, + CreationTime: creationTime, + PubKeyAlgo: PubKeyAlgoX448, + PublicKey: pub, + } + + pk.setFingerprintAndKeyId() + return pk +} + +func NewEd25519PublicKey(creationTime time.Time, pub *ed25519.PublicKey) *PublicKey { + pk := &PublicKey{ + Version: 4, + CreationTime: creationTime, + PubKeyAlgo: PubKeyAlgoEd25519, + PublicKey: pub, + } + + pk.setFingerprintAndKeyId() + return pk +} + +func NewEd448PublicKey(creationTime time.Time, pub *ed448.PublicKey) *PublicKey { + pk := &PublicKey{ + Version: 4, + CreationTime: creationTime, + PubKeyAlgo: PubKeyAlgoEd448, + PublicKey: pub, + } + + pk.setFingerprintAndKeyId() + return pk +} + func (pk *PublicKey) parse(r io.Reader) (err error) { // RFC 4880, section 5.5.2 var buf [6]byte @@ -181,12 +237,19 @@ func (pk *PublicKey) parse(r io.Reader) (err error) { if err != nil { return } - if buf[0] != 4 && buf[0] != 5 { + + pk.Version = int(buf[0]) + if pk.Version != 4 && pk.Version != 5 && pk.Version != 6 { return errors.UnsupportedError("public key version " + strconv.Itoa(int(buf[0]))) } - pk.Version = int(buf[0]) - if pk.Version == 5 { + if V5Disabled && pk.Version == 5 { + return errors.UnsupportedError("support for parsing v5 entities is disabled; build with `-tags v5` if needed") + } + + if pk.Version >= 5 { + // Read the four-octet scalar octet count + // The count is not used in this implementation var n [4]byte _, err = readFull(r, n[:]) if err != nil { @@ -195,6 +258,7 @@ func (pk *PublicKey) parse(r io.Reader) (err error) { } pk.CreationTime = time.Unix(int64(uint32(buf[1])<<24|uint32(buf[2])<<16|uint32(buf[3])<<8|uint32(buf[4])), 0) pk.PubKeyAlgo = PublicKeyAlgorithm(buf[5]) + // Ignore four-ocet length switch pk.PubKeyAlgo { case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly: err = pk.parseRSA(r) @@ -208,6 +272,14 @@ func (pk *PublicKey) parse(r io.Reader) (err error) { err = pk.parseECDH(r) case PubKeyAlgoEdDSA: err = pk.parseEdDSA(r) + case PubKeyAlgoX25519: + err = pk.parseX25519(r) + case PubKeyAlgoX448: + err = pk.parseX448(r) + case PubKeyAlgoEd25519: + err = pk.parseEd25519(r) + case PubKeyAlgoEd448: + err = pk.parseEd448(r) default: err = errors.UnsupportedError("public key type: " + strconv.Itoa(int(pk.PubKeyAlgo))) } @@ -221,21 +293,44 @@ func (pk *PublicKey) parse(r io.Reader) (err error) { func (pk *PublicKey) setFingerprintAndKeyId() { // RFC 4880, section 12.2 - if pk.Version == 5 { + if pk.Version >= 5 { fingerprint := sha256.New() - pk.SerializeForHash(fingerprint) + if err := pk.SerializeForHash(fingerprint); err != nil { + // Should not happen for a hash. + panic(err) + } pk.Fingerprint = make([]byte, 32) copy(pk.Fingerprint, fingerprint.Sum(nil)) pk.KeyId = binary.BigEndian.Uint64(pk.Fingerprint[:8]) } else { fingerprint := sha1.New() - pk.SerializeForHash(fingerprint) + if err := pk.SerializeForHash(fingerprint); err != nil { + // Should not happen for a hash. + panic(err) + } pk.Fingerprint = make([]byte, 20) copy(pk.Fingerprint, fingerprint.Sum(nil)) pk.KeyId = binary.BigEndian.Uint64(pk.Fingerprint[12:20]) } } +func (pk *PublicKey) checkV6Compatibility() error { + // Implementations MUST NOT accept or generate version 6 key material using the deprecated OIDs. + switch pk.PubKeyAlgo { + case PubKeyAlgoECDH: + curveInfo := ecc.FindByOid(pk.oid) + if curveInfo == nil { + return errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid)) + } + if curveInfo.GenName == ecc.Curve25519GenName { + return errors.StructuralError("cannot generate v6 key with deprecated OID: Curve25519Legacy") + } + case PubKeyAlgoEdDSA: + return errors.StructuralError("cannot generate v6 key with deprecated algorithm: EdDSALegacy") + } + return nil +} + // parseRSA parses RSA public key material from the given Reader. See RFC 4880, // section 5.5.2. func (pk *PublicKey) parseRSA(r io.Reader) (err error) { @@ -324,16 +419,17 @@ func (pk *PublicKey) parseECDSA(r io.Reader) (err error) { if _, err = pk.oid.ReadFrom(r); err != nil { return } - pk.p = new(encoding.MPI) - if _, err = pk.p.ReadFrom(r); err != nil { - return - } curveInfo := ecc.FindByOid(pk.oid) if curveInfo == nil { return errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid)) } + pk.p = new(encoding.MPI) + if _, err = pk.p.ReadFrom(r); err != nil { + return + } + c, ok := curveInfo.Curve.(ecc.ECDSACurve) if !ok { return errors.UnsupportedError(fmt.Sprintf("unsupported oid: %x", pk.oid)) @@ -353,6 +449,17 @@ func (pk *PublicKey) parseECDH(r io.Reader) (err error) { if _, err = pk.oid.ReadFrom(r); err != nil { return } + + curveInfo := ecc.FindByOid(pk.oid) + if curveInfo == nil { + return errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid)) + } + + if pk.Version == 6 && curveInfo.GenName == ecc.Curve25519GenName { + // Implementations MUST NOT accept or generate version 6 key material using the deprecated OIDs. + return errors.StructuralError("cannot read v6 key with deprecated OID: Curve25519Legacy") + } + pk.p = new(encoding.MPI) if _, err = pk.p.ReadFrom(r); err != nil { return @@ -362,12 +469,6 @@ func (pk *PublicKey) parseECDH(r io.Reader) (err error) { return } - curveInfo := ecc.FindByOid(pk.oid) - - if curveInfo == nil { - return errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid)) - } - c, ok := curveInfo.Curve.(ecc.ECDHCurve) if !ok { return errors.UnsupportedError(fmt.Sprintf("unsupported oid: %x", pk.oid)) @@ -396,10 +497,16 @@ func (pk *PublicKey) parseECDH(r io.Reader) (err error) { } func (pk *PublicKey) parseEdDSA(r io.Reader) (err error) { + if pk.Version == 6 { + // Implementations MUST NOT accept or generate version 6 key material using the deprecated OIDs. + return errors.StructuralError("cannot generate v6 key with deprecated algorithm: EdDSALegacy") + } + pk.oid = new(encoding.OID) if _, err = pk.oid.ReadFrom(r); err != nil { return } + curveInfo := ecc.FindByOid(pk.oid) if curveInfo == nil { return errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid)) @@ -435,75 +542,145 @@ func (pk *PublicKey) parseEdDSA(r io.Reader) (err error) { return } +func (pk *PublicKey) parseX25519(r io.Reader) (err error) { + point := make([]byte, x25519.KeySize) + _, err = io.ReadFull(r, point) + if err != nil { + return + } + pub := &x25519.PublicKey{ + Point: point, + } + pk.PublicKey = pub + return +} + +func (pk *PublicKey) parseX448(r io.Reader) (err error) { + point := make([]byte, x448.KeySize) + _, err = io.ReadFull(r, point) + if err != nil { + return + } + pub := &x448.PublicKey{ + Point: point, + } + pk.PublicKey = pub + return +} + +func (pk *PublicKey) parseEd25519(r io.Reader) (err error) { + point := make([]byte, ed25519.PublicKeySize) + _, err = io.ReadFull(r, point) + if err != nil { + return + } + pub := &ed25519.PublicKey{ + Point: point, + } + pk.PublicKey = pub + return +} + +func (pk *PublicKey) parseEd448(r io.Reader) (err error) { + point := make([]byte, ed448.PublicKeySize) + _, err = io.ReadFull(r, point) + if err != nil { + return + } + pub := &ed448.PublicKey{ + Point: point, + } + pk.PublicKey = pub + return +} + // SerializeForHash serializes the PublicKey to w with the special packet // header format needed for hashing. func (pk *PublicKey) SerializeForHash(w io.Writer) error { - pk.SerializeSignaturePrefix(w) + if err := pk.SerializeSignaturePrefix(w); err != nil { + return err + } return pk.serializeWithoutHeaders(w) } // SerializeSignaturePrefix writes the prefix for this public key to the given Writer. // The prefix is used when calculating a signature over this public key. See // RFC 4880, section 5.2.4. -func (pk *PublicKey) SerializeSignaturePrefix(w io.Writer) { +func (pk *PublicKey) SerializeSignaturePrefix(w io.Writer) error { var pLength = pk.algorithmSpecificByteCount() - if pk.Version == 5 { - pLength += 10 // version, timestamp (4), algorithm, key octet count (4). - w.Write([]byte{ - 0x9A, + // version, timestamp, algorithm + pLength += versionSize + timestampSize + algorithmSize + if pk.Version >= 5 { + // key octet count (4). + pLength += 4 + _, err := w.Write([]byte{ + // When a v4 signature is made over a key, the hash data starts with the octet 0x99, followed by a two-octet length + // of the key, and then the body of the key packet. When a v6 signature is made over a key, the hash data starts + // with the salt, then octet 0x9B, followed by a four-octet length of the key, and then the body of the key packet. + 0x95 + byte(pk.Version), byte(pLength >> 24), byte(pLength >> 16), byte(pLength >> 8), byte(pLength), }) - return + return err } - pLength += 6 - w.Write([]byte{0x99, byte(pLength >> 8), byte(pLength)}) + if _, err := w.Write([]byte{0x99, byte(pLength >> 8), byte(pLength)}); err != nil { + return err + } + return nil } func (pk *PublicKey) Serialize(w io.Writer) (err error) { - length := 6 // 6 byte header + length := uint32(versionSize + timestampSize + algorithmSize) // 6 byte header length += pk.algorithmSpecificByteCount() - if pk.Version == 5 { + if pk.Version >= 5 { length += 4 // octet key count } packetType := packetTypePublicKey if pk.IsSubkey { packetType = packetTypePublicSubkey } - err = serializeHeader(w, packetType, length) + err = serializeHeader(w, packetType, int(length)) if err != nil { return } return pk.serializeWithoutHeaders(w) } -func (pk *PublicKey) algorithmSpecificByteCount() int { - length := 0 +func (pk *PublicKey) algorithmSpecificByteCount() uint32 { + length := uint32(0) switch pk.PubKeyAlgo { case PubKeyAlgoRSA, PubKeyAlgoRSAEncryptOnly, PubKeyAlgoRSASignOnly: - length += int(pk.n.EncodedLength()) - length += int(pk.e.EncodedLength()) + length += uint32(pk.n.EncodedLength()) + length += uint32(pk.e.EncodedLength()) case PubKeyAlgoDSA: - length += int(pk.p.EncodedLength()) - length += int(pk.q.EncodedLength()) - length += int(pk.g.EncodedLength()) - length += int(pk.y.EncodedLength()) + length += uint32(pk.p.EncodedLength()) + length += uint32(pk.q.EncodedLength()) + length += uint32(pk.g.EncodedLength()) + length += uint32(pk.y.EncodedLength()) case PubKeyAlgoElGamal: - length += int(pk.p.EncodedLength()) - length += int(pk.g.EncodedLength()) - length += int(pk.y.EncodedLength()) + length += uint32(pk.p.EncodedLength()) + length += uint32(pk.g.EncodedLength()) + length += uint32(pk.y.EncodedLength()) case PubKeyAlgoECDSA: - length += int(pk.oid.EncodedLength()) - length += int(pk.p.EncodedLength()) + length += uint32(pk.oid.EncodedLength()) + length += uint32(pk.p.EncodedLength()) case PubKeyAlgoECDH: - length += int(pk.oid.EncodedLength()) - length += int(pk.p.EncodedLength()) - length += int(pk.kdf.EncodedLength()) + length += uint32(pk.oid.EncodedLength()) + length += uint32(pk.p.EncodedLength()) + length += uint32(pk.kdf.EncodedLength()) case PubKeyAlgoEdDSA: - length += int(pk.oid.EncodedLength()) - length += int(pk.p.EncodedLength()) + length += uint32(pk.oid.EncodedLength()) + length += uint32(pk.p.EncodedLength()) + case PubKeyAlgoX25519: + length += x25519.KeySize + case PubKeyAlgoX448: + length += x448.KeySize + case PubKeyAlgoEd25519: + length += ed25519.PublicKeySize + case PubKeyAlgoEd448: + length += ed448.PublicKeySize default: panic("unknown public key algorithm") } @@ -522,7 +699,7 @@ func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err error) { return } - if pk.Version == 5 { + if pk.Version >= 5 { n := pk.algorithmSpecificByteCount() if _, err = w.Write([]byte{ byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n), @@ -580,6 +757,22 @@ func (pk *PublicKey) serializeWithoutHeaders(w io.Writer) (err error) { } _, err = w.Write(pk.p.EncodedBytes()) return + case PubKeyAlgoX25519: + publicKey := pk.PublicKey.(*x25519.PublicKey) + _, err = w.Write(publicKey.Point) + return + case PubKeyAlgoX448: + publicKey := pk.PublicKey.(*x448.PublicKey) + _, err = w.Write(publicKey.Point) + return + case PubKeyAlgoEd25519: + publicKey := pk.PublicKey.(*ed25519.PublicKey) + _, err = w.Write(publicKey.Point) + return + case PubKeyAlgoEd448: + publicKey := pk.PublicKey.(*ed448.PublicKey) + _, err = w.Write(publicKey.Point) + return } return errors.InvalidArgumentError("bad public-key algorithm") } @@ -589,6 +782,20 @@ func (pk *PublicKey) CanSign() bool { return pk.PubKeyAlgo != PubKeyAlgoRSAEncryptOnly && pk.PubKeyAlgo != PubKeyAlgoElGamal && pk.PubKeyAlgo != PubKeyAlgoECDH } +// VerifyHashTag returns nil iff sig appears to be a plausible signature of the data +// hashed into signed, based solely on its HashTag. signed is mutated by this call. +func VerifyHashTag(signed hash.Hash, sig *Signature) (err error) { + if sig.Version == 5 && (sig.SigType == 0x00 || sig.SigType == 0x01) { + sig.AddMetadataToHashSuffix() + } + signed.Write(sig.HashSuffix) + hashBytes := signed.Sum(nil) + if hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1] { + return errors.SignatureError("hash tag doesn't match") + } + return nil +} + // VerifySignature returns nil iff sig is a valid signature, made by this // public key, of the data hashed into signed. signed is mutated by this call. func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err error) { @@ -600,7 +807,8 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro } signed.Write(sig.HashSuffix) hashBytes := signed.Sum(nil) - if sig.Version == 5 && (hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1]) { + // see discussion https://github.com/ProtonMail/go-crypto/issues/107 + if sig.Version >= 5 && (hashBytes[0] != sig.HashTag[0] || hashBytes[1] != sig.HashTag[1]) { return errors.SignatureError("hash tag doesn't match") } @@ -639,6 +847,18 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro return errors.SignatureError("EdDSA verification failure") } return nil + case PubKeyAlgoEd25519: + ed25519PublicKey := pk.PublicKey.(*ed25519.PublicKey) + if !ed25519.Verify(ed25519PublicKey, hashBytes, sig.EdSig) { + return errors.SignatureError("Ed25519 verification failure") + } + return nil + case PubKeyAlgoEd448: + ed448PublicKey := pk.PublicKey.(*ed448.PublicKey) + if !ed448.Verify(ed448PublicKey, hashBytes, sig.EdSig) { + return errors.SignatureError("ed448 verification failure") + } + return nil default: return errors.SignatureError("Unsupported public key algorithm used in signature") } @@ -646,11 +866,8 @@ func (pk *PublicKey) VerifySignature(signed hash.Hash, sig *Signature) (err erro // keySignatureHash returns a Hash of the message that needs to be signed for // pk to assert a subkey relationship to signed. -func keySignatureHash(pk, signed signingKey, hashFunc crypto.Hash) (h hash.Hash, err error) { - if !hashFunc.Available() { - return nil, errors.UnsupportedError("hash function") - } - h = hashFunc.New() +func keySignatureHash(pk, signed signingKey, hashFunc hash.Hash) (h hash.Hash, err error) { + h = hashFunc // RFC 4880, section 5.2.4 err = pk.SerializeForHash(h) @@ -662,10 +879,28 @@ func keySignatureHash(pk, signed signingKey, hashFunc crypto.Hash) (h hash.Hash, return } +// VerifyKeyHashTag returns nil iff sig appears to be a plausible signature over this +// primary key and subkey, based solely on its HashTag. +func (pk *PublicKey) VerifyKeyHashTag(signed *PublicKey, sig *Signature) error { + preparedHash, err := sig.PrepareVerify() + if err != nil { + return err + } + h, err := keySignatureHash(pk, signed, preparedHash) + if err != nil { + return err + } + return VerifyHashTag(h, sig) +} + // VerifyKeySignature returns nil iff sig is a valid signature, made by this // public key, of signed. func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) error { - h, err := keySignatureHash(pk, signed, sig.Hash) + preparedHash, err := sig.PrepareVerify() + if err != nil { + return err + } + h, err := keySignatureHash(pk, signed, preparedHash) if err != nil { return err } @@ -679,10 +914,14 @@ func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) error if sig.EmbeddedSignature == nil { return errors.StructuralError("signing subkey is missing cross-signature") } + preparedHashEmbedded, err := sig.EmbeddedSignature.PrepareVerify() + if err != nil { + return err + } // Verify the cross-signature. This is calculated over the same // data as the main signature, so we cannot just recursively // call signed.VerifyKeySignature(...) - if h, err = keySignatureHash(pk, signed, sig.EmbeddedSignature.Hash); err != nil { + if h, err = keySignatureHash(pk, signed, preparedHashEmbedded); err != nil { return errors.StructuralError("error while hashing for cross-signature: " + err.Error()) } if err := signed.VerifySignature(h, sig.EmbeddedSignature); err != nil { @@ -693,32 +932,44 @@ func (pk *PublicKey) VerifyKeySignature(signed *PublicKey, sig *Signature) error return nil } -func keyRevocationHash(pk signingKey, hashFunc crypto.Hash) (h hash.Hash, err error) { - if !hashFunc.Available() { - return nil, errors.UnsupportedError("hash function") - } - h = hashFunc.New() - - // RFC 4880, section 5.2.4 - err = pk.SerializeForHash(h) +func keyRevocationHash(pk signingKey, hashFunc hash.Hash) (err error) { + return pk.SerializeForHash(hashFunc) +} - return +// VerifyRevocationHashTag returns nil iff sig appears to be a plausible signature +// over this public key, based solely on its HashTag. +func (pk *PublicKey) VerifyRevocationHashTag(sig *Signature) (err error) { + preparedHash, err := sig.PrepareVerify() + if err != nil { + return err + } + if err = keyRevocationHash(pk, preparedHash); err != nil { + return err + } + return VerifyHashTag(preparedHash, sig) } // VerifyRevocationSignature returns nil iff sig is a valid signature, made by this // public key. func (pk *PublicKey) VerifyRevocationSignature(sig *Signature) (err error) { - h, err := keyRevocationHash(pk, sig.Hash) + preparedHash, err := sig.PrepareVerify() if err != nil { return err } - return pk.VerifySignature(h, sig) + if err = keyRevocationHash(pk, preparedHash); err != nil { + return err + } + return pk.VerifySignature(preparedHash, sig) } // VerifySubkeyRevocationSignature returns nil iff sig is a valid subkey revocation signature, // made by this public key, of signed. func (pk *PublicKey) VerifySubkeyRevocationSignature(sig *Signature, signed *PublicKey) (err error) { - h, err := keySignatureHash(pk, signed, sig.Hash) + preparedHash, err := sig.PrepareVerify() + if err != nil { + return err + } + h, err := keySignatureHash(pk, signed, preparedHash) if err != nil { return err } @@ -727,15 +978,15 @@ func (pk *PublicKey) VerifySubkeyRevocationSignature(sig *Signature, signed *Pub // userIdSignatureHash returns a Hash of the message that needs to be signed // to assert that pk is a valid key for id. -func userIdSignatureHash(id string, pk *PublicKey, hashFunc crypto.Hash) (h hash.Hash, err error) { - if !hashFunc.Available() { - return nil, errors.UnsupportedError("hash function") - } - h = hashFunc.New() +func userIdSignatureHash(id string, pk *PublicKey, h hash.Hash) (err error) { // RFC 4880, section 5.2.4 - pk.SerializeSignaturePrefix(h) - pk.serializeWithoutHeaders(h) + if err := pk.SerializeSignaturePrefix(h); err != nil { + return err + } + if err := pk.serializeWithoutHeaders(h); err != nil { + return err + } var buf [5]byte buf[0] = 0xb4 @@ -746,16 +997,51 @@ func userIdSignatureHash(id string, pk *PublicKey, hashFunc crypto.Hash) (h hash h.Write(buf[:]) h.Write([]byte(id)) - return + return nil +} + +// directKeySignatureHash returns a Hash of the message that needs to be signed. +func directKeySignatureHash(pk *PublicKey, h hash.Hash) (err error) { + return pk.SerializeForHash(h) +} + +// VerifyUserIdHashTag returns nil iff sig appears to be a plausible signature over this +// public key and UserId, based solely on its HashTag +func (pk *PublicKey) VerifyUserIdHashTag(id string, sig *Signature) (err error) { + preparedHash, err := sig.PrepareVerify() + if err != nil { + return err + } + err = userIdSignatureHash(id, pk, preparedHash) + if err != nil { + return err + } + return VerifyHashTag(preparedHash, sig) } // VerifyUserIdSignature returns nil iff sig is a valid signature, made by this // public key, that id is the identity of pub. func (pk *PublicKey) VerifyUserIdSignature(id string, pub *PublicKey, sig *Signature) (err error) { - h, err := userIdSignatureHash(id, pub, sig.Hash) + h, err := sig.PrepareVerify() + if err != nil { + return err + } + if err := userIdSignatureHash(id, pub, h); err != nil { + return err + } + return pk.VerifySignature(h, sig) +} + +// VerifyDirectKeySignature returns nil iff sig is a valid signature, made by this +// public key. +func (pk *PublicKey) VerifyDirectKeySignature(sig *Signature) (err error) { + h, err := sig.PrepareVerify() if err != nil { return err } + if err := directKeySignatureHash(pk, h); err != nil { + return err + } return pk.VerifySignature(h, sig) } @@ -786,21 +1072,49 @@ func (pk *PublicKey) BitLength() (bitLength uint16, err error) { bitLength = pk.p.BitLength() case PubKeyAlgoEdDSA: bitLength = pk.p.BitLength() + case PubKeyAlgoX25519: + bitLength = x25519.KeySize * 8 + case PubKeyAlgoX448: + bitLength = x448.KeySize * 8 + case PubKeyAlgoEd25519: + bitLength = ed25519.PublicKeySize * 8 + case PubKeyAlgoEd448: + bitLength = ed448.PublicKeySize * 8 default: err = errors.InvalidArgumentError("bad public-key algorithm") } return } +// Curve returns the used elliptic curve of this public key. +// Returns an error if no elliptic curve is used. +func (pk *PublicKey) Curve() (curve Curve, err error) { + switch pk.PubKeyAlgo { + case PubKeyAlgoECDSA, PubKeyAlgoECDH, PubKeyAlgoEdDSA: + curveInfo := ecc.FindByOid(pk.oid) + if curveInfo == nil { + return "", errors.UnsupportedError(fmt.Sprintf("unknown oid: %x", pk.oid)) + } + curve = Curve(curveInfo.GenName) + case PubKeyAlgoEd25519, PubKeyAlgoX25519: + curve = Curve25519 + case PubKeyAlgoEd448, PubKeyAlgoX448: + curve = Curve448 + default: + err = errors.InvalidArgumentError("public key does not operate with an elliptic curve") + } + return +} + // KeyExpired returns whether sig is a self-signature of a key that has // expired or is created in the future. func (pk *PublicKey) KeyExpired(sig *Signature, currentTime time.Time) bool { - if pk.CreationTime.After(currentTime) { + if pk.CreationTime.Unix() > currentTime.Unix() { return true } if sig.KeyLifetimeSecs == nil || *sig.KeyLifetimeSecs == 0 { return false } expiry := pk.CreationTime.Add(time.Duration(*sig.KeyLifetimeSecs) * time.Second) - return currentTime.After(expiry) + return currentTime.Unix() > expiry.Unix() } diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/reader.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/reader.go index 10215fe5f2..dd84092392 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/reader.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/reader.go @@ -10,6 +10,12 @@ import ( "github.com/ProtonMail/go-crypto/openpgp/errors" ) +type PacketReader interface { + Next() (p Packet, err error) + Push(reader io.Reader) (err error) + Unread(p Packet) +} + // Reader reads packets from an io.Reader and allows packets to be 'unread' so // that they result from the next call to Next. type Reader struct { @@ -26,37 +32,81 @@ type Reader struct { const maxReaders = 32 // Next returns the most recently unread Packet, or reads another packet from -// the top-most io.Reader. Unknown packet types are skipped. +// the top-most io.Reader. Unknown/unsupported/Marker packet types are skipped. func (r *Reader) Next() (p Packet, err error) { + for { + p, err := r.read() + if err == io.EOF { + break + } else if err != nil { + if _, ok := err.(errors.UnknownPacketTypeError); ok { + continue + } + if _, ok := err.(errors.UnsupportedError); ok { + switch p.(type) { + case *SymmetricallyEncrypted, *AEADEncrypted, *Compressed, *LiteralData: + return nil, err + } + continue + } + return nil, err + } else { + //A marker packet MUST be ignored when received + switch p.(type) { + case *Marker: + continue + } + return p, nil + } + } + return nil, io.EOF +} + +// Next returns the most recently unread Packet, or reads another packet from +// the top-most io.Reader. Unknown/Marker packet types are skipped while unsupported +// packets are returned as UnsupportedPacket type. +func (r *Reader) NextWithUnsupported() (p Packet, err error) { + for { + p, err = r.read() + if err == io.EOF { + break + } else if err != nil { + if _, ok := err.(errors.UnknownPacketTypeError); ok { + continue + } + if casteErr, ok := err.(errors.UnsupportedError); ok { + return &UnsupportedPacket{ + IncompletePacket: p, + Error: casteErr, + }, nil + } + return + } else { + //A marker packet MUST be ignored when received + switch p.(type) { + case *Marker: + continue + } + return + } + } + return nil, io.EOF +} + +func (r *Reader) read() (p Packet, err error) { if len(r.q) > 0 { p = r.q[len(r.q)-1] r.q = r.q[:len(r.q)-1] return } - for len(r.readers) > 0 { p, err = Read(r.readers[len(r.readers)-1]) - if err == nil { - return - } if err == io.EOF { r.readers = r.readers[:len(r.readers)-1] continue } - // TODO: Add strict mode that rejects unknown packets, instead of ignoring them. - if _, ok := err.(errors.UnknownPacketTypeError); ok { - continue - } - if _, ok := err.(errors.UnsupportedError); ok { - switch p.(type) { - case *SymmetricallyEncrypted, *AEADEncrypted, *Compressed, *LiteralData: - return nil, err - } - continue - } - return nil, err + return p, err } - return nil, io.EOF } @@ -84,3 +134,76 @@ func NewReader(r io.Reader) *Reader { readers: []io.Reader{r}, } } + +// CheckReader is similar to Reader but additionally +// uses the pushdown automata to verify the read packet sequence. +type CheckReader struct { + Reader + verifier *SequenceVerifier + fullyRead bool +} + +// Next returns the most recently unread Packet, or reads another packet from +// the top-most io.Reader. Unknown packet types are skipped. +// If the read packet sequence does not conform to the packet composition +// rules in rfc4880, it returns an error. +func (r *CheckReader) Next() (p Packet, err error) { + if r.fullyRead { + return nil, io.EOF + } + if len(r.q) > 0 { + p = r.q[len(r.q)-1] + r.q = r.q[:len(r.q)-1] + return + } + var errMsg error + for len(r.readers) > 0 { + p, errMsg, err = ReadWithCheck(r.readers[len(r.readers)-1], r.verifier) + if errMsg != nil { + err = errMsg + return + } + if err == nil { + return + } + if err == io.EOF { + r.readers = r.readers[:len(r.readers)-1] + continue + } + //A marker packet MUST be ignored when received + switch p.(type) { + case *Marker: + continue + } + if _, ok := err.(errors.UnknownPacketTypeError); ok { + continue + } + if _, ok := err.(errors.UnsupportedError); ok { + switch p.(type) { + case *SymmetricallyEncrypted, *AEADEncrypted, *Compressed, *LiteralData: + return nil, err + } + continue + } + return nil, err + } + if errMsg = r.verifier.Next(EOSSymbol); errMsg != nil { + return nil, errMsg + } + if errMsg = r.verifier.AssertValid(); errMsg != nil { + return nil, errMsg + } + r.fullyRead = true + return nil, io.EOF +} + +func NewCheckReader(r io.Reader) *CheckReader { + return &CheckReader{ + Reader: Reader{ + q: nil, + readers: []io.Reader{r}, + }, + verifier: NewSequenceVerifier(), + fullyRead: false, + } +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/recipient.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/recipient.go new file mode 100644 index 0000000000..fb2e362e4a --- /dev/null +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/recipient.go @@ -0,0 +1,15 @@ +package packet + +// Recipient type represents a Intended Recipient Fingerprint subpacket +// See https://datatracker.ietf.org/doc/html/draft-ietf-openpgp-crypto-refresh#name-intended-recipient-fingerpr +type Recipient struct { + KeyVersion int + Fingerprint []byte +} + +func (r *Recipient) Serialize() []byte { + packet := make([]byte, len(r.Fingerprint)+1) + packet[0] = byte(r.KeyVersion) + copy(packet[1:], r.Fingerprint) + return packet +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/signature.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/signature.go index 80d0bb98e0..3a4b366d87 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/signature.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/signature.go @@ -8,13 +8,17 @@ import ( "bytes" "crypto" "crypto/dsa" + "encoding/asn1" "encoding/binary" "hash" "io" + "math/big" "strconv" "time" "github.com/ProtonMail/go-crypto/openpgp/ecdsa" + "github.com/ProtonMail/go-crypto/openpgp/ed25519" + "github.com/ProtonMail/go-crypto/openpgp/ed448" "github.com/ProtonMail/go-crypto/openpgp/eddsa" "github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/internal/algorithm" @@ -22,7 +26,8 @@ import ( ) const ( - // See RFC 4880, section 5.2.3.21 for details. + // First octet of key flags. + // See RFC 9580, section 5.2.3.29 for details. KeyFlagCertify = 1 << iota KeyFlagSign KeyFlagEncryptCommunications @@ -33,12 +38,30 @@ const ( KeyFlagGroupKey ) -// Signature represents a signature. See RFC 4880, section 5.2. +const ( + // First octet of keyserver preference flags. + // See RFC 9580, section 5.2.3.25 for details. + _ = 1 << iota + _ + _ + _ + _ + _ + _ + KeyserverPrefNoModify +) + +const SaltNotationName = "salt@notations.openpgpjs.org" + +// Signature represents a signature. See RFC 9580, section 5.2. type Signature struct { Version int SigType SignatureType PubKeyAlgo PublicKeyAlgorithm Hash crypto.Hash + // salt contains a random salt value for v6 signatures + // See RFC 9580 Section 5.2.4. + salt []byte // HashSuffix is extra data that is hashed in after the signed data. HashSuffix []byte @@ -57,6 +80,7 @@ type Signature struct { DSASigR, DSASigS encoding.Field ECDSASigR, ECDSASigS encoding.Field EdDSASigR, EdDSASigS encoding.Field + EdSig []byte // rawSubpackets contains the unparsed subpackets, in order. rawSubpackets []outputSubpacket @@ -72,31 +96,42 @@ type Signature struct { SignerUserId *string IsPrimaryId *bool Notations []*Notation + IntendedRecipients []*Recipient // TrustLevel and TrustAmount can be set by the signer to assert that // the key is not only valid but also trustworthy at the specified // level. - // See RFC 4880, section 5.2.3.13 for details. + // See RFC 9580, section 5.2.3.21 for details. TrustLevel TrustLevel TrustAmount TrustAmount // TrustRegularExpression can be used in conjunction with trust Signature // packets to limit the scope of the trust that is extended. - // See RFC 4880, section 5.2.3.14 for details. + // See RFC 9580, section 5.2.3.22 for details. TrustRegularExpression *string + // KeyserverPrefsValid is set if any keyserver preferences were given. See RFC 9580, section + // 5.2.3.25 for details. + KeyserverPrefsValid bool + KeyserverPrefNoModify bool + + // PreferredKeyserver can be set to a URI where the latest version of the + // key that this signature is made over can be found. See RFC 9580, section + // 5.2.3.26 for details. + PreferredKeyserver string + // PolicyURI can be set to the URI of a document that describes the - // policy under which the signature was issued. See RFC 4880, section - // 5.2.3.20 for details. + // policy under which the signature was issued. See RFC 9580, section + // 5.2.3.28 for details. PolicyURI string - // FlagsValid is set if any flags were given. See RFC 4880, section - // 5.2.3.21 for details. + // FlagsValid is set if any flags were given. See RFC 9580, section + // 5.2.3.29 for details. FlagsValid bool FlagCertify, FlagSign, FlagEncryptCommunications, FlagEncryptStorage, FlagSplitKey, FlagAuthenticate, FlagGroupKey bool // RevocationReason is set if this signature has been revoked. - // See RFC 4880, section 5.2.3.23 for details. + // See RFC 9580, section 5.2.3.31 for details. RevocationReason *ReasonForRevocation RevocationReasonText string @@ -113,26 +148,57 @@ type Signature struct { outSubpackets []outputSubpacket } +// VerifiableSignature internally keeps state if the +// the signature has been verified before. +type VerifiableSignature struct { + Valid *bool // nil if it has not been verified yet + Packet *Signature +} + +// NewVerifiableSig returns a struct of type VerifiableSignature referencing the input signature. +func NewVerifiableSig(signature *Signature) *VerifiableSignature { + return &VerifiableSignature{ + Packet: signature, + } +} + +// Salt returns the signature salt for v6 signatures. +func (sig *Signature) Salt() []byte { + if sig == nil { + return nil + } + return sig.salt +} + func (sig *Signature) parse(r io.Reader) (err error) { - // RFC 4880, section 5.2.3 - var buf [5]byte + // RFC 9580, section 5.2.3 + var buf [7]byte _, err = readFull(r, buf[:1]) if err != nil { return } - if buf[0] != 4 && buf[0] != 5 { + sig.Version = int(buf[0]) + if sig.Version != 4 && sig.Version != 5 && sig.Version != 6 { err = errors.UnsupportedError("signature packet version " + strconv.Itoa(int(buf[0]))) return } - sig.Version = int(buf[0]) - _, err = readFull(r, buf[:5]) + + if V5Disabled && sig.Version == 5 { + return errors.UnsupportedError("support for parsing v5 entities is disabled; build with `-tags v5` if needed") + } + + if sig.Version == 6 { + _, err = readFull(r, buf[:7]) + } else { + _, err = readFull(r, buf[:5]) + } if err != nil { return } sig.SigType = SignatureType(buf[0]) sig.PubKeyAlgo = PublicKeyAlgorithm(buf[1]) switch sig.PubKeyAlgo { - case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA, PubKeyAlgoECDSA, PubKeyAlgoEdDSA: + case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly, PubKeyAlgoDSA, PubKeyAlgoECDSA, PubKeyAlgoEdDSA, PubKeyAlgoEd25519, PubKeyAlgoEd448: default: err = errors.UnsupportedError("public key algorithm " + strconv.Itoa(int(sig.PubKeyAlgo))) return @@ -150,7 +216,17 @@ func (sig *Signature) parse(r io.Reader) (err error) { return errors.UnsupportedError("hash function " + strconv.Itoa(int(buf[2]))) } - hashedSubpacketsLength := int(buf[3])<<8 | int(buf[4]) + var hashedSubpacketsLength int + if sig.Version == 6 { + // For a v6 signature, a four-octet length is used. + hashedSubpacketsLength = + int(buf[3])<<24 | + int(buf[4])<<16 | + int(buf[5])<<8 | + int(buf[6]) + } else { + hashedSubpacketsLength = int(buf[3])<<8 | int(buf[4]) + } hashedSubpackets := make([]byte, hashedSubpacketsLength) _, err = readFull(r, hashedSubpackets) if err != nil { @@ -166,11 +242,21 @@ func (sig *Signature) parse(r io.Reader) (err error) { return } - _, err = readFull(r, buf[:2]) + if sig.Version == 6 { + _, err = readFull(r, buf[:4]) + } else { + _, err = readFull(r, buf[:2]) + } + if err != nil { return } - unhashedSubpacketsLength := int(buf[0])<<8 | int(buf[1]) + var unhashedSubpacketsLength uint32 + if sig.Version == 6 { + unhashedSubpacketsLength = uint32(buf[0])<<24 | uint32(buf[1])<<16 | uint32(buf[2])<<8 | uint32(buf[3]) + } else { + unhashedSubpacketsLength = uint32(buf[0])<<8 | uint32(buf[1]) + } unhashedSubpackets := make([]byte, unhashedSubpacketsLength) _, err = readFull(r, unhashedSubpackets) if err != nil { @@ -186,6 +272,30 @@ func (sig *Signature) parse(r io.Reader) (err error) { return } + if sig.Version == 6 { + // Only for v6 signatures, a variable-length field containing the salt + _, err = readFull(r, buf[:1]) + if err != nil { + return + } + saltLength := int(buf[0]) + var expectedSaltLength int + expectedSaltLength, err = SaltLengthForHash(sig.Hash) + if err != nil { + return + } + if saltLength != expectedSaltLength { + err = errors.StructuralError("unexpected salt size for the given hash algorithm") + return + } + salt := make([]byte, expectedSaltLength) + _, err = readFull(r, salt) + if err != nil { + return + } + sig.salt = salt + } + switch sig.PubKeyAlgo { case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly: sig.RSASignature = new(encoding.MPI) @@ -216,6 +326,16 @@ func (sig *Signature) parse(r io.Reader) (err error) { if _, err = sig.EdDSASigS.ReadFrom(r); err != nil { return } + case PubKeyAlgoEd25519: + sig.EdSig, err = ed25519.ReadSignature(r) + if err != nil { + return + } + case PubKeyAlgoEd448: + sig.EdSig, err = ed448.ReadSignature(r) + if err != nil { + return + } default: panic("unreachable") } @@ -223,7 +343,7 @@ func (sig *Signature) parse(r io.Reader) (err error) { } // parseSignatureSubpackets parses subpackets of the main signature packet. See -// RFC 4880, section 5.2.3.1. +// RFC 9580, section 5.2.3.1. func parseSignatureSubpackets(sig *Signature, subpackets []byte, isHashed bool) (err error) { for len(subpackets) > 0 { subpackets, err = parseSignatureSubpacket(sig, subpackets, isHashed) @@ -244,6 +364,7 @@ type signatureSubpacketType uint8 const ( creationTimeSubpacket signatureSubpacketType = 2 signatureExpirationSubpacket signatureSubpacketType = 3 + exportableCertSubpacket signatureSubpacketType = 4 trustSubpacket signatureSubpacketType = 5 regularExpressionSubpacket signatureSubpacketType = 6 keyExpirationSubpacket signatureSubpacketType = 9 @@ -252,6 +373,8 @@ const ( notationDataSubpacket signatureSubpacketType = 20 prefHashAlgosSubpacket signatureSubpacketType = 21 prefCompressionSubpacket signatureSubpacketType = 22 + keyserverPrefsSubpacket signatureSubpacketType = 23 + prefKeyserverSubpacket signatureSubpacketType = 24 primaryUserIdSubpacket signatureSubpacketType = 25 policyUriSubpacket signatureSubpacketType = 26 keyFlagsSubpacket signatureSubpacketType = 27 @@ -260,12 +383,13 @@ const ( featuresSubpacket signatureSubpacketType = 30 embeddedSignatureSubpacket signatureSubpacketType = 32 issuerFingerprintSubpacket signatureSubpacketType = 33 + intendedRecipientSubpacket signatureSubpacketType = 35 prefCipherSuitesSubpacket signatureSubpacketType = 39 ) // parseSignatureSubpacket parses a single subpacket. len(subpacket) is >= 1. func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (rest []byte, err error) { - // RFC 4880, section 5.2.3.1 + // RFC 9580, section 5.2.3.7 var ( length uint32 packetType signatureSubpacketType @@ -323,19 +447,24 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r t := binary.BigEndian.Uint32(subpacket) sig.CreationTime = time.Unix(int64(t), 0) case signatureExpirationSubpacket: - // Signature expiration time, section 5.2.3.10 + // Signature expiration time, section 5.2.3.18 if len(subpacket) != 4 { err = errors.StructuralError("expiration subpacket with bad length") return } sig.SigLifetimeSecs = new(uint32) *sig.SigLifetimeSecs = binary.BigEndian.Uint32(subpacket) + case exportableCertSubpacket: + if subpacket[0] == 0 { + err = errors.UnsupportedError("signature with non-exportable certification") + return + } case trustSubpacket: if len(subpacket) != 2 { err = errors.StructuralError("trust subpacket with bad length") return } - // Trust level and amount, section 5.2.3.13 + // Trust level and amount, section 5.2.3.21 sig.TrustLevel = TrustLevel(subpacket[0]) sig.TrustAmount = TrustAmount(subpacket[1]) case regularExpressionSubpacket: @@ -343,7 +472,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r err = errors.StructuralError("regexp subpacket with bad length") return } - // Trust regular expression, section 5.2.3.14 + // Trust regular expression, section 5.2.3.22 // RFC specifies the string should be null-terminated; remove a null byte from the end if subpacket[len(subpacket)-1] != 0x00 { err = errors.StructuralError("expected regular expression to be null-terminated") @@ -352,7 +481,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r trustRegularExpression := string(subpacket[:len(subpacket)-1]) sig.TrustRegularExpression = &trustRegularExpression case keyExpirationSubpacket: - // Key expiration time, section 5.2.3.6 + // Key expiration time, section 5.2.3.13 if len(subpacket) != 4 { err = errors.StructuralError("key expiration subpacket with bad length") return @@ -360,23 +489,25 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r sig.KeyLifetimeSecs = new(uint32) *sig.KeyLifetimeSecs = binary.BigEndian.Uint32(subpacket) case prefSymmetricAlgosSubpacket: - // Preferred symmetric algorithms, section 5.2.3.7 + // Preferred symmetric algorithms, section 5.2.3.14 sig.PreferredSymmetric = make([]byte, len(subpacket)) copy(sig.PreferredSymmetric, subpacket) case issuerSubpacket: - // Issuer, section 5.2.3.5 - if sig.Version > 4 { - err = errors.StructuralError("issuer subpacket found in v5 key") + // Issuer, section 5.2.3.12 + if sig.Version > 4 && isHashed { + err = errors.StructuralError("issuer subpacket found in v6 key") return } if len(subpacket) != 8 { err = errors.StructuralError("issuer subpacket with bad length") return } - sig.IssuerKeyId = new(uint64) - *sig.IssuerKeyId = binary.BigEndian.Uint64(subpacket) + if sig.Version <= 4 { + sig.IssuerKeyId = new(uint64) + *sig.IssuerKeyId = binary.BigEndian.Uint64(subpacket) + } case notationDataSubpacket: - // Notation data, section 5.2.3.16 + // Notation data, section 5.2.3.24 if len(subpacket) < 8 { err = errors.StructuralError("notation data subpacket with bad length") return @@ -398,15 +529,27 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r sig.Notations = append(sig.Notations, ¬ation) case prefHashAlgosSubpacket: - // Preferred hash algorithms, section 5.2.3.8 + // Preferred hash algorithms, section 5.2.3.16 sig.PreferredHash = make([]byte, len(subpacket)) copy(sig.PreferredHash, subpacket) case prefCompressionSubpacket: - // Preferred compression algorithms, section 5.2.3.9 + // Preferred compression algorithms, section 5.2.3.17 sig.PreferredCompression = make([]byte, len(subpacket)) copy(sig.PreferredCompression, subpacket) + case keyserverPrefsSubpacket: + // Keyserver preferences, section 5.2.3.25 + sig.KeyserverPrefsValid = true + if len(subpacket) == 0 { + return + } + if subpacket[0]&KeyserverPrefNoModify != 0 { + sig.KeyserverPrefNoModify = true + } + case prefKeyserverSubpacket: + // Preferred keyserver, section 5.2.3.26 + sig.PreferredKeyserver = string(subpacket) case primaryUserIdSubpacket: - // Primary User ID, section 5.2.3.19 + // Primary User ID, section 5.2.3.27 if len(subpacket) != 1 { err = errors.StructuralError("primary user id subpacket with bad length") return @@ -416,12 +559,11 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r *sig.IsPrimaryId = true } case keyFlagsSubpacket: - // Key flags, section 5.2.3.21 + // Key flags, section 5.2.3.29 + sig.FlagsValid = true if len(subpacket) == 0 { - err = errors.StructuralError("empty key flags subpacket") return } - sig.FlagsValid = true if subpacket[0]&KeyFlagCertify != 0 { sig.FlagCertify = true } @@ -447,16 +589,16 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r userId := string(subpacket) sig.SignerUserId = &userId case reasonForRevocationSubpacket: - // Reason For Revocation, section 5.2.3.23 + // Reason For Revocation, section 5.2.3.31 if len(subpacket) == 0 { err = errors.StructuralError("empty revocation reason subpacket") return } sig.RevocationReason = new(ReasonForRevocation) - *sig.RevocationReason = ReasonForRevocation(subpacket[0]) + *sig.RevocationReason = NewReasonForRevocation(subpacket[0]) sig.RevocationReasonText = string(subpacket[1:]) case featuresSubpacket: - // Features subpacket, section 5.2.3.24 specifies a very general + // Features subpacket, section 5.2.3.32 specifies a very general // mechanism for OpenPGP implementations to signal support for new // features. if len(subpacket) > 0 { @@ -470,16 +612,13 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r } case embeddedSignatureSubpacket: // Only usage is in signatures that cross-certify - // signing subkeys. section 5.2.3.26 describes the + // signing subkeys. section 5.2.3.34 describes the // format, with its usage described in section 11.1 if sig.EmbeddedSignature != nil { err = errors.StructuralError("Cannot have multiple embedded signatures") return } sig.EmbeddedSignature = new(Signature) - // Embedded signatures are required to be v4 signatures see - // section 12.1. However, we only parse v4 signatures in this - // file anyway. if err := sig.EmbeddedSignature.parse(bytes.NewBuffer(subpacket)); err != nil { return nil, err } @@ -487,7 +626,7 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r return nil, errors.StructuralError("cross-signature has unexpected type " + strconv.Itoa(int(sigType))) } case policyUriSubpacket: - // Policy URI, section 5.2.3.20 + // Policy URI, section 5.2.3.28 sig.PolicyURI = string(subpacket) case issuerFingerprintSubpacket: if len(subpacket) == 0 { @@ -495,20 +634,31 @@ func parseSignatureSubpacket(sig *Signature, subpacket []byte, isHashed bool) (r return } v, l := subpacket[0], len(subpacket[1:]) - if v == 5 && l != 32 || v != 5 && l != 20 { + if v >= 5 && l != 32 || v < 5 && l != 20 { return nil, errors.StructuralError("bad fingerprint length") } sig.IssuerFingerprint = make([]byte, l) copy(sig.IssuerFingerprint, subpacket[1:]) sig.IssuerKeyId = new(uint64) - if v == 5 { + if v >= 5 { *sig.IssuerKeyId = binary.BigEndian.Uint64(subpacket[1:9]) } else { *sig.IssuerKeyId = binary.BigEndian.Uint64(subpacket[13:21]) } + case intendedRecipientSubpacket: + // Intended Recipient Fingerprint, section 5.2.3.36 + if len(subpacket) < 1 { + return nil, errors.StructuralError("invalid intended recipient fingerpring length") + } + version, length := subpacket[0], len(subpacket[1:]) + if version >= 5 && length != 32 || version < 5 && length != 20 { + return nil, errors.StructuralError("invalid fingerprint length") + } + fingerprint := make([]byte, length) + copy(fingerprint, subpacket[1:]) + sig.IntendedRecipients = append(sig.IntendedRecipients, &Recipient{int(version), fingerprint}) case prefCipherSuitesSubpacket: - // Preferred AEAD cipher suites - // See https://www.ietf.org/archive/id/draft-ietf-openpgp-crypto-refresh-07.html#name-preferred-aead-ciphersuites + // Preferred AEAD cipher suites, section 5.2.3.15 if len(subpacket)%2 != 0 { err = errors.StructuralError("invalid aead cipher suite length") return @@ -550,9 +700,16 @@ func (sig *Signature) CheckKeyIdOrFingerprint(pk *PublicKey) bool { return sig.IssuerKeyId != nil && *sig.IssuerKeyId == pk.KeyId } +func (sig *Signature) CheckKeyIdOrFingerprintExplicit(fingerprint []byte, keyId uint64) bool { + if sig.IssuerFingerprint != nil && len(sig.IssuerFingerprint) >= 20 && fingerprint != nil { + return bytes.Equal(sig.IssuerFingerprint, fingerprint) + } + return sig.IssuerKeyId != nil && *sig.IssuerKeyId == keyId +} + // serializeSubpacketLength marshals the given length into to. func serializeSubpacketLength(to []byte, length int) int { - // RFC 4880, Section 4.2.2. + // RFC 9580, Section 4.2.1. if length < 192 { to[0] = byte(length) return 1 @@ -598,20 +755,19 @@ func serializeSubpackets(to []byte, subpackets []outputSubpacket, hashed bool) { to = to[n:] } } - return } // SigExpired returns whether sig is a signature that has expired or is created // in the future. func (sig *Signature) SigExpired(currentTime time.Time) bool { - if sig.CreationTime.After(currentTime) { + if sig.CreationTime.Unix() > currentTime.Unix() { return true } if sig.SigLifetimeSecs == nil || *sig.SigLifetimeSecs == 0 { return false } expiry := sig.CreationTime.Add(time.Duration(*sig.SigLifetimeSecs) * time.Second) - return currentTime.After(expiry) + return currentTime.Unix() > expiry.Unix() } // buildHashSuffix constructs the HashSuffix member of sig in preparation for signing. @@ -635,20 +791,36 @@ func (sig *Signature) buildHashSuffix(hashedSubpackets []byte) (err error) { uint8(sig.SigType), uint8(sig.PubKeyAlgo), uint8(hashId), - uint8(len(hashedSubpackets) >> 8), - uint8(len(hashedSubpackets)), }) + hashedSubpacketsLength := len(hashedSubpackets) + if sig.Version == 6 { + // v6 signatures store the length in 4 octets + hashedFields.Write([]byte{ + uint8(hashedSubpacketsLength >> 24), + uint8(hashedSubpacketsLength >> 16), + uint8(hashedSubpacketsLength >> 8), + uint8(hashedSubpacketsLength), + }) + } else { + hashedFields.Write([]byte{ + uint8(hashedSubpacketsLength >> 8), + uint8(hashedSubpacketsLength), + }) + } + lenPrefix := hashedFields.Len() hashedFields.Write(hashedSubpackets) - var l uint64 = uint64(6 + len(hashedSubpackets)) + var l uint64 = uint64(lenPrefix + len(hashedSubpackets)) if sig.Version == 5 { + // v5 case hashedFields.Write([]byte{0x05, 0xff}) hashedFields.Write([]byte{ uint8(l >> 56), uint8(l >> 48), uint8(l >> 40), uint8(l >> 32), uint8(l >> 24), uint8(l >> 16), uint8(l >> 8), uint8(l), }) } else { - hashedFields.Write([]byte{0x04, 0xff}) + // v4 and v6 case + hashedFields.Write([]byte{byte(sig.Version), 0xff}) hashedFields.Write([]byte{ uint8(l >> 24), uint8(l >> 16), uint8(l >> 8), uint8(l), }) @@ -676,6 +848,67 @@ func (sig *Signature) signPrepareHash(h hash.Hash) (digest []byte, err error) { return } +// PrepareSign must be called to create a hash object before Sign for v6 signatures. +// The created hash object initially hashes a randomly generated salt +// as required by v6 signatures. The generated salt is stored in sig. If the signature is not v6, +// the method returns an empty hash object. +// See RFC 9580 Section 5.2.4. +func (sig *Signature) PrepareSign(config *Config) (hash.Hash, error) { + if !sig.Hash.Available() { + return nil, errors.UnsupportedError("hash function") + } + hasher := sig.Hash.New() + if sig.Version == 6 { + if sig.salt == nil { + var err error + sig.salt, err = SignatureSaltForHash(sig.Hash, config.Random()) + if err != nil { + return nil, err + } + } + hasher.Write(sig.salt) + } + return hasher, nil +} + +// SetSalt sets the signature salt for v6 signatures. +// Assumes salt is generated correctly and checks if length matches. +// If the signature is not v6, the method ignores the salt. +// Use PrepareSign whenever possible instead of generating and +// hashing the salt externally. +// See RFC 9580 Section 5.2.4. +func (sig *Signature) SetSalt(salt []byte) error { + if sig.Version == 6 { + expectedSaltLength, err := SaltLengthForHash(sig.Hash) + if err != nil { + return err + } + if salt == nil || len(salt) != expectedSaltLength { + return errors.InvalidArgumentError("unexpected salt size for the given hash algorithm") + } + sig.salt = salt + } + return nil +} + +// PrepareVerify must be called to create a hash object before verifying v6 signatures. +// The created hash object initially hashes the internally stored salt. +// If the signature is not v6, the method returns an empty hash object. +// See RFC 9580 Section 5.2.4. +func (sig *Signature) PrepareVerify() (hash.Hash, error) { + if !sig.Hash.Available() { + return nil, errors.UnsupportedError("hash function") + } + hasher := sig.Hash.New() + if sig.Version == 6 { + if sig.salt == nil { + return nil, errors.StructuralError("v6 requires a salt for the hash to be signed") + } + hasher.Write(sig.salt) + } + return hasher, nil +} + // Sign signs a message with a private key. The hash, h, must contain // the hash of the message to be signed and will be mutated by this function. // On success, the signature is stored in sig. Call Serialize to write it out. @@ -686,6 +919,20 @@ func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey, config *Config) (err e } sig.Version = priv.PublicKey.Version sig.IssuerFingerprint = priv.PublicKey.Fingerprint + if sig.Version < 6 && config.RandomizeSignaturesViaNotation() { + sig.removeNotationsWithName(SaltNotationName) + salt, err := SignatureSaltForHash(sig.Hash, config.Random()) + if err != nil { + return err + } + notation := Notation{ + Name: SaltNotationName, + Value: salt, + IsCritical: false, + IsHumanReadable: false, + } + sig.Notations = append(sig.Notations, ¬ation) + } sig.outSubpackets, err = sig.buildSubpackets(priv.PublicKey) if err != nil { return err @@ -715,8 +962,16 @@ func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey, config *Config) (err e sig.DSASigS = new(encoding.MPI).SetBig(s) } case PubKeyAlgoECDSA: - sk := priv.PrivateKey.(*ecdsa.PrivateKey) - r, s, err := ecdsa.Sign(config.Random(), sk, digest) + var r, s *big.Int + if sk, ok := priv.PrivateKey.(*ecdsa.PrivateKey); ok { + r, s, err = ecdsa.Sign(config.Random(), sk, digest) + } else { + var b []byte + b, err = priv.PrivateKey.(crypto.Signer).Sign(config.Random(), digest, sig.Hash) + if err == nil { + r, s, err = unwrapECDSASig(b) + } + } if err == nil { sig.ECDSASigR = new(encoding.MPI).SetBig(r) @@ -729,6 +984,18 @@ func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey, config *Config) (err e sig.EdDSASigR = encoding.NewMPI(r) sig.EdDSASigS = encoding.NewMPI(s) } + case PubKeyAlgoEd25519: + sk := priv.PrivateKey.(*ed25519.PrivateKey) + signature, err := ed25519.Sign(sk, digest) + if err == nil { + sig.EdSig = signature + } + case PubKeyAlgoEd448: + sk := priv.PrivateKey.(*ed448.PrivateKey) + signature, err := ed448.Sign(sk, digest) + if err == nil { + sig.EdSig = signature + } default: err = errors.UnsupportedError("public key algorithm: " + strconv.Itoa(int(sig.PubKeyAlgo))) } @@ -736,6 +1003,18 @@ func (sig *Signature) Sign(h hash.Hash, priv *PrivateKey, config *Config) (err e return } +// unwrapECDSASig parses the two integer components of an ASN.1-encoded ECDSA signature. +func unwrapECDSASig(b []byte) (r, s *big.Int, err error) { + var ecsdaSig struct { + R, S *big.Int + } + _, err = asn1.Unmarshal(b, &ecsdaSig) + if err != nil { + return + } + return ecsdaSig.R, ecsdaSig.S, nil +} + // SignUserId computes a signature from priv, asserting that pub is a valid // key for the identity id. On success, the signature is stored in sig. Call // Serialize to write it out. @@ -744,11 +1023,32 @@ func (sig *Signature) SignUserId(id string, pub *PublicKey, priv *PrivateKey, co if priv.Dummy() { return errors.ErrDummyPrivateKey("dummy key found") } - h, err := userIdSignatureHash(id, pub, sig.Hash) + prepareHash, err := sig.PrepareSign(config) if err != nil { return err } - return sig.Sign(h, priv, config) + if err := userIdSignatureHash(id, pub, prepareHash); err != nil { + return err + } + return sig.Sign(prepareHash, priv, config) +} + +// SignDirectKeyBinding computes a signature from priv +// On success, the signature is stored in sig. +// Call Serialize to write it out. +// If config is nil, sensible defaults will be used. +func (sig *Signature) SignDirectKeyBinding(pub *PublicKey, priv *PrivateKey, config *Config) error { + if priv.Dummy() { + return errors.ErrDummyPrivateKey("dummy key found") + } + prepareHash, err := sig.PrepareSign(config) + if err != nil { + return err + } + if err := directKeySignatureHash(pub, prepareHash); err != nil { + return err + } + return sig.Sign(prepareHash, priv, config) } // CrossSignKey computes a signature from signingKey on pub hashed using hashKey. On success, @@ -756,7 +1056,11 @@ func (sig *Signature) SignUserId(id string, pub *PublicKey, priv *PrivateKey, co // If config is nil, sensible defaults will be used. func (sig *Signature) CrossSignKey(pub *PublicKey, hashKey *PublicKey, signingKey *PrivateKey, config *Config) error { - h, err := keySignatureHash(hashKey, pub, sig.Hash) + prepareHash, err := sig.PrepareSign(config) + if err != nil { + return err + } + h, err := keySignatureHash(hashKey, pub, prepareHash) if err != nil { return err } @@ -770,7 +1074,11 @@ func (sig *Signature) SignKey(pub *PublicKey, priv *PrivateKey, config *Config) if priv.Dummy() { return errors.ErrDummyPrivateKey("dummy key found") } - h, err := keySignatureHash(&priv.PublicKey, pub, sig.Hash) + prepareHash, err := sig.PrepareSign(config) + if err != nil { + return err + } + h, err := keySignatureHash(&priv.PublicKey, pub, prepareHash) if err != nil { return err } @@ -781,11 +1089,14 @@ func (sig *Signature) SignKey(pub *PublicKey, priv *PrivateKey, config *Config) // stored in sig. Call Serialize to write it out. // If config is nil, sensible defaults will be used. func (sig *Signature) RevokeKey(pub *PublicKey, priv *PrivateKey, config *Config) error { - h, err := keyRevocationHash(pub, sig.Hash) + prepareHash, err := sig.PrepareSign(config) if err != nil { return err } - return sig.Sign(h, priv, config) + if err := keyRevocationHash(pub, prepareHash); err != nil { + return err + } + return sig.Sign(prepareHash, priv, config) } // RevokeSubkey computes a subkey revocation signature of pub using priv. @@ -802,7 +1113,7 @@ func (sig *Signature) Serialize(w io.Writer) (err error) { if len(sig.outSubpackets) == 0 { sig.outSubpackets = sig.rawSubpackets } - if sig.RSASignature == nil && sig.DSASigR == nil && sig.ECDSASigR == nil && sig.EdDSASigR == nil { + if sig.RSASignature == nil && sig.DSASigR == nil && sig.ECDSASigR == nil && sig.EdDSASigR == nil && sig.EdSig == nil { return errors.InvalidArgumentError("Signature: need to call Sign, SignUserId or SignKey before Serialize") } @@ -819,16 +1130,24 @@ func (sig *Signature) Serialize(w io.Writer) (err error) { case PubKeyAlgoEdDSA: sigLength = int(sig.EdDSASigR.EncodedLength()) sigLength += int(sig.EdDSASigS.EncodedLength()) + case PubKeyAlgoEd25519: + sigLength = ed25519.SignatureSize + case PubKeyAlgoEd448: + sigLength = ed448.SignatureSize default: panic("impossible") } + hashedSubpacketsLen := subpacketsLength(sig.outSubpackets, true) unhashedSubpacketsLen := subpacketsLength(sig.outSubpackets, false) - length := len(sig.HashSuffix) - 6 /* trailer not included */ + + length := 4 + /* length of version|signature type|public-key algorithm|hash algorithm */ + 2 /* length of hashed subpackets */ + hashedSubpacketsLen + 2 /* length of unhashed subpackets */ + unhashedSubpacketsLen + 2 /* hash tag */ + sigLength - if sig.Version == 5 { - length -= 4 // eight-octet instead of four-octet big endian + if sig.Version == 6 { + length += 4 + /* the two length fields are four-octet instead of two */ + 1 + /* salt length */ + len(sig.salt) /* length salt */ } err = serializeHeader(w, packetTypeSignature, length) if err != nil { @@ -842,18 +1161,41 @@ func (sig *Signature) Serialize(w io.Writer) (err error) { } func (sig *Signature) serializeBody(w io.Writer) (err error) { - hashedSubpacketsLen := uint16(uint16(sig.HashSuffix[4])<<8) | uint16(sig.HashSuffix[5]) - fields := sig.HashSuffix[:6+hashedSubpacketsLen] + var fields []byte + if sig.Version == 6 { + // v6 signatures use 4 octets for length + hashedSubpacketsLen := + uint32(uint32(sig.HashSuffix[4])<<24) | + uint32(uint32(sig.HashSuffix[5])<<16) | + uint32(uint32(sig.HashSuffix[6])<<8) | + uint32(sig.HashSuffix[7]) + fields = sig.HashSuffix[:8+hashedSubpacketsLen] + } else { + hashedSubpacketsLen := uint16(uint16(sig.HashSuffix[4])<<8) | + uint16(sig.HashSuffix[5]) + fields = sig.HashSuffix[:6+hashedSubpacketsLen] + + } _, err = w.Write(fields) if err != nil { return } unhashedSubpacketsLen := subpacketsLength(sig.outSubpackets, false) - unhashedSubpackets := make([]byte, 2+unhashedSubpacketsLen) - unhashedSubpackets[0] = byte(unhashedSubpacketsLen >> 8) - unhashedSubpackets[1] = byte(unhashedSubpacketsLen) - serializeSubpackets(unhashedSubpackets[2:], sig.outSubpackets, false) + var unhashedSubpackets []byte + if sig.Version == 6 { + unhashedSubpackets = make([]byte, 4+unhashedSubpacketsLen) + unhashedSubpackets[0] = byte(unhashedSubpacketsLen >> 24) + unhashedSubpackets[1] = byte(unhashedSubpacketsLen >> 16) + unhashedSubpackets[2] = byte(unhashedSubpacketsLen >> 8) + unhashedSubpackets[3] = byte(unhashedSubpacketsLen) + serializeSubpackets(unhashedSubpackets[4:], sig.outSubpackets, false) + } else { + unhashedSubpackets = make([]byte, 2+unhashedSubpacketsLen) + unhashedSubpackets[0] = byte(unhashedSubpacketsLen >> 8) + unhashedSubpackets[1] = byte(unhashedSubpacketsLen) + serializeSubpackets(unhashedSubpackets[2:], sig.outSubpackets, false) + } _, err = w.Write(unhashedSubpackets) if err != nil { @@ -864,6 +1206,18 @@ func (sig *Signature) serializeBody(w io.Writer) (err error) { return } + if sig.Version == 6 { + // write salt for v6 signatures + _, err = w.Write([]byte{uint8(len(sig.salt))}) + if err != nil { + return + } + _, err = w.Write(sig.salt) + if err != nil { + return + } + } + switch sig.PubKeyAlgo { case PubKeyAlgoRSA, PubKeyAlgoRSASignOnly: _, err = w.Write(sig.RSASignature.EncodedBytes()) @@ -882,6 +1236,10 @@ func (sig *Signature) serializeBody(w io.Writer) (err error) { return } _, err = w.Write(sig.EdDSASigS.EncodedBytes()) + case PubKeyAlgoEd25519: + err = ed25519.WriteSignature(w, sig.EdSig) + case PubKeyAlgoEd448: + err = ed448.WriteSignature(w, sig.EdSig) default: panic("impossible") } @@ -899,28 +1257,81 @@ type outputSubpacket struct { func (sig *Signature) buildSubpackets(issuer PublicKey) (subpackets []outputSubpacket, err error) { creationTime := make([]byte, 4) binary.BigEndian.PutUint32(creationTime, uint32(sig.CreationTime.Unix())) - subpackets = append(subpackets, outputSubpacket{true, creationTimeSubpacket, false, creationTime}) - + // Signature Creation Time + subpackets = append(subpackets, outputSubpacket{true, creationTimeSubpacket, true, creationTime}) + // Signature Expiration Time + if sig.SigLifetimeSecs != nil && *sig.SigLifetimeSecs != 0 { + sigLifetime := make([]byte, 4) + binary.BigEndian.PutUint32(sigLifetime, *sig.SigLifetimeSecs) + subpackets = append(subpackets, outputSubpacket{true, signatureExpirationSubpacket, true, sigLifetime}) + } + // Trust Signature + if sig.TrustLevel != 0 { + subpackets = append(subpackets, outputSubpacket{true, trustSubpacket, true, []byte{byte(sig.TrustLevel), byte(sig.TrustAmount)}}) + } + // Regular Expression + if sig.TrustRegularExpression != nil { + // RFC specifies the string should be null-terminated; add a null byte to the end + subpackets = append(subpackets, outputSubpacket{true, regularExpressionSubpacket, true, []byte(*sig.TrustRegularExpression + "\000")}) + } + // Key Expiration Time + if sig.KeyLifetimeSecs != nil && *sig.KeyLifetimeSecs != 0 { + keyLifetime := make([]byte, 4) + binary.BigEndian.PutUint32(keyLifetime, *sig.KeyLifetimeSecs) + subpackets = append(subpackets, outputSubpacket{true, keyExpirationSubpacket, true, keyLifetime}) + } + // Preferred Symmetric Ciphers for v1 SEIPD + if len(sig.PreferredSymmetric) > 0 { + subpackets = append(subpackets, outputSubpacket{true, prefSymmetricAlgosSubpacket, false, sig.PreferredSymmetric}) + } + // Issuer Key ID if sig.IssuerKeyId != nil && sig.Version == 4 { keyId := make([]byte, 8) binary.BigEndian.PutUint64(keyId, *sig.IssuerKeyId) - subpackets = append(subpackets, outputSubpacket{true, issuerSubpacket, false, keyId}) + subpackets = append(subpackets, outputSubpacket{true, issuerSubpacket, true, keyId}) } - if sig.IssuerFingerprint != nil { - contents := append([]uint8{uint8(issuer.Version)}, sig.IssuerFingerprint...) - subpackets = append(subpackets, outputSubpacket{true, issuerFingerprintSubpacket, sig.Version == 5, contents}) + // Notation Data + for _, notation := range sig.Notations { + subpackets = append( + subpackets, + outputSubpacket{ + true, + notationDataSubpacket, + notation.IsCritical, + notation.getData(), + }) } - if sig.SignerUserId != nil { - subpackets = append(subpackets, outputSubpacket{true, signerUserIdSubpacket, false, []byte(*sig.SignerUserId)}) + // Preferred Hash Algorithms + if len(sig.PreferredHash) > 0 { + subpackets = append(subpackets, outputSubpacket{true, prefHashAlgosSubpacket, false, sig.PreferredHash}) } - if sig.SigLifetimeSecs != nil && *sig.SigLifetimeSecs != 0 { - sigLifetime := make([]byte, 4) - binary.BigEndian.PutUint32(sigLifetime, *sig.SigLifetimeSecs) - subpackets = append(subpackets, outputSubpacket{true, signatureExpirationSubpacket, true, sigLifetime}) + // Preferred Compression Algorithms + if len(sig.PreferredCompression) > 0 { + subpackets = append(subpackets, outputSubpacket{true, prefCompressionSubpacket, false, sig.PreferredCompression}) } - + // Keyserver Preferences + // Keyserver preferences may only appear in self-signatures or certification signatures. + if sig.KeyserverPrefsValid { + var prefs byte + if sig.KeyserverPrefNoModify { + prefs |= KeyserverPrefNoModify + } + subpackets = append(subpackets, outputSubpacket{true, keyserverPrefsSubpacket, false, []byte{prefs}}) + } + // Preferred Keyserver + if len(sig.PreferredKeyserver) > 0 { + subpackets = append(subpackets, outputSubpacket{true, prefKeyserverSubpacket, false, []uint8(sig.PreferredKeyserver)}) + } + // Primary User ID + if sig.IsPrimaryId != nil && *sig.IsPrimaryId { + subpackets = append(subpackets, outputSubpacket{true, primaryUserIdSubpacket, false, []byte{1}}) + } + // Policy URI + if len(sig.PolicyURI) > 0 { + subpackets = append(subpackets, outputSubpacket{true, policyUriSubpacket, false, []uint8(sig.PolicyURI)}) + } + // Key Flags // Key flags may only appear in self-signatures or certification signatures. - if sig.FlagsValid { var flags byte if sig.FlagCertify { @@ -944,22 +1355,19 @@ func (sig *Signature) buildSubpackets(issuer PublicKey) (subpackets []outputSubp if sig.FlagGroupKey { flags |= KeyFlagGroupKey } - subpackets = append(subpackets, outputSubpacket{true, keyFlagsSubpacket, false, []byte{flags}}) + subpackets = append(subpackets, outputSubpacket{true, keyFlagsSubpacket, true, []byte{flags}}) } - - for _, notation := range sig.Notations { - subpackets = append( - subpackets, - outputSubpacket{ - true, - notationDataSubpacket, - notation.IsCritical, - notation.getData(), - }) + // Signer's User ID + if sig.SignerUserId != nil { + subpackets = append(subpackets, outputSubpacket{true, signerUserIdSubpacket, false, []byte(*sig.SignerUserId)}) } - - // The following subpackets may only appear in self-signatures. - + // Reason for Revocation + // Revocation reason appears only in revocation signatures and is serialized as per section 5.2.3.31. + if sig.RevocationReason != nil { + subpackets = append(subpackets, outputSubpacket{true, reasonForRevocationSubpacket, true, + append([]uint8{uint8(*sig.RevocationReason)}, []uint8(sig.RevocationReasonText)...)}) + } + // Features var features = byte(0x00) if sig.SEIPDv1 { features |= 0x01 @@ -967,46 +1375,36 @@ func (sig *Signature) buildSubpackets(issuer PublicKey) (subpackets []outputSubp if sig.SEIPDv2 { features |= 0x08 } - if features != 0x00 { subpackets = append(subpackets, outputSubpacket{true, featuresSubpacket, false, []byte{features}}) } - - if sig.TrustLevel != 0 { - subpackets = append(subpackets, outputSubpacket{true, trustSubpacket, true, []byte{byte(sig.TrustLevel), byte(sig.TrustAmount)}}) - } - - if sig.TrustRegularExpression != nil { - // RFC specifies the string should be null-terminated; add a null byte to the end - subpackets = append(subpackets, outputSubpacket{true, regularExpressionSubpacket, true, []byte(*sig.TrustRegularExpression + "\000")}) - } - - if sig.KeyLifetimeSecs != nil && *sig.KeyLifetimeSecs != 0 { - keyLifetime := make([]byte, 4) - binary.BigEndian.PutUint32(keyLifetime, *sig.KeyLifetimeSecs) - subpackets = append(subpackets, outputSubpacket{true, keyExpirationSubpacket, true, keyLifetime}) - } - - if sig.IsPrimaryId != nil && *sig.IsPrimaryId { - subpackets = append(subpackets, outputSubpacket{true, primaryUserIdSubpacket, false, []byte{1}}) - } - - if len(sig.PreferredSymmetric) > 0 { - subpackets = append(subpackets, outputSubpacket{true, prefSymmetricAlgosSubpacket, false, sig.PreferredSymmetric}) - } - - if len(sig.PreferredHash) > 0 { - subpackets = append(subpackets, outputSubpacket{true, prefHashAlgosSubpacket, false, sig.PreferredHash}) + // Embedded Signature + // EmbeddedSignature appears only in subkeys capable of signing and is serialized as per section 5.2.3.34. + if sig.EmbeddedSignature != nil { + var buf bytes.Buffer + err = sig.EmbeddedSignature.serializeBody(&buf) + if err != nil { + return + } + subpackets = append(subpackets, outputSubpacket{true, embeddedSignatureSubpacket, true, buf.Bytes()}) } - - if len(sig.PreferredCompression) > 0 { - subpackets = append(subpackets, outputSubpacket{true, prefCompressionSubpacket, false, sig.PreferredCompression}) + // Issuer Fingerprint + if sig.IssuerFingerprint != nil { + contents := append([]uint8{uint8(issuer.Version)}, sig.IssuerFingerprint...) + subpackets = append(subpackets, outputSubpacket{true, issuerFingerprintSubpacket, sig.Version >= 5, contents}) } - - if len(sig.PolicyURI) > 0 { - subpackets = append(subpackets, outputSubpacket{true, policyUriSubpacket, false, []uint8(sig.PolicyURI)}) + // Intended Recipient Fingerprint + for _, recipient := range sig.IntendedRecipients { + subpackets = append( + subpackets, + outputSubpacket{ + true, + intendedRecipientSubpacket, + false, + recipient.Serialize(), + }) } - + // Preferred AEAD Ciphersuites if len(sig.PreferredCipherSuites) > 0 { serialized := make([]byte, len(sig.PreferredCipherSuites)*2) for i, cipherSuite := range sig.PreferredCipherSuites { @@ -1015,23 +1413,6 @@ func (sig *Signature) buildSubpackets(issuer PublicKey) (subpackets []outputSubp } subpackets = append(subpackets, outputSubpacket{true, prefCipherSuitesSubpacket, false, serialized}) } - - // Revocation reason appears only in revocation signatures and is serialized as per section 5.2.3.23. - if sig.RevocationReason != nil { - subpackets = append(subpackets, outputSubpacket{true, reasonForRevocationSubpacket, true, - append([]uint8{uint8(*sig.RevocationReason)}, []uint8(sig.RevocationReasonText)...)}) - } - - // EmbeddedSignature appears only in subkeys capable of signing and is serialized as per section 5.2.3.26. - if sig.EmbeddedSignature != nil { - var buf bytes.Buffer - err = sig.EmbeddedSignature.serializeBody(&buf) - if err != nil { - return - } - subpackets = append(subpackets, outputSubpacket{true, embeddedSignatureSubpacket, true, buf.Bytes()}) - } - return } @@ -1073,8 +1454,6 @@ func (sig *Signature) AddMetadataToHashSuffix() { binary.BigEndian.PutUint32(buf[:], lit.Time) suffix.Write(buf[:]) - // Update the counter and restore trailing bytes - l = uint64(suffix.Len()) suffix.Write([]byte{0x05, 0xff}) suffix.Write([]byte{ uint8(l >> 56), uint8(l >> 48), uint8(l >> 40), uint8(l >> 32), @@ -1082,3 +1461,49 @@ func (sig *Signature) AddMetadataToHashSuffix() { }) sig.HashSuffix = suffix.Bytes() } + +// SaltLengthForHash selects the required salt length for the given hash algorithm, +// as per Table 23 (Hash algorithm registry) of the crypto refresh. +// See RFC 9580 Section 9.5. +func SaltLengthForHash(hash crypto.Hash) (int, error) { + switch hash { + case crypto.SHA256, crypto.SHA224, crypto.SHA3_256: + return 16, nil + case crypto.SHA384: + return 24, nil + case crypto.SHA512, crypto.SHA3_512: + return 32, nil + default: + return 0, errors.UnsupportedError("hash function not supported for V6 signatures") + } +} + +// SignatureSaltForHash generates a random signature salt +// with the length for the given hash algorithm. +// See RFC 9580 Section 9.5. +func SignatureSaltForHash(hash crypto.Hash, randReader io.Reader) ([]byte, error) { + saltLength, err := SaltLengthForHash(hash) + if err != nil { + return nil, err + } + salt := make([]byte, saltLength) + _, err = io.ReadFull(randReader, salt) + if err != nil { + return nil, err + } + return salt, nil +} + +// removeNotationsWithName removes all notations in this signature with the given name. +func (sig *Signature) removeNotationsWithName(name string) { + if sig == nil || sig.Notations == nil { + return + } + updatedNotations := make([]*Notation, 0, len(sig.Notations)) + for _, notation := range sig.Notations { + if notation.Name != name { + updatedNotations = append(updatedNotations, notation) + } + } + sig.Notations = updatedNotations +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetric_key_encrypted.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetric_key_encrypted.go index bac2b132ea..2812a1db88 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetric_key_encrypted.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetric_key_encrypted.go @@ -7,11 +7,13 @@ package packet import ( "bytes" "crypto/cipher" + "crypto/sha256" "io" "strconv" "github.com/ProtonMail/go-crypto/openpgp/errors" "github.com/ProtonMail/go-crypto/openpgp/s2k" + "golang.org/x/crypto/hkdf" ) // This is the largest session key that we'll support. Since at most 256-bit cipher @@ -39,10 +41,21 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) error { return err } ske.Version = int(buf[0]) - if ske.Version != 4 && ske.Version != 5 { + if ske.Version != 4 && ske.Version != 5 && ske.Version != 6 { return errors.UnsupportedError("unknown SymmetricKeyEncrypted version") } + if V5Disabled && ske.Version == 5 { + return errors.UnsupportedError("support for parsing v5 entities is disabled; build with `-tags v5` if needed") + } + + if ske.Version > 5 { + // Scalar octet count + if _, err := readFull(r, buf[:]); err != nil { + return err + } + } + // Cipher function if _, err := readFull(r, buf[:]); err != nil { return err @@ -52,7 +65,7 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) error { return errors.UnsupportedError("unknown cipher: " + strconv.Itoa(int(buf[0]))) } - if ske.Version == 5 { + if ske.Version >= 5 { // AEAD mode if _, err := readFull(r, buf[:]); err != nil { return errors.StructuralError("cannot read AEAD octet from packet") @@ -60,6 +73,13 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) error { ske.Mode = AEADMode(buf[0]) } + if ske.Version > 5 { + // Scalar octet count + if _, err := readFull(r, buf[:]); err != nil { + return err + } + } + var err error if ske.s2k, err = s2k.Parse(r); err != nil { if _, ok := err.(errors.ErrDummyPrivateKey); ok { @@ -68,7 +88,7 @@ func (ske *SymmetricKeyEncrypted) parse(r io.Reader) error { return err } - if ske.Version == 5 { + if ske.Version >= 5 { // AEAD IV iv := make([]byte, ske.Mode.IvLength()) _, err := readFull(r, iv) @@ -109,8 +129,8 @@ func (ske *SymmetricKeyEncrypted) Decrypt(passphrase []byte) ([]byte, CipherFunc case 4: plaintextKey, cipherFunc, err := ske.decryptV4(key) return plaintextKey, cipherFunc, err - case 5: - plaintextKey, err := ske.decryptV5(key) + case 5, 6: + plaintextKey, err := ske.aeadDecrypt(ske.Version, key) return plaintextKey, CipherFunction(0), err } err := errors.UnsupportedError("unknown SymmetricKeyEncrypted version") @@ -136,9 +156,9 @@ func (ske *SymmetricKeyEncrypted) decryptV4(key []byte) ([]byte, CipherFunction, return plaintextKey, cipherFunc, nil } -func (ske *SymmetricKeyEncrypted) decryptV5(key []byte) ([]byte, error) { - adata := []byte{0xc3, byte(5), byte(ske.CipherFunc), byte(ske.Mode)} - aead := getEncryptedKeyAeadInstance(ske.CipherFunc, ske.Mode, key, adata) +func (ske *SymmetricKeyEncrypted) aeadDecrypt(version int, key []byte) ([]byte, error) { + adata := []byte{0xc3, byte(version), byte(ske.CipherFunc), byte(ske.Mode)} + aead := getEncryptedKeyAeadInstance(ske.CipherFunc, ske.Mode, key, adata, version) plaintextKey, err := aead.Open(nil, ske.iv, ske.encryptedKey, adata) if err != nil { @@ -175,10 +195,22 @@ func SerializeSymmetricKeyEncrypted(w io.Writer, passphrase []byte, config *Conf // the given passphrase. The returned session key must be passed to // SerializeSymmetricallyEncrypted. // If config is nil, sensible defaults will be used. +// Deprecated: Use SerializeSymmetricKeyEncryptedAEADReuseKey instead. func SerializeSymmetricKeyEncryptedReuseKey(w io.Writer, sessionKey []byte, passphrase []byte, config *Config) (err error) { + return SerializeSymmetricKeyEncryptedAEADReuseKey(w, sessionKey, passphrase, config.AEAD() != nil, config) +} + +// SerializeSymmetricKeyEncryptedAEADReuseKey serializes a symmetric key packet to w. +// The packet contains the given session key, encrypted by a key derived from +// the given passphrase. The returned session key must be passed to +// SerializeSymmetricallyEncrypted. +// If aeadSupported is set, SKESK v6 is used, otherwise v4. +// Note: aeadSupported MUST match the value passed to SerializeSymmetricallyEncrypted. +// If config is nil, sensible defaults will be used. +func SerializeSymmetricKeyEncryptedAEADReuseKey(w io.Writer, sessionKey []byte, passphrase []byte, aeadSupported bool, config *Config) (err error) { var version int - if config.AEAD() != nil { - version = 5 + if aeadSupported { + version = 6 } else { version = 4 } @@ -203,11 +235,15 @@ func SerializeSymmetricKeyEncryptedReuseKey(w io.Writer, sessionKey []byte, pass switch version { case 4: packetLength = 2 /* header */ + len(s2kBytes) + 1 /* cipher type */ + keySize - case 5: + case 5, 6: ivLen := config.AEAD().Mode().IvLength() tagLen := config.AEAD().Mode().TagLength() packetLength = 3 + len(s2kBytes) + ivLen + keySize + tagLen } + if version > 5 { + packetLength += 2 // additional octet count fields + } + err = serializeHeader(w, packetTypeSymmetricKeyEncrypted, packetLength) if err != nil { return @@ -216,13 +252,22 @@ func SerializeSymmetricKeyEncryptedReuseKey(w io.Writer, sessionKey []byte, pass // Symmetric Key Encrypted Version buf := []byte{byte(version)} + if version > 5 { + // Scalar octet count + buf = append(buf, byte(3+len(s2kBytes)+config.AEAD().Mode().IvLength())) + } + // Cipher function buf = append(buf, byte(cipherFunc)) - if version == 5 { + if version >= 5 { // AEAD mode buf = append(buf, byte(config.AEAD().Mode())) } + if version > 5 { + // Scalar octet count + buf = append(buf, byte(len(s2kBytes))) + } _, err = w.Write(buf) if err != nil { return @@ -243,10 +288,10 @@ func SerializeSymmetricKeyEncryptedReuseKey(w io.Writer, sessionKey []byte, pass if err != nil { return } - case 5: + case 5, 6: mode := config.AEAD().Mode() - adata := []byte{0xc3, byte(5), byte(cipherFunc), byte(mode)} - aead := getEncryptedKeyAeadInstance(cipherFunc, mode, keyEncryptingKey, adata) + adata := []byte{0xc3, byte(version), byte(cipherFunc), byte(mode)} + aead := getEncryptedKeyAeadInstance(cipherFunc, mode, keyEncryptingKey, adata, version) // Sample iv using random reader iv := make([]byte, config.AEAD().Mode().IvLength()) @@ -270,7 +315,17 @@ func SerializeSymmetricKeyEncryptedReuseKey(w io.Writer, sessionKey []byte, pass return } -func getEncryptedKeyAeadInstance(c CipherFunction, mode AEADMode, inputKey, associatedData []byte) (aead cipher.AEAD) { - blockCipher := c.new(inputKey) +func getEncryptedKeyAeadInstance(c CipherFunction, mode AEADMode, inputKey, associatedData []byte, version int) (aead cipher.AEAD) { + var blockCipher cipher.Block + if version > 5 { + hkdfReader := hkdf.New(sha256.New, inputKey, []byte{}, associatedData) + + encryptionKey := make([]byte, c.KeySize()) + _, _ = readFull(hkdfReader, encryptionKey) + + blockCipher = c.new(encryptionKey) + } else { + blockCipher = c.new(inputKey) + } return mode.new(blockCipher) } diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetrically_encrypted.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetrically_encrypted.go index e9bbf0327e..0e898742cf 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetrically_encrypted.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetrically_encrypted.go @@ -74,6 +74,10 @@ func (se *SymmetricallyEncrypted) Decrypt(c CipherFunction, key []byte) (io.Read // SerializeSymmetricallyEncrypted serializes a symmetrically encrypted packet // to w and returns a WriteCloser to which the to-be-encrypted packets can be // written. +// If aeadSupported is set to true, SEIPDv2 is used with the indicated CipherSuite. +// Otherwise, SEIPDv1 is used with the indicated CipherFunction. +// Note: aeadSupported MUST match the value passed to SerializeEncryptedKeyAEAD +// and/or SerializeSymmetricKeyEncryptedAEADReuseKey. // If config is nil, sensible defaults will be used. func SerializeSymmetricallyEncrypted(w io.Writer, c CipherFunction, aeadSupported bool, cipherSuite CipherSuite, key []byte, config *Config) (Contents io.WriteCloser, err error) { writeCloser := noOpCloser{w} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetrically_encrypted_aead.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetrically_encrypted_aead.go index e96252c196..3957b2d53e 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetrically_encrypted_aead.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetrically_encrypted_aead.go @@ -7,7 +7,9 @@ package packet import ( "crypto/cipher" "crypto/sha256" + "fmt" "io" + "strconv" "github.com/ProtonMail/go-crypto/openpgp/errors" "golang.org/x/crypto/hkdf" @@ -25,19 +27,19 @@ func (se *SymmetricallyEncrypted) parseAead(r io.Reader) error { se.Cipher = CipherFunction(headerData[0]) // cipherFunc must have block size 16 to use AEAD if se.Cipher.blockSize() != 16 { - return errors.UnsupportedError("invalid aead cipher: " + string(se.Cipher)) + return errors.UnsupportedError("invalid aead cipher: " + strconv.Itoa(int(se.Cipher))) } // Mode se.Mode = AEADMode(headerData[1]) if se.Mode.TagLength() == 0 { - return errors.UnsupportedError("unknown aead mode: " + string(se.Mode)) + return errors.UnsupportedError("unknown aead mode: " + strconv.Itoa(int(se.Mode))) } // Chunk size se.ChunkSizeByte = headerData[2] if se.ChunkSizeByte > 16 { - return errors.UnsupportedError("invalid aead chunk size byte: " + string(se.ChunkSizeByte)) + return errors.UnsupportedError("invalid aead chunk size byte: " + strconv.Itoa(int(se.ChunkSizeByte))) } // Salt @@ -62,8 +64,11 @@ func (se *SymmetricallyEncrypted) associatedData() []byte { // decryptAead decrypts a V2 SEIPD packet (AEAD) as specified in // https://www.ietf.org/archive/id/draft-ietf-openpgp-crypto-refresh-07.html#section-5.13.2 func (se *SymmetricallyEncrypted) decryptAead(inputKey []byte) (io.ReadCloser, error) { - aead, nonce := getSymmetricallyEncryptedAeadInstance(se.Cipher, se.Mode, inputKey, se.Salt[:], se.associatedData()) + if se.Cipher.KeySize() != len(inputKey) { + return nil, errors.StructuralError(fmt.Sprintf("invalid session key length for cipher: got %d bytes, but expected %d bytes", len(inputKey), se.Cipher.KeySize())) + } + aead, nonce := getSymmetricallyEncryptedAeadInstance(se.Cipher, se.Mode, inputKey, se.Salt[:], se.associatedData()) // Carry the first tagLen bytes tagLen := se.Mode.TagLength() peekedBytes := make([]byte, tagLen) @@ -115,7 +120,7 @@ func serializeSymmetricallyEncryptedAead(ciphertext io.WriteCloser, cipherSuite // Random salt salt := make([]byte, aeadSaltSize) - if _, err := rand.Read(salt); err != nil { + if _, err := io.ReadFull(rand, salt); err != nil { return nil, err } diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetrically_encrypted_mdc.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetrically_encrypted_mdc.go index fa26bebe38..8b18623684 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetrically_encrypted_mdc.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/symmetrically_encrypted_mdc.go @@ -148,7 +148,7 @@ const mdcPacketTagByte = byte(0x80) | 0x40 | 19 func (ser *seMDCReader) Close() error { if ser.error { - return errors.ErrMDCMissing + return errors.ErrMDCHashMismatch } for !ser.eof { @@ -159,7 +159,7 @@ func (ser *seMDCReader) Close() error { break } if err != nil { - return errors.ErrMDCMissing + return errors.ErrMDCHashMismatch } } @@ -172,7 +172,7 @@ func (ser *seMDCReader) Close() error { // The hash already includes the MDC header, but we still check its value // to confirm encryption correctness if ser.trailer[0] != mdcPacketTagByte || ser.trailer[1] != sha1.Size { - return errors.ErrMDCMissing + return errors.ErrMDCHashMismatch } return nil } @@ -237,9 +237,9 @@ func serializeSymmetricallyEncryptedMdc(ciphertext io.WriteCloser, c CipherFunct block := c.new(key) blockSize := block.BlockSize() iv := make([]byte, blockSize) - _, err = config.Random().Read(iv) + _, err = io.ReadFull(config.Random(), iv) if err != nil { - return + return nil, err } s, prefix := NewOCFBEncrypter(block, iv, OCFBNoResync) _, err = ciphertext.Write(prefix) diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/userattribute.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/userattribute.go index 88ec72c6c4..63814ed132 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/userattribute.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/userattribute.go @@ -9,7 +9,6 @@ import ( "image" "image/jpeg" "io" - "io/ioutil" ) const UserAttrImageSubpacket = 1 @@ -63,7 +62,7 @@ func NewUserAttribute(contents ...*OpaqueSubpacket) *UserAttribute { func (uat *UserAttribute) parse(r io.Reader) (err error) { // RFC 4880, section 5.13 - b, err := ioutil.ReadAll(r) + b, err := io.ReadAll(r) if err != nil { return } diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/userid.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/userid.go index 614fbafd5e..3c7451a3c3 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/userid.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/packet/userid.go @@ -6,7 +6,6 @@ package packet import ( "io" - "io/ioutil" "strings" ) @@ -66,7 +65,7 @@ func NewUserId(name, comment, email string) *UserId { func (uid *UserId) parse(r io.Reader) (err error) { // RFC 4880, section 5.11 - b, err := ioutil.ReadAll(r) + b, err := io.ReadAll(r) if err != nil { return } diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/read.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/read.go index 8499c73790..e6dd9b5fd3 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/read.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/read.go @@ -46,6 +46,7 @@ type MessageDetails struct { DecryptedWith Key // the private key used to decrypt the message, if any. IsSigned bool // true if the message is signed. SignedByKeyId uint64 // the key id of the signer, if any. + SignedByFingerprint []byte // the key fingerprint of the signer, if any. SignedBy *Key // the key of the signer, if available. LiteralData *packet.LiteralData // the metadata of the contents UnverifiedBody io.Reader // the contents of the message. @@ -117,7 +118,7 @@ ParsePackets: // This packet contains the decryption key encrypted to a public key. md.EncryptedToKeyIds = append(md.EncryptedToKeyIds, p.KeyId) switch p.Algo { - case packet.PubKeyAlgoRSA, packet.PubKeyAlgoRSAEncryptOnly, packet.PubKeyAlgoElGamal, packet.PubKeyAlgoECDH: + case packet.PubKeyAlgoRSA, packet.PubKeyAlgoRSAEncryptOnly, packet.PubKeyAlgoElGamal, packet.PubKeyAlgoECDH, packet.PubKeyAlgoX25519, packet.PubKeyAlgoX448: break default: continue @@ -232,7 +233,7 @@ FindKey: } mdFinal, sensitiveParsingErr := readSignedMessage(packets, md, keyring, config) if sensitiveParsingErr != nil { - return nil, errors.StructuralError("parsing error") + return nil, errors.HandleSensitiveParsingError(sensitiveParsingErr, md.decrypted != nil) } return mdFinal, nil } @@ -270,13 +271,17 @@ FindLiteralData: prevLast = true } - h, wrappedHash, err = hashForSignature(p.Hash, p.SigType) + h, wrappedHash, err = hashForSignature(p.Hash, p.SigType, p.Salt) if err != nil { md.SignatureError = err } md.IsSigned = true + if p.Version == 6 { + md.SignedByFingerprint = p.KeyFingerprint + } md.SignedByKeyId = p.KeyId + if keyring != nil { keys := keyring.KeysByIdUsage(p.KeyId, packet.KeyFlagSign) if len(keys) > 0 { @@ -292,7 +297,7 @@ FindLiteralData: if md.IsSigned && md.SignatureError == nil { md.UnverifiedBody = &signatureCheckReader{packets, h, wrappedHash, md, config} } else if md.decrypted != nil { - md.UnverifiedBody = checkReader{md} + md.UnverifiedBody = &checkReader{md, false} } else { md.UnverifiedBody = md.LiteralData.Body } @@ -300,12 +305,22 @@ FindLiteralData: return md, nil } +func wrapHashForSignature(hashFunc hash.Hash, sigType packet.SignatureType) (hash.Hash, error) { + switch sigType { + case packet.SigTypeBinary: + return hashFunc, nil + case packet.SigTypeText: + return NewCanonicalTextHash(hashFunc), nil + } + return nil, errors.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType))) +} + // hashForSignature returns a pair of hashes that can be used to verify a // signature. The signature may specify that the contents of the signed message // should be preprocessed (i.e. to normalize line endings). Thus this function // returns two hashes. The second should be used to hash the message itself and // performs any needed preprocessing. -func hashForSignature(hashFunc crypto.Hash, sigType packet.SignatureType) (hash.Hash, hash.Hash, error) { +func hashForSignature(hashFunc crypto.Hash, sigType packet.SignatureType, sigSalt []byte) (hash.Hash, hash.Hash, error) { if _, ok := algorithm.HashToHashIdWithSha1(hashFunc); !ok { return nil, nil, errors.UnsupportedError("unsupported hash function") } @@ -313,14 +328,19 @@ func hashForSignature(hashFunc crypto.Hash, sigType packet.SignatureType) (hash. return nil, nil, errors.UnsupportedError("hash not available: " + strconv.Itoa(int(hashFunc))) } h := hashFunc.New() - + if sigSalt != nil { + h.Write(sigSalt) + } + wrappedHash, err := wrapHashForSignature(h, sigType) + if err != nil { + return nil, nil, err + } switch sigType { case packet.SigTypeBinary: - return h, h, nil + return h, wrappedHash, nil case packet.SigTypeText: - return h, NewCanonicalTextHash(h), nil + return h, wrappedHash, nil } - return nil, nil, errors.UnsupportedError("unsupported signature type: " + strconv.Itoa(int(sigType))) } @@ -328,21 +348,27 @@ func hashForSignature(hashFunc crypto.Hash, sigType packet.SignatureType) (hash. // it closes the ReadCloser from any SymmetricallyEncrypted packet to trigger // MDC checks. type checkReader struct { - md *MessageDetails + md *MessageDetails + checked bool } -func (cr checkReader) Read(buf []byte) (int, error) { +func (cr *checkReader) Read(buf []byte) (int, error) { n, sensitiveParsingError := cr.md.LiteralData.Body.Read(buf) if sensitiveParsingError == io.EOF { + if cr.checked { + // Only check once + return n, io.EOF + } mdcErr := cr.md.decrypted.Close() if mdcErr != nil { return n, mdcErr } + cr.checked = true return n, io.EOF } if sensitiveParsingError != nil { - return n, errors.StructuralError("parsing error") + return n, errors.HandleSensitiveParsingError(sensitiveParsingError, true) } return n, nil @@ -366,6 +392,7 @@ func (scr *signatureCheckReader) Read(buf []byte) (int, error) { scr.wrappedHash.Write(buf[:n]) } + readsDecryptedData := scr.md.decrypted != nil if sensitiveParsingError == io.EOF { var p packet.Packet var readError error @@ -384,7 +411,7 @@ func (scr *signatureCheckReader) Read(buf []byte) (int, error) { key := scr.md.SignedBy signatureError := key.PublicKey.VerifySignature(scr.h, sig) if signatureError == nil { - signatureError = checkSignatureDetails(key, sig, scr.config) + signatureError = checkMessageSignatureDetails(key, sig, scr.config) } scr.md.Signature = sig scr.md.SignatureError = signatureError @@ -408,16 +435,15 @@ func (scr *signatureCheckReader) Read(buf []byte) (int, error) { // unsigned hash of its own. In order to check this we need to // close that Reader. if scr.md.decrypted != nil { - mdcErr := scr.md.decrypted.Close() - if mdcErr != nil { - return n, mdcErr + if sensitiveParsingError := scr.md.decrypted.Close(); sensitiveParsingError != nil { + return n, errors.HandleSensitiveParsingError(sensitiveParsingError, true) } } return n, io.EOF } if sensitiveParsingError != nil { - return n, errors.StructuralError("parsing error") + return n, errors.HandleSensitiveParsingError(sensitiveParsingError, readsDecryptedData) } return n, nil @@ -428,14 +454,13 @@ func (scr *signatureCheckReader) Read(buf []byte) (int, error) { // if any, and a possible signature verification error. // If the signer isn't known, ErrUnknownIssuer is returned. func VerifyDetachedSignature(keyring KeyRing, signed, signature io.Reader, config *packet.Config) (sig *packet.Signature, signer *Entity, err error) { - var expectedHashes []crypto.Hash - return verifyDetachedSignature(keyring, signed, signature, expectedHashes, config) + return verifyDetachedSignature(keyring, signed, signature, nil, false, config) } // VerifyDetachedSignatureAndHash performs the same actions as // VerifyDetachedSignature and checks that the expected hash functions were used. func VerifyDetachedSignatureAndHash(keyring KeyRing, signed, signature io.Reader, expectedHashes []crypto.Hash, config *packet.Config) (sig *packet.Signature, signer *Entity, err error) { - return verifyDetachedSignature(keyring, signed, signature, expectedHashes, config) + return verifyDetachedSignature(keyring, signed, signature, expectedHashes, true, config) } // CheckDetachedSignature takes a signed file and a detached signature and @@ -443,25 +468,24 @@ func VerifyDetachedSignatureAndHash(keyring KeyRing, signed, signature io.Reader // signature verification error. If the signer isn't known, // ErrUnknownIssuer is returned. func CheckDetachedSignature(keyring KeyRing, signed, signature io.Reader, config *packet.Config) (signer *Entity, err error) { - var expectedHashes []crypto.Hash - return CheckDetachedSignatureAndHash(keyring, signed, signature, expectedHashes, config) + _, signer, err = verifyDetachedSignature(keyring, signed, signature, nil, false, config) + return } // CheckDetachedSignatureAndHash performs the same actions as // CheckDetachedSignature and checks that the expected hash functions were used. func CheckDetachedSignatureAndHash(keyring KeyRing, signed, signature io.Reader, expectedHashes []crypto.Hash, config *packet.Config) (signer *Entity, err error) { - _, signer, err = verifyDetachedSignature(keyring, signed, signature, expectedHashes, config) + _, signer, err = verifyDetachedSignature(keyring, signed, signature, expectedHashes, true, config) return } -func verifyDetachedSignature(keyring KeyRing, signed, signature io.Reader, expectedHashes []crypto.Hash, config *packet.Config) (sig *packet.Signature, signer *Entity, err error) { +func verifyDetachedSignature(keyring KeyRing, signed, signature io.Reader, expectedHashes []crypto.Hash, checkHashes bool, config *packet.Config) (sig *packet.Signature, signer *Entity, err error) { var issuerKeyId uint64 var hashFunc crypto.Hash var sigType packet.SignatureType var keys []Key var p packet.Packet - expectedHashesLen := len(expectedHashes) packets := packet.NewReader(signature) for { p, err = packets.Next() @@ -483,16 +507,19 @@ func verifyDetachedSignature(keyring KeyRing, signed, signature io.Reader, expec issuerKeyId = *sig.IssuerKeyId hashFunc = sig.Hash sigType = sig.SigType - - for i, expectedHash := range expectedHashes { - if hashFunc == expectedHash { - break + if checkHashes { + matchFound := false + // check for hashes + for _, expectedHash := range expectedHashes { + if hashFunc == expectedHash { + matchFound = true + break + } } - if i+1 == expectedHashesLen { - return nil, nil, errors.StructuralError("hash algorithm mismatch with cleartext message headers") + if !matchFound { + return nil, nil, errors.StructuralError("hash algorithm or salt mismatch with cleartext message headers") } } - keys = keyring.KeysByIdUsage(issuerKeyId, packet.KeyFlagSign) if len(keys) > 0 { break @@ -503,7 +530,11 @@ func verifyDetachedSignature(keyring KeyRing, signed, signature io.Reader, expec panic("unreachable") } - h, wrappedHash, err := hashForSignature(hashFunc, sigType) + h, err := sig.PrepareVerify() + if err != nil { + return nil, nil, err + } + wrappedHash, err := wrapHashForSignature(h, sigType) if err != nil { return nil, nil, err } @@ -515,7 +546,7 @@ func verifyDetachedSignature(keyring KeyRing, signed, signature io.Reader, expec for _, key := range keys { err = key.PublicKey.VerifySignature(h, sig) if err == nil { - return sig, key.Entity, checkSignatureDetails(&key, sig, config) + return sig, key.Entity, checkMessageSignatureDetails(&key, sig, config) } } @@ -533,7 +564,7 @@ func CheckArmoredDetachedSignature(keyring KeyRing, signed, signature io.Reader, return CheckDetachedSignature(keyring, signed, body, config) } -// checkSignatureDetails returns an error if: +// checkMessageSignatureDetails returns an error if: // - The signature (or one of the binding signatures mentioned below) // has a unknown critical notation data subpacket // - The primary key of the signing entity is revoked @@ -551,15 +582,11 @@ func CheckArmoredDetachedSignature(keyring KeyRing, signed, signature io.Reader, // NOTE: The order of these checks is important, as the caller may choose to // ignore ErrSignatureExpired or ErrKeyExpired errors, but should never // ignore any other errors. -// -// TODO: Also return an error if: -// - The primary key is expired according to a direct-key signature -// - (For V5 keys only:) The direct-key signature (exists and) is expired -func checkSignatureDetails(key *Key, signature *packet.Signature, config *packet.Config) error { +func checkMessageSignatureDetails(key *Key, signature *packet.Signature, config *packet.Config) error { now := config.Now() - primaryIdentity := key.Entity.PrimaryIdentity() + primarySelfSignature, primaryIdentity := key.Entity.PrimarySelfSignature() signedBySubKey := key.PublicKey != key.Entity.PrimaryKey - sigsToCheck := []*packet.Signature{signature, primaryIdentity.SelfSignature} + sigsToCheck := []*packet.Signature{signature, primarySelfSignature} if signedBySubKey { sigsToCheck = append(sigsToCheck, key.SelfSignature, key.SelfSignature.EmbeddedSignature) } @@ -572,10 +599,10 @@ func checkSignatureDetails(key *Key, signature *packet.Signature, config *packet } if key.Entity.Revoked(now) || // primary key is revoked (signedBySubKey && key.Revoked(now)) || // subkey is revoked - primaryIdentity.Revoked(now) { // primary identity is revoked + (primaryIdentity != nil && primaryIdentity.Revoked(now)) { // primary identity is revoked for v4 return errors.ErrKeyRevoked } - if key.Entity.PrimaryKey.KeyExpired(primaryIdentity.SelfSignature, now) { // primary key is expired + if key.Entity.PrimaryKey.KeyExpired(primarySelfSignature, now) { // primary key is expired return errors.ErrKeyExpired } if signedBySubKey { diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/read_write_test_data.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/read_write_test_data.go index db6dad5c0b..670d60226a 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/read_write_test_data.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/read_write_test_data.go @@ -26,6 +26,8 @@ const testKeys1And2PrivateHex = "9501d8044d3c5c10010400b1d13382944bd5aba23a43129 const dsaElGamalTestKeysHex = "9501e1044dfcb16a110400aa3e5c1a1f43dd28c2ffae8abf5cfce555ee874134d8ba0a0f7b868ce2214beddc74e5e1e21ded354a95d18acdaf69e5e342371a71fbb9093162e0c5f3427de413a7f2c157d83f5cd2f9d791256dc4f6f0e13f13c3302af27f2384075ab3021dff7a050e14854bbde0a1094174855fc02f0bae8e00a340d94a1f22b32e48485700a0cec672ac21258fb95f61de2ce1af74b2c4fa3e6703ff698edc9be22c02ae4d916e4fa223f819d46582c0516235848a77b577ea49018dcd5e9e15cff9dbb4663a1ae6dd7580fa40946d40c05f72814b0f88481207e6c0832c3bded4853ebba0a7e3bd8e8c66df33d5a537cd4acf946d1080e7a3dcea679cb2b11a72a33a2b6a9dc85f466ad2ddf4c3db6283fa645343286971e3dd700703fc0c4e290d45767f370831a90187e74e9972aae5bff488eeff7d620af0362bfb95c1a6c3413ab5d15a2e4139e5d07a54d72583914661ed6a87cce810be28a0aa8879a2dd39e52fb6fe800f4f181ac7e328f740cde3d09a05cecf9483e4cca4253e60d4429ffd679d9996a520012aad119878c941e3cf151459873bdfc2a9563472fe0303027a728f9feb3b864260a1babe83925ce794710cfd642ee4ae0e5b9d74cee49e9c67b6cd0ea5dfbb582132195a121356a1513e1bca73e5b80c58c7ccb4164453412f456c47616d616c2054657374204b65792031886204131102002205024dfcb16a021b03060b090807030206150802090a0b0416020301021e01021780000a091033af447ccd759b09fadd00a0b8fd6f5a790bad7e9f2dbb7632046dc4493588db009c087c6a9ba9f7f49fab221587a74788c00db4889ab00200009d0157044dfcb16a1004008dec3f9291205255ccff8c532318133a6840739dd68b03ba942676f9038612071447bf07d00d559c5c0875724ea16a4c774f80d8338b55fca691a0522e530e604215b467bbc9ccfd483a1da99d7bc2648b4318fdbd27766fc8bfad3fddb37c62b8ae7ccfe9577e9b8d1e77c1d417ed2c2ef02d52f4da11600d85d3229607943700030503ff506c94c87c8cab778e963b76cf63770f0a79bf48fb49d3b4e52234620fc9f7657f9f8d56c96a2b7c7826ae6b57ebb2221a3fe154b03b6637cea7e6d98e3e45d87cf8dc432f723d3d71f89c5192ac8d7290684d2c25ce55846a80c9a7823f6acd9bb29fa6cd71f20bc90eccfca20451d0c976e460e672b000df49466408d527affe0303027a728f9feb3b864260abd761730327bca2aaa4ea0525c175e92bf240682a0e83b226f97ecb2e935b62c9a133858ce31b271fa8eb41f6a1b3cd72a63025ce1a75ee4180dcc284884904181102000905024dfcb16a021b0c000a091033af447ccd759b09dd0b009e3c3e7296092c81bee5a19929462caaf2fff3ae26009e218c437a2340e7ea628149af1ec98ec091a43992b00200009501e1044dfcb1be1104009f61faa61aa43df75d128cbe53de528c4aec49ce9360c992e70c77072ad5623de0a3a6212771b66b39a30dad6781799e92608316900518ec01184a85d872365b7d2ba4bacfb5882ea3c2473d3750dc6178cc1cf82147fb58caa28b28e9f12f6d1efcb0534abed644156c91cca4ab78834268495160b2400bc422beb37d237c2300a0cac94911b6d493bda1e1fbc6feeca7cb7421d34b03fe22cec6ccb39675bb7b94a335c2b7be888fd3906a1125f33301d8aa6ec6ee6878f46f73961c8d57a3e9544d8ef2a2cbfd4d52da665b1266928cfe4cb347a58c412815f3b2d2369dec04b41ac9a71cc9547426d5ab941cccf3b18575637ccfb42df1a802df3cfe0a999f9e7109331170e3a221991bf868543960f8c816c28097e503fe319db10fb98049f3a57d7c80c420da66d56f3644371631fad3f0ff4040a19a4fedc2d07727a1b27576f75a4d28c47d8246f27071e12d7a8de62aad216ddbae6aa02efd6b8a3e2818cda48526549791ab277e447b3a36c57cefe9b592f5eab73959743fcc8e83cbefec03a329b55018b53eec196765ae40ef9e20521a603c551efe0303020950d53a146bf9c66034d00c23130cce95576a2ff78016ca471276e8227fb30b1ffbd92e61804fb0c3eff9e30b1a826ee8f3e4730b4d86273ca977b4164453412f456c47616d616c2054657374204b65792032886204131102002205024dfcb1be021b03060b090807030206150802090a0b0416020301021e01021780000a0910a86bf526325b21b22bd9009e34511620415c974750a20df5cb56b182f3b48e6600a0a9466cb1a1305a84953445f77d461593f1d42bc1b00200009d0157044dfcb1be1004009565a951da1ee87119d600c077198f1c1bceb0f7aa54552489298e41ff788fa8f0d43a69871f0f6f77ebdfb14a4260cf9fbeb65d5844b4272a1904dd95136d06c3da745dc46327dd44a0f16f60135914368c8039a34033862261806bb2c5ce1152e2840254697872c85441ccb7321431d75a747a4bfb1d2c66362b51ce76311700030503fc0ea76601c196768070b7365a200e6ddb09307f262d5f39eec467b5f5784e22abdf1aa49226f59ab37cb49969d8f5230ea65caf56015abda62604544ed526c5c522bf92bed178a078789f6c807b6d34885688024a5bed9e9f8c58d11d4b82487b44c5f470c5606806a0443b79cadb45e0f897a561a53f724e5349b9267c75ca17fe0303020950d53a146bf9c660bc5f4ce8f072465e2d2466434320c1e712272fafc20e342fe7608101580fa1a1a367e60486a7cd1246b7ef5586cf5e10b32762b710a30144f12dd17dd4884904181102000905024dfcb1be021b0c000a0910a86bf526325b21b2904c00a0b2b66b4b39ccffda1d10f3ea8d58f827e30a8b8e009f4255b2d8112a184e40cde43a34e8655ca7809370b0020000" +const ed25519wX25519Key = "c54b0663877fe31b00000020f94da7bb48d60a61e567706a6587d0331999bb9d891a08242ead84543df895a3001972817b12be707e8d5f586ce61361201d344eb266a2c82fde6835762b65b0b7c2b1061f1b0a00000042058263877fe3030b090705150a0e080c021600029b03021e09222106cb186c4f0609a697e4d52dfa6c722b0c1f1e27c18a56708f6525ec27bad9acc905270902070200000000ad2820103e2d7d227ec0e6d7ce4471db36bfc97083253690271498a7ef0576c07faae14585b3b903b0127ec4fda2f023045a2ec76bcb4f9571a9651e14aee1137a1d668442c88f951e33c4ffd33fb9a17d511eed758fc6d9cc50cb5fd793b2039d5804c74b0663877fe319000000208693248367f9e5015db922f8f48095dda784987f2d5985b12fbad16caf5e4435004d600a4f794d44775c57a26e0feefed558e9afffd6ad0d582d57fb2ba2dcedb8c29b06181b0a0000002c050263877fe322a106cb186c4f0609a697e4d52dfa6c722b0c1f1e27c18a56708f6525ec27bad9acc9021b0c00000000defa20a6e9186d9d5935fc8fe56314cdb527486a5a5120f9b762a235a729f039010a56b89c658568341fbef3b894e9834ad9bc72afae2f4c9c47a43855e65f1cb0a3f77bbc5f61085c1f8249fe4e7ca59af5f0bcee9398e0fa8d76e522e1d8ab42bb0d" + const signedMessageHex = "a3019bc0cbccc0c4b8d8b74ee2108fe16ec6d3ca490cbe362d3f8333d3f352531472538b8b13d353b97232f352158c20943157c71c16064626063656269052062e4e01987e9b6fccff4b7df3a34c534b23e679cbec3bc0f8f6e64dfb4b55fe3f8efa9ce110ddb5cd79faf1d753c51aecfa669f7e7aa043436596cccc3359cb7dd6bbe9ecaa69e5989d9e57209571edc0b2fa7f57b9b79a64ee6e99ce1371395fee92fec2796f7b15a77c386ff668ee27f6d38f0baa6c438b561657377bf6acff3c5947befd7bf4c196252f1d6e5c524d0300" const signedTextMessageHex = "a3019bc0cbccc8c4b8d8b74ee2108fe16ec6d36a250cbece0c178233d3f352531472538b8b13d35379b97232f352158ca0b4312f57c71c1646462606365626906a062e4e019811591798ff99bf8afee860b0d8a8c2a85c3387e3bcf0bb3b17987f2bbcfab2aa526d930cbfd3d98757184df3995c9f3e7790e36e3e9779f06089d4c64e9e47dd6202cb6e9bc73c5d11bb59fbaf89d22d8dc7cf199ddf17af96e77c5f65f9bbed56f427bd8db7af37f6c9984bf9385efaf5f184f986fb3e6adb0ecfe35bbf92d16a7aa2a344fb0bc52fb7624f0200" @@ -160,18 +162,78 @@ TcIYl5/Uyoi+FOvPLcNw4hOv2nwUzSSVAw== =IiS2 -----END PGP PRIVATE KEY BLOCK-----` -// Generated with the above private key -const v5PrivKeyMsg = `-----BEGIN PGP MESSAGE----- -Version: OpenPGP.js v4.10.7 -Comment: https://openpgpjs.org +// See OpenPGP crypto refresh Section A.3. +const v6PrivKey = `-----BEGIN PGP PRIVATE KEY BLOCK----- + +xUsGY4d/4xsAAAAg+U2nu0jWCmHlZ3BqZYfQMxmZu52JGggkLq2EVD34laMAGXKB +exK+cH6NX1hs5hNhIB00TrJmosgv3mg1ditlsLfCsQYfGwoAAABCBYJjh3/jAwsJ +BwUVCg4IDAIWAAKbAwIeCSIhBssYbE8GCaaX5NUt+mxyKwwfHifBilZwj2Ul7Ce6 +2azJBScJAgcCAAAAAK0oIBA+LX0ifsDm185Ecds2v8lwgyU2kCcUmKfvBXbAf6rh +RYWzuQOwEn7E/aLwIwRaLsdry0+VcallHhSu4RN6HWaEQsiPlR4zxP/TP7mhfVEe +7XWPxtnMUMtf15OyA51YBMdLBmOHf+MZAAAAIIaTJINn+eUBXbki+PSAld2nhJh/ +LVmFsS+60WyvXkQ1AE1gCk95TUR3XFeibg/u/tVY6a//1q0NWC1X+yui3O24wpsG +GBsKAAAALAWCY4d/4wKbDCIhBssYbE8GCaaX5NUt+mxyKwwfHifBilZwj2Ul7Ce6 +2azJAAAAAAQBIKbpGG2dWTX8j+VjFM21J0hqWlEg+bdiojWnKfA5AQpWUWtnNwDE +M0g12vYxoWM8Y81W+bHBw805I8kWVkXU6vFOi+HWvv/ira7ofJu16NnoUkhclkUr +k0mXubZvyl4GBg== +-----END PGP PRIVATE KEY BLOCK-----` + +// See OpenPGP crypto refresh merge request: +// https://gitlab.com/openpgp-wg/rfc4880bis/-/merge_requests/304 +const v6PrivKeyMsg = `-----BEGIN PGP MESSAGE----- + +wV0GIQYSyD8ecG9jCP4VGkF3Q6HwM3kOk+mXhIjR2zeNqZMIhRmHzxjV8bU/gXzO +WgBM85PMiVi93AZfJfhK9QmxfdNnZBjeo1VDeVZheQHgaVf7yopqR6W1FT6NOrfS +aQIHAgZhZBZTW+CwcW1g4FKlbExAf56zaw76/prQoN+bAzxpohup69LA7JW/Vp0l +yZnuSj3hcFj0DfqLTGgr4/u717J+sPWbtQBfgMfG9AOIwwrUBqsFE9zW+f1zdlYo +bhF30A+IitsxxA== +-----END PGP MESSAGE-----` + +// See OpenPGP crypto refresh merge request: +// https://gitlab.com/openpgp-wg/rfc4880bis/-/merge_requests/305 +const v6PrivKeyInlineSignMsg = `-----BEGIN PGP MESSAGE----- -xA0DAQoWGTR7yYckZAIByxF1B21zZy50eHRfbIGSdGVzdMJ3BQEWCgAGBQJf -bIGSACMiIQUZNHvJhyRkAl+Z3z7C4AAO2YhIkuH3s+pMlACRWVabVDQvAP9G -y29VPonFXqi2zKkpZrvyvZxg+n5e8Nt9wNbuxeCd3QD/TtO2s+JvjrE4Siwv -UQdl5MlBka1QSNbMq2Bz7XwNPg4= -=6lbM +wV0GIQYSyD8ecG9jCP4VGkF3Q6HwM3kOk+mXhIjR2zeNqZMIhRmHzxjV8bU/gXzO +WgBM85PMiVi93AZfJfhK9QmxfdNnZBjeo1VDeVZheQHgaVf7yopqR6W1FT6NOrfS +aQIHAgZhZBZTW+CwcW1g4FKlbExAf56zaw76/prQoN+bAzxpohup69LA7JW/Vp0l +yZnuSj3hcFj0DfqLTGgr4/u717J+sPWbtQBfgMfG9AOIwwrUBqsFE9zW+f1zdlYo +bhF30A+IitsxxA== -----END PGP MESSAGE-----` +// See https://gitlab.com/openpgp-wg/rfc4880bis/-/merge_requests/274 +// decryption password: "correct horse battery staple" +const v6ArgonSealedPrivKey = `-----BEGIN PGP PRIVATE KEY BLOCK----- + +xYIGY4d/4xsAAAAg+U2nu0jWCmHlZ3BqZYfQMxmZu52JGggkLq2EVD34laP9JgkC +FARdb9ccngltHraRe25uHuyuAQQVtKipJ0+r5jL4dacGWSAheCWPpITYiyfyIOPS +3gIDyg8f7strd1OB4+LZsUhcIjOMpVHgmiY/IutJkulneoBYwrEGHxsKAAAAQgWC +Y4d/4wMLCQcFFQoOCAwCFgACmwMCHgkiIQbLGGxPBgmml+TVLfpscisMHx4nwYpW +cI9lJewnutmsyQUnCQIHAgAAAACtKCAQPi19In7A5tfORHHbNr/JcIMlNpAnFJin +7wV2wH+q4UWFs7kDsBJ+xP2i8CMEWi7Ha8tPlXGpZR4UruETeh1mhELIj5UeM8T/ +0z+5oX1RHu11j8bZzFDLX9eTsgOdWATHggZjh3/jGQAAACCGkySDZ/nlAV25Ivj0 +gJXdp4SYfy1ZhbEvutFsr15ENf0mCQIUBA5hhGgp2oaavg6mFUXcFMwBBBUuE8qf +9Ock+xwusd+GAglBr5LVyr/lup3xxQvHXFSjjA2haXfoN6xUGRdDEHI6+uevKjVR +v5oAxgu7eJpaXNjCmwYYGwoAAAAsBYJjh3/jApsMIiEGyxhsTwYJppfk1S36bHIr +DB8eJ8GKVnCPZSXsJ7rZrMkAAAAABAEgpukYbZ1ZNfyP5WMUzbUnSGpaUSD5t2Ki +Nacp8DkBClZRa2c3AMQzSDXa9jGhYzxjzVb5scHDzTkjyRZWRdTq8U6L4da+/+Kt +ruh8m7Xo2ehSSFyWRSuTSZe5tm/KXgYG +-----END PGP PRIVATE KEY BLOCK-----` + +const v4Key25519 = `-----BEGIN PGP PRIVATE KEY BLOCK----- + +xUkEZB3qzRto01j2k2pwN5ux9w70stPinAdXULLr20CRW7U7h2GSeACch0M+ +qzQg8yjFQ8VBvu3uwgKH9senoHmj72lLSCLTmhFKzQR0ZXN0wogEEBsIAD4F +gmQd6s0ECwkHCAmQIf45+TuC+xMDFQgKBBYAAgECGQECmwMCHgEWIQSWEzMi +jJUHvyIbVKIh/jn5O4L7EwAAUhaHNlgudvxARdPPETUzVgjuWi+YIz8w1xIb +lHQMvIrbe2sGCQIethpWofd0x7DHuv/ciHg+EoxJ/Td6h4pWtIoKx0kEZB3q +zRm4CyA7quliq7yx08AoOqHTuuCgvpkSdEhpp3pEyejQOgBo0p6ywIiLPllY +0t+jpNspHpAGfXID6oqjpYuJw3AfVRBlwnQEGBsIACoFgmQd6s0JkCH+Ofk7 +gvsTApsMFiEElhMzIoyVB78iG1SiIf45+TuC+xMAAGgQuN9G73446ykvJ/mL +sCZ7zGFId2gBd1EnG0FTC4npfOKpck0X8dngByrCxU8LDSfvjsEp/xDAiKsQ +aU71tdtNBQ== +=e7jT +-----END PGP PRIVATE KEY BLOCK-----` + const keyWithExpiredCrossSig = `-----BEGIN PGP PUBLIC KEY BLOCK----- xsDNBF2lnPIBDAC5cL9PQoQLTMuhjbYvb4Ncuuo0bfmgPRFywX53jPhoFf4Zg6mv @@ -272,3 +334,124 @@ AtNTq6ihLMD5v1d82ZC7tNatdlDMGWnIdvEMCv2GZcuIqDQ9rXWs49e7tq1NncLY hz3tYjKhoFTKEIq3y3Pp =h/aX -----END PGP PUBLIC KEY BLOCK-----` + +const keyv5Test = `-----BEGIN PGP PRIVATE KEY BLOCK----- +Comment: Bob's OpenPGP Transferable Secret Key + +lQVYBF2lnPIBDAC5cL9PQoQLTMuhjbYvb4Ncuuo0bfmgPRFywX53jPhoFf4Zg6mv +/seOXpgecTdOcVttfzC8ycIKrt3aQTiwOG/ctaR4Bk/t6ayNFfdUNxHWk4WCKzdz +/56fW2O0F23qIRd8UUJp5IIlN4RDdRCtdhVQIAuzvp2oVy/LaS2kxQoKvph/5pQ/ +5whqsyroEWDJoSV0yOb25B/iwk/pLUFoyhDG9bj0kIzDxrEqW+7Ba8nocQlecMF3 +X5KMN5kp2zraLv9dlBBpWW43XktjcCZgMy20SouraVma8Je/ECwUWYUiAZxLIlMv +9CurEOtxUw6N3RdOtLmYZS9uEnn5y1UkF88o8Nku890uk6BrewFzJyLAx5wRZ4F0 +qV/yq36UWQ0JB/AUGhHVPdFf6pl6eaxBwT5GXvbBUibtf8YI2og5RsgTWtXfU7eb +SGXrl5ZMpbA6mbfhd0R8aPxWfmDWiIOhBufhMCvUHh1sApMKVZnvIff9/0Dca3wb +vLIwa3T4CyshfT0AEQEAAQAL/RZqbJW2IqQDCnJi4Ozm++gPqBPiX1RhTWSjwxfM +cJKUZfzLj414rMKm6Jh1cwwGY9jekROhB9WmwaaKT8HtcIgrZNAlYzANGRCM4TLK +3VskxfSwKKna8l+s+mZglqbAjUg3wmFuf9Tj2xcUZYmyRm1DEmcN2ZzpvRtHgX7z +Wn1mAKUlSDJZSQks0zjuMNbupcpyJokdlkUg2+wBznBOTKzgMxVNC9b2g5/tMPUs +hGGWmF1UH+7AHMTaS6dlmr2ZBIyogdnfUqdNg5sZwsxSNrbglKP4sqe7X61uEAIQ +bD7rT3LonLbhkrj3I8wilUD8usIwt5IecoHhd9HziqZjRCc1BUBkboUEoyedbDV4 +i4qfsFZ6CEWoLuD5pW7dEp0M+WeuHXO164Rc+LnH6i1VQrpb1Okl4qO6ejIpIjBI +1t3GshtUu/mwGBBxs60KBX5g77mFQ9lLCRj8lSYqOsHRKBhUp4qM869VA+fD0BRP +fqPT0I9IH4Oa/A3jYJcg622GwQYA1LhnP208Waf6PkQSJ6kyr8ymY1yVh9VBE/g6 +fRDYA+pkqKnw9wfH2Qho3ysAA+OmVOX8Hldg+Pc0Zs0e5pCavb0En8iFLvTA0Q2E +LR5rLue9uD7aFuKFU/VdcddY9Ww/vo4k5p/tVGp7F8RYCFn9rSjIWbfvvZi1q5Tx ++akoZbga+4qQ4WYzB/obdX6SCmi6BndcQ1QdjCCQU6gpYx0MddVERbIp9+2SXDyL +hpxjSyz+RGsZi/9UAshT4txP4+MZBgDfK3ZqtW+h2/eMRxkANqOJpxSjMyLO/FXN +WxzTDYeWtHNYiAlOwlQZEPOydZFty9IVzzNFQCIUCGjQ/nNyhw7adSgUk3+BXEx/ +MyJPYY0BYuhLxLYcrfQ9nrhaVKxRJj25SVHj2ASsiwGJRZW4CC3uw40OYxfKEvNC +mer/VxM3kg8qqGf9KUzJ1dVdAvjyx2Hz6jY2qWCyRQ6IMjWHyd43C4r3jxooYKUC +YnstRQyb/gCSKahveSEjo07CiXMr88UGALwzEr3npFAsPW3osGaFLj49y1oRe11E +he9gCHFm+fuzbXrWmdPjYU5/ZdqdojzDqfu4ThfnipknpVUM1o6MQqkjM896FHm8 +zbKVFSMhEP6DPHSCexMFrrSgN03PdwHTO6iBaIBBFqmGY01tmJ03SxvSpiBPON9P +NVvy/6UZFedTq8A07OUAxO62YUSNtT5pmK2vzs3SAZJmbFbMh+NN204TRI72GlqT +t5hcfkuv8hrmwPS/ZR6q312mKQ6w/1pqO9qitCFCb2IgQmFiYmFnZSA8Ym9iQG9w +ZW5wZ3AuZXhhbXBsZT6JAc4EEwEKADgCGwMFCwkIBwIGFQoJCAsCBBYCAwECHgEC +F4AWIQTRpm4aI7GCyZgPeIz7/MgqAV5zMAUCXaWe+gAKCRD7/MgqAV5zMG9sC/9U +2T3RrqEbw533FPNfEflhEVRIZ8gDXKM8hU6cqqEzCmzZT6xYTe6sv4y+PJBGXJFX +yhj0g6FDkSyboM5litOcTupURObVqMgA/Y4UKERznm4fzzH9qek85c4ljtLyNufe +doL2pp3vkGtn7eD0QFRaLLmnxPKQ/TlZKdLE1G3u8Uot8QHicaR6GnAdc5UXQJE3 +BiV7jZuDyWmZ1cUNwJkKL6oRtp+ZNDOQCrLNLecKHcgCqrpjSQG5oouba1I1Q6Vl +sP44dhA1nkmLHtxlTOzpeHj4jnk1FaXmyasurrrI5CgU/L2Oi39DGKTH/A/cywDN +4ZplIQ9zR8enkbXquUZvFDe+Xz+6xRXtb5MwQyWODB3nHw85HocLwRoIN9WdQEI+ +L8a/56AuOwhs8llkSuiITjR7r9SgKJC2WlAHl7E8lhJ3VDW3ELC56KH308d6mwOG +ZRAqIAKzM1T5FGjMBhq7ZV0eqdEntBh3EcOIfj2M8rg1MzJv+0mHZOIjByawikad +BVgEXaWc8gEMANYwv1xsYyunXYK0X1vY/rP1NNPvhLyLIE7NpK90YNBj+xS1ldGD +bUdZqZeef2xJe8gMQg05DoD1DF3GipZ0Ies65beh+d5hegb7N4pzh0LzrBrVNHar +29b5ExdI7i4iYD5TO6Vr/qTUOiAN/byqELEzAb+L+b2DVz/RoCm4PIp1DU9ewcc2 +WB38Ofqut3nLYA5tqJ9XvAiEQme+qAVcM3ZFcaMt4I4dXhDZZNg+D9LiTWcxdUPB +leu8iwDRjAgyAhPzpFp+nWoqWA81uIiULWD1Fj+IVoY3ZvgivoYOiEFBJ9lbb4te +g9m5UT/AaVDTWuHzbspVlbiVe+qyB77C2daWzNyx6UYBPLOo4r0t0c91kbNE5lgj +Z7xz6los0N1U8vq91EFSeQJoSQ62XWavYmlCLmdNT6BNfgh4icLsT7Vr1QMX9jzn +JtTPxdXytSdHvpSpULsqJ016l0dtmONcK3z9mj5N5z0k1tg1AH970TGYOe2aUcSx +IRDMXDOPyzEfjwARAQABAAv9F2CwsjS+Sjh1M1vegJbZjei4gF1HHpEM0K0PSXsp +SfVvpR4AoSJ4He6CXSMWg0ot8XKtDuZoV9jnJaES5UL9pMAD7JwIOqZm/DYVJM5h +OASCh1c356/wSbFbzRHPtUdZO9Q30WFNJM5pHbCJPjtNoRmRGkf71RxtvHBzy7np +Ga+W6U/NVKHw0i0CYwMI0YlKDakYW3Pm+QL+gHZFvngGweTod0f9l2VLLAmeQR/c ++EZs7lNumhuZ8mXcwhUc9JQIhOkpO+wreDysEFkAcsKbkQP3UDUsA1gFx9pbMzT0 +tr1oZq2a4QBtxShHzP/ph7KLpN+6qtjks3xB/yjTgaGmtrwM8tSe0wD1RwXS+/1o +BHpXTnQ7TfeOGUAu4KCoOQLv6ELpKWbRBLWuiPwMdbGpvVFALO8+kvKAg9/r+/ny +zM2GQHY+J3Jh5JxPiJnHfXNZjIKLbFbIPdSKNyJBuazXW8xIa//mEHMI5OcvsZBK +clAIp7LXzjEjKXIwHwDcTn9pBgDpdOKTHOtJ3JUKx0rWVsDH6wq6iKV/FTVSY5jl +zN+puOEsskF1Lfxn9JsJihAVO3yNsp6RvkKtyNlFazaCVKtDAmkjoh60XNxcNRqr +gCnwdpbgdHP6v/hvZY54ZaJjz6L2e8unNEkYLxDt8cmAyGPgH2XgL7giHIp9jrsQ +aS381gnYwNX6wE1aEikgtY91nqJjwPlibF9avSyYQoMtEqM/1UjTjB2KdD/MitK5 +fP0VpvuXpNYZedmyq4UOMwdkiNMGAOrfmOeT0olgLrTMT5H97Cn3Yxbk13uXHNu/ +ZUZZNe8s+QtuLfUlKAJtLEUutN33TlWQY522FV0m17S+b80xJib3yZVJteVurrh5 +HSWHAM+zghQAvCesg5CLXa2dNMkTCmZKgCBvfDLZuZbjFwnwCI6u/NhOY9egKuUf +SA/je/RXaT8m5VxLYMxwqQXKApzD87fv0tLPlVIEvjEsaf992tFEFSNPcG1l/jpd +5AVXw6kKuf85UkJtYR1x2MkQDrqY1QX/XMw00kt8y9kMZUre19aCArcmor+hDhRJ +E3Gt4QJrD9z/bICESw4b4z2DbgD/Xz9IXsA/r9cKiM1h5QMtXvuhyfVeM01enhxM +GbOH3gjqqGNKysx0UODGEwr6AV9hAd8RWXMchJLaExK9J5SRawSg671ObAU24SdY +vMQ9Z4kAQ2+1ReUZzf3ogSMRZtMT+d18gT6L90/y+APZIaoArLPhebIAGq39HLmJ +26x3z0WAgrpA1kNsjXEXkoiZGPLKIGoe3hqJAbYEGAEKACAWIQTRpm4aI7GCyZgP +eIz7/MgqAV5zMAUCXaWc8gIbDAAKCRD7/MgqAV5zMOn/C/9ugt+HZIwX308zI+QX +c5vDLReuzmJ3ieE0DMO/uNSC+K1XEioSIZP91HeZJ2kbT9nn9fuReuoff0T0Dief +rbwcIQQHFFkrqSp1K3VWmUGp2JrUsXFVdjy/fkBIjTd7c5boWljv/6wAsSfiv2V0 +JSM8EFU6TYXxswGjFVfc6X97tJNeIrXL+mpSmPPqy2bztcCCHkWS5lNLWQw+R7Vg +71Fe6yBSNVrqC2/imYG2J9zlowjx1XU63Wdgqp2Wxt0l8OmsB/W80S1fRF5G4SDH +s9HXglXXqPsBRZJYfP+VStm9L5P/sKjCcX6WtZR7yS6G8zj/X767MLK/djANvpPd +NVniEke6hM3CNBXYPAMhQBMWhCulcoz+0lxi8L34rMN+Dsbma96psdUrn7uLaB91 +6we0CTfF8qqm7BsVAgalon/UUiuMY80U3ueoj3okiSTiHIjD/YtpXSPioC8nMng7 +xqAY9Bwizt4FWgXuLm1a4+So4V9j1TRCXd12Uc2l2RNmgDE= +=miES +-----END PGP PRIVATE KEY BLOCK----- +` + +const certv5Test = `-----BEGIN PGP PRIVATE KEY BLOCK----- + +lGEFXJH05BYAAAAtCSsGAQQB2kcPAQEHQFhZlVcVVtwf+21xNQPX+ecMJJBL0MPd +fj75iux+my8QAAAAAAAiAQCHZ1SnSUmWqxEsoI6facIVZQu6mph3cBFzzTvcm5lA +Ng5ctBhlbW1hLmdvbGRtYW5AZXhhbXBsZS5uZXSIlgUTFggASCIhBRk0e8mHJGQC +X5nfPsLgAA7ZiEiS4fez6kyUAJFZVptUBQJckfTkAhsDBQsJCAcCAyICAQYVCgkI +CwIEFgIDAQIeBwIXgAAA9cAA/jiR3yMsZMeEQ40u6uzEoXa6UXeV/S3wwJAXRJy9 +M8s0AP9vuL/7AyTfFXwwzSjDnYmzS0qAhbLDQ643N+MXGBJ2BZxmBVyR9OQSAAAA +MgorBgEEAZdVAQUBAQdA+nysrzml2UCweAqtpDuncSPlvrcBWKU0yfU0YvYWWAoD +AQgHAAAAAAAiAP9OdAPppjU1WwpqjIItkxr+VPQRT8Zm/Riw7U3F6v3OiBFHiHoF +GBYIACwiIQUZNHvJhyRkAl+Z3z7C4AAO2YhIkuH3s+pMlACRWVabVAUCXJH05AIb +DAAAOSQBAP4BOOIR/sGLNMOfeb5fPs/02QMieoiSjIBnijhob2U5AQC+RtOHCHx7 +TcIYl5/Uyoi+FOvPLcNw4hOv2nwUzSSVAw== +=IiS2 +-----END PGP PRIVATE KEY BLOCK----- +` + +const msgv5Test = `-----BEGIN PGP MESSAGE----- + +wcDMA3wvqk35PDeyAQv+PcQiLsoYTH30nJYQh3j3cJaO2+jErtVCrIQRIU0+ +rmgMddERYST4A9mA0DQIiTI4FQ0Lp440D3BWCgpq3LlNWewGzduaWwym5rN6 +cwHz5ccDqOcqbd9X0GXXGy/ZH/ljSgzuVMIytMAXKdF/vrRrVgH/+I7cxvm9 +HwnhjMN5dF0j4aEt996H2T7cbtzSr2GN9SWGW8Gyu7I8Zx73hgrGUI7gDiJB +Afaff+P6hfkkHSGOItr94dde8J/7AUF4VEwwxdVVPvsNEFyvv6gRIbYtOCa2 +6RE6h1V/QTxW2O7zZgzWALrE2ui0oaYr9QuqQSssd9CdgExLfdPbI+3/ZAnE +v31Idzpk3/6ILiakYHtXkElPXvf46mCNpobty8ysT34irF+fy3C1p3oGwAsx +5VDV9OSFU6z5U+UPbSPYAy9rkc5ZssuIKxCER2oTvZ2L8Q5cfUvEUiJtRGGn +CJlHrVDdp3FssKv2tlKgLkvxJLyoOjuEkj44H1qRk+D02FzmmUT/0sAHAYYx +lTir6mjHeLpcGjn4waUuWIAJyph8SxUexP60bic0L0NBa6Qp5SxxijKsPIDb +FPHxWwfJSDZRrgUyYT7089YFB/ZM4FHyH9TZcnxn0f0xIB7NS6YNDsxzN2zT +EVEYf+De4qT/dQTsdww78Chtcv9JY9r2kDm77dk2MUGHL2j7n8jasbLtgA7h +pn2DMIWLrGamMLWRmlwslolKr1sMV5x8w+5Ias6C33iBMl9phkg42an0gYmc +byVJHvLO/XErtC+GNIJeMg== +=liRq +-----END PGP MESSAGE----- +` diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/s2k/s2k.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/s2k/s2k.go index a43695964b..6871b84fc9 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/s2k/s2k.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/s2k/s2k.go @@ -87,10 +87,10 @@ func decodeCount(c uint8) int { // encodeMemory converts the Argon2 "memory" in the range parallelism*8 to // 2**31, inclusive, to an encoded memory. The return value is the // octet that is actually stored in the GPG file. encodeMemory panics -// if is not in the above range +// if is not in the above range // See OpenPGP crypto refresh Section 3.7.1.4. func encodeMemory(memory uint32, parallelism uint8) uint8 { - if memory < (8 * uint32(parallelism)) || memory > uint32(2147483648) { + if memory < (8*uint32(parallelism)) || memory > uint32(2147483648) { panic("Memory argument memory is outside the required range") } @@ -199,8 +199,8 @@ func Generate(rand io.Reader, c *Config) (*Params, error) { } params = &Params{ - mode: SaltedS2K, - hashId: hashId, + mode: SaltedS2K, + hashId: hashId, } } else { // Enforce IteratedSaltedS2K method otherwise hashId, ok := algorithm.HashToHashId(c.hash()) @@ -211,7 +211,7 @@ func Generate(rand io.Reader, c *Config) (*Params, error) { c.S2KMode = IteratedSaltedS2K } params = &Params{ - mode: IteratedSaltedS2K, + mode: IteratedSaltedS2K, hashId: hashId, countByte: c.EncodedCount(), } @@ -283,6 +283,9 @@ func ParseIntoParams(r io.Reader) (params *Params, err error) { params.passes = buf[Argon2SaltSize] params.parallelism = buf[Argon2SaltSize+1] params.memoryExp = buf[Argon2SaltSize+2] + if err := validateArgon2Params(params); err != nil { + return nil, err + } return params, nil case GnuS2K: // This is a GNU extension. See @@ -300,15 +303,22 @@ func ParseIntoParams(r io.Reader) (params *Params, err error) { return nil, errors.UnsupportedError("S2K function") } +func (params *Params) Mode() Mode { + return params.mode +} + func (params *Params) Dummy() bool { return params != nil && params.mode == GnuS2K } func (params *Params) salt() []byte { switch params.mode { - case SaltedS2K, IteratedSaltedS2K: return params.saltBytes[:8] - case Argon2S2K: return params.saltBytes[:Argon2SaltSize] - default: return nil + case SaltedS2K, IteratedSaltedS2K: + return params.saltBytes[:8] + case Argon2S2K: + return params.saltBytes[:Argon2SaltSize] + default: + return nil } } @@ -405,3 +415,22 @@ func Serialize(w io.Writer, key []byte, rand io.Reader, passphrase []byte, c *Co f(key, passphrase) return nil } + +// validateArgon2Params checks that the argon2 parameters are valid according to RFC9580. +func validateArgon2Params(params *Params) error { + // The number of passes t and the degree of parallelism p MUST be non-zero. + if params.parallelism == 0 { + return errors.StructuralError("invalid argon2 params: parallelism is 0") + } + if params.passes == 0 { + return errors.StructuralError("invalid argon2 params: iterations is 0") + } + + // The encoded memory size MUST be a value from 3+ceil(log2(p)) to 31, + // such that the decoded memory size m is a value from 8*p to 2^31. + if params.memoryExp > 31 || decodeMemory(params.memoryExp) < 8*uint32(params.parallelism) { + return errors.StructuralError("invalid argon2 params: memory is out of bounds") + } + + return nil +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/s2k/s2k_cache.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/s2k/s2k_cache.go index 25a4442dfb..616e0d12c6 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/s2k/s2k_cache.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/s2k/s2k_cache.go @@ -5,7 +5,7 @@ package s2k // the same parameters. type Cache map[Params][]byte -// GetOrComputeDerivedKey tries to retrieve the key +// GetOrComputeDerivedKey tries to retrieve the key // for the given s2k parameters from the cache. // If there is no hit, it derives the key with the s2k function from the passphrase, // updates the cache, and returns the key. diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/s2k/s2k_config.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/s2k/s2k_config.go index b40be5228f..b93db1ab85 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/s2k/s2k_config.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/s2k/s2k_config.go @@ -50,9 +50,9 @@ type Config struct { type Argon2Config struct { NumberOfPasses uint8 DegreeOfParallelism uint8 - // The memory parameter for Argon2 specifies desired memory usage in kibibytes. + // Memory specifies the desired Argon2 memory usage in kibibytes. // For example memory=64*1024 sets the memory cost to ~64 MB. - Memory uint32 + Memory uint32 } func (c *Config) Mode() Mode { @@ -115,7 +115,7 @@ func (c *Argon2Config) EncodedMemory() uint8 { } memory := c.Memory - lowerBound := uint32(c.Parallelism())*8 + lowerBound := uint32(c.Parallelism()) * 8 upperBound := uint32(2147483648) switch { diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/write.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/write.go index 7fdd13a3dd..b0f6ef7b09 100644 --- a/vendor/github.com/ProtonMail/go-crypto/openpgp/write.go +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/write.go @@ -76,7 +76,11 @@ func detachSign(w io.Writer, signer *Entity, message io.Reader, sigType packet.S sig := createSignaturePacket(signingKey.PublicKey, sigType, config) - h, wrappedHash, err := hashForSignature(sig.Hash, sig.SigType) + h, err := sig.PrepareSign(config) + if err != nil { + return + } + wrappedHash, err := wrapHashForSignature(h, sig.SigType) if err != nil { return } @@ -275,14 +279,28 @@ func writeAndSign(payload io.WriteCloser, candidateHashes []uint8, signed *Entit return nil, errors.InvalidArgumentError("cannot encrypt because no candidate hash functions are compiled in. (Wanted " + name + " in this case.)") } + var salt []byte if signer != nil { + var opsVersion = 3 + if signer.Version == 6 { + opsVersion = signer.Version + } ops := &packet.OnePassSignature{ + Version: opsVersion, SigType: sigType, Hash: hash, PubKeyAlgo: signer.PubKeyAlgo, KeyId: signer.KeyId, IsLast: true, } + if opsVersion == 6 { + ops.KeyFingerprint = signer.Fingerprint + salt, err = packet.SignatureSaltForHash(hash, config.Random()) + if err != nil { + return nil, err + } + ops.Salt = salt + } if err := ops.Serialize(payload); err != nil { return nil, err } @@ -310,19 +328,19 @@ func writeAndSign(payload io.WriteCloser, candidateHashes []uint8, signed *Entit } if signer != nil { - h, wrappedHash, err := hashForSignature(hash, sigType) + h, wrappedHash, err := hashForSignature(hash, sigType, salt) if err != nil { return nil, err } metadata := &packet.LiteralData{ - Format: 't', + Format: 'u', FileName: hints.FileName, Time: epochSeconds, } if hints.IsBinary { metadata.Format = 'b' } - return signatureWriter{payload, literalData, hash, wrappedHash, h, signer, sigType, config, metadata}, nil + return signatureWriter{payload, literalData, hash, wrappedHash, h, salt, signer, sigType, config, metadata}, nil } return literalData, nil } @@ -380,15 +398,19 @@ func encrypt(keyWriter io.Writer, dataWriter io.Writer, to []*Entity, signed *En return nil, errors.InvalidArgumentError("cannot encrypt a message to key id " + strconv.FormatUint(to[i].PrimaryKey.KeyId, 16) + " because it has no valid encryption keys") } - sig := to[i].PrimaryIdentity().SelfSignature - if !sig.SEIPDv2 { + primarySelfSignature, _ := to[i].PrimarySelfSignature() + if primarySelfSignature == nil { + return nil, errors.InvalidArgumentError("entity without a self-signature") + } + + if !primarySelfSignature.SEIPDv2 { aeadSupported = false } - candidateCiphers = intersectPreferences(candidateCiphers, sig.PreferredSymmetric) - candidateHashes = intersectPreferences(candidateHashes, sig.PreferredHash) - candidateCipherSuites = intersectCipherSuites(candidateCipherSuites, sig.PreferredCipherSuites) - candidateCompression = intersectPreferences(candidateCompression, sig.PreferredCompression) + candidateCiphers = intersectPreferences(candidateCiphers, primarySelfSignature.PreferredSymmetric) + candidateHashes = intersectPreferences(candidateHashes, primarySelfSignature.PreferredHash) + candidateCipherSuites = intersectCipherSuites(candidateCipherSuites, primarySelfSignature.PreferredCipherSuites) + candidateCompression = intersectPreferences(candidateCompression, primarySelfSignature.PreferredCompression) } // In the event that the intersection of supported algorithms is empty we use the ones @@ -422,13 +444,19 @@ func encrypt(keyWriter io.Writer, dataWriter io.Writer, to []*Entity, signed *En } } - symKey := make([]byte, cipher.KeySize()) + var symKey []byte + if aeadSupported { + symKey = make([]byte, aeadCipherSuite.Cipher.KeySize()) + } else { + symKey = make([]byte, cipher.KeySize()) + } + if _, err := io.ReadFull(config.Random(), symKey); err != nil { return nil, err } for _, key := range encryptKeys { - if err := packet.SerializeEncryptedKey(keyWriter, key.PublicKey, cipher, symKey, config); err != nil { + if err := packet.SerializeEncryptedKeyAEAD(keyWriter, key.PublicKey, cipher, aeadSupported, symKey, config); err != nil { return nil, err } } @@ -465,13 +493,17 @@ func Sign(output io.Writer, signed *Entity, hints *FileHints, config *packet.Con hashToHashId(crypto.SHA3_512), } defaultHashes := candidateHashes[0:1] - preferredHashes := signed.PrimaryIdentity().SelfSignature.PreferredHash + primarySelfSignature, _ := signed.PrimarySelfSignature() + if primarySelfSignature == nil { + return nil, errors.StructuralError("signed entity has no self-signature") + } + preferredHashes := primarySelfSignature.PreferredHash if len(preferredHashes) == 0 { preferredHashes = defaultHashes } candidateHashes = intersectPreferences(candidateHashes, preferredHashes) if len(candidateHashes) == 0 { - return nil, errors.InvalidArgumentError("cannot sign because signing key shares no common algorithms with candidate hashes") + return nil, errors.StructuralError("cannot sign because signing key shares no common algorithms with candidate hashes") } return writeAndSign(noOpCloser{output}, candidateHashes, signed, hints, packet.SigTypeBinary, config) @@ -486,6 +518,7 @@ type signatureWriter struct { hashType crypto.Hash wrappedHash hash.Hash h hash.Hash + salt []byte // v6 only signer *packet.PrivateKey sigType packet.SignatureType config *packet.Config @@ -509,6 +542,10 @@ func (s signatureWriter) Close() error { sig.Hash = s.hashType sig.Metadata = s.metadata + if err := sig.SetSalt(s.salt); err != nil { + return err + } + if err := sig.Sign(s.h, s.signer, s.config); err != nil { return err } diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/x25519/x25519.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/x25519/x25519.go new file mode 100644 index 0000000000..38afcc74fa --- /dev/null +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/x25519/x25519.go @@ -0,0 +1,221 @@ +package x25519 + +import ( + "crypto/sha256" + "crypto/subtle" + "io" + + "github.com/ProtonMail/go-crypto/openpgp/aes/keywrap" + "github.com/ProtonMail/go-crypto/openpgp/errors" + x25519lib "github.com/cloudflare/circl/dh/x25519" + "golang.org/x/crypto/hkdf" +) + +const ( + hkdfInfo = "OpenPGP X25519" + aes128KeySize = 16 + // The size of a public or private key in bytes. + KeySize = x25519lib.Size +) + +type PublicKey struct { + // Point represents the encoded elliptic curve point of the public key. + Point []byte +} + +type PrivateKey struct { + PublicKey + // Secret represents the secret of the private key. + Secret []byte +} + +// NewPrivateKey creates a new empty private key including the public key. +func NewPrivateKey(key PublicKey) *PrivateKey { + return &PrivateKey{ + PublicKey: key, + } +} + +// Validate validates that the provided public key matches the private key. +func Validate(pk *PrivateKey) (err error) { + var expectedPublicKey, privateKey x25519lib.Key + subtle.ConstantTimeCopy(1, privateKey[:], pk.Secret) + x25519lib.KeyGen(&expectedPublicKey, &privateKey) + if subtle.ConstantTimeCompare(expectedPublicKey[:], pk.PublicKey.Point) == 0 { + return errors.KeyInvalidError("x25519: invalid key") + } + return nil +} + +// GenerateKey generates a new x25519 key pair. +func GenerateKey(rand io.Reader) (*PrivateKey, error) { + var privateKey, publicKey x25519lib.Key + privateKeyOut := new(PrivateKey) + err := generateKey(rand, &privateKey, &publicKey) + if err != nil { + return nil, err + } + privateKeyOut.PublicKey.Point = publicKey[:] + privateKeyOut.Secret = privateKey[:] + return privateKeyOut, nil +} + +func generateKey(rand io.Reader, privateKey *x25519lib.Key, publicKey *x25519lib.Key) error { + maxRounds := 10 + isZero := true + for round := 0; isZero; round++ { + if round == maxRounds { + return errors.InvalidArgumentError("x25519: zero keys only, randomness source might be corrupt") + } + _, err := io.ReadFull(rand, privateKey[:]) + if err != nil { + return err + } + isZero = constantTimeIsZero(privateKey[:]) + } + x25519lib.KeyGen(publicKey, privateKey) + return nil +} + +// Encrypt encrypts a sessionKey with x25519 according to +// the OpenPGP crypto refresh specification section 5.1.6. The function assumes that the +// sessionKey has the correct format and padding according to the specification. +func Encrypt(rand io.Reader, publicKey *PublicKey, sessionKey []byte) (ephemeralPublicKey *PublicKey, encryptedSessionKey []byte, err error) { + var ephemeralPrivate, ephemeralPublic, staticPublic, shared x25519lib.Key + // Check that the input static public key has 32 bytes + if len(publicKey.Point) != KeySize { + err = errors.KeyInvalidError("x25519: the public key has the wrong size") + return + } + copy(staticPublic[:], publicKey.Point) + // Generate ephemeral keyPair + err = generateKey(rand, &ephemeralPrivate, &ephemeralPublic) + if err != nil { + return + } + // Compute shared key + ok := x25519lib.Shared(&shared, &ephemeralPrivate, &staticPublic) + if !ok { + err = errors.KeyInvalidError("x25519: the public key is a low order point") + return + } + // Derive the encryption key from the shared secret + encryptionKey := applyHKDF(ephemeralPublic[:], publicKey.Point[:], shared[:]) + ephemeralPublicKey = &PublicKey{ + Point: ephemeralPublic[:], + } + // Encrypt the sessionKey with aes key wrapping + encryptedSessionKey, err = keywrap.Wrap(encryptionKey, sessionKey) + return +} + +// Decrypt decrypts a session key stored in ciphertext with the provided x25519 +// private key and ephemeral public key. +func Decrypt(privateKey *PrivateKey, ephemeralPublicKey *PublicKey, ciphertext []byte) (encodedSessionKey []byte, err error) { + var ephemeralPublic, staticPrivate, shared x25519lib.Key + // Check that the input ephemeral public key has 32 bytes + if len(ephemeralPublicKey.Point) != KeySize { + err = errors.KeyInvalidError("x25519: the public key has the wrong size") + return + } + copy(ephemeralPublic[:], ephemeralPublicKey.Point) + subtle.ConstantTimeCopy(1, staticPrivate[:], privateKey.Secret) + // Compute shared key + ok := x25519lib.Shared(&shared, &staticPrivate, &ephemeralPublic) + if !ok { + err = errors.KeyInvalidError("x25519: the ephemeral public key is a low order point") + return + } + // Derive the encryption key from the shared secret + encryptionKey := applyHKDF(ephemeralPublicKey.Point[:], privateKey.PublicKey.Point[:], shared[:]) + // Decrypt the session key with aes key wrapping + encodedSessionKey, err = keywrap.Unwrap(encryptionKey, ciphertext) + return +} + +func applyHKDF(ephemeralPublicKey []byte, publicKey []byte, sharedSecret []byte) []byte { + inputKey := make([]byte, 3*KeySize) + // ephemeral public key | recipient public key | shared secret + subtle.ConstantTimeCopy(1, inputKey[:KeySize], ephemeralPublicKey) + subtle.ConstantTimeCopy(1, inputKey[KeySize:2*KeySize], publicKey) + subtle.ConstantTimeCopy(1, inputKey[2*KeySize:], sharedSecret) + hkdfReader := hkdf.New(sha256.New, inputKey, []byte{}, []byte(hkdfInfo)) + encryptionKey := make([]byte, aes128KeySize) + _, _ = io.ReadFull(hkdfReader, encryptionKey) + return encryptionKey +} + +func constantTimeIsZero(bytes []byte) bool { + isZero := byte(0) + for _, b := range bytes { + isZero |= b + } + return isZero == 0 +} + +// ENCODING/DECODING ciphertexts: + +// EncodeFieldsLength returns the length of the ciphertext encoding +// given the encrypted session key. +func EncodedFieldsLength(encryptedSessionKey []byte, v6 bool) int { + lenCipherFunction := 0 + if !v6 { + lenCipherFunction = 1 + } + return KeySize + 1 + len(encryptedSessionKey) + lenCipherFunction +} + +// EncodeField encodes x25519 session key encryption fields as +// ephemeral x25519 public key | follow byte length | cipherFunction (v3 only) | encryptedSessionKey +// and writes it to writer. +func EncodeFields(writer io.Writer, ephemeralPublicKey *PublicKey, encryptedSessionKey []byte, cipherFunction byte, v6 bool) (err error) { + lenAlgorithm := 0 + if !v6 { + lenAlgorithm = 1 + } + if _, err = writer.Write(ephemeralPublicKey.Point); err != nil { + return err + } + if _, err = writer.Write([]byte{byte(len(encryptedSessionKey) + lenAlgorithm)}); err != nil { + return err + } + if !v6 { + if _, err = writer.Write([]byte{cipherFunction}); err != nil { + return err + } + } + _, err = writer.Write(encryptedSessionKey) + return err +} + +// DecodeField decodes a x25519 session key encryption as +// ephemeral x25519 public key | follow byte length | cipherFunction (v3 only) | encryptedSessionKey. +func DecodeFields(reader io.Reader, v6 bool) (ephemeralPublicKey *PublicKey, encryptedSessionKey []byte, cipherFunction byte, err error) { + var buf [1]byte + ephemeralPublicKey = &PublicKey{ + Point: make([]byte, KeySize), + } + // 32 octets representing an ephemeral x25519 public key. + if _, err = io.ReadFull(reader, ephemeralPublicKey.Point); err != nil { + return nil, nil, 0, err + } + // A one-octet size of the following fields. + if _, err = io.ReadFull(reader, buf[:]); err != nil { + return nil, nil, 0, err + } + followingLen := buf[0] + // The one-octet algorithm identifier, if it was passed (in the case of a v3 PKESK packet). + if !v6 { + if _, err = io.ReadFull(reader, buf[:]); err != nil { + return nil, nil, 0, err + } + cipherFunction = buf[0] + followingLen -= 1 + } + // The encrypted session key. + encryptedSessionKey = make([]byte, followingLen) + if _, err = io.ReadFull(reader, encryptedSessionKey); err != nil { + return nil, nil, 0, err + } + return ephemeralPublicKey, encryptedSessionKey, cipherFunction, nil +} diff --git a/vendor/github.com/ProtonMail/go-crypto/openpgp/x448/x448.go b/vendor/github.com/ProtonMail/go-crypto/openpgp/x448/x448.go new file mode 100644 index 0000000000..65a082dabd --- /dev/null +++ b/vendor/github.com/ProtonMail/go-crypto/openpgp/x448/x448.go @@ -0,0 +1,229 @@ +package x448 + +import ( + "crypto/sha512" + "crypto/subtle" + "io" + + "github.com/ProtonMail/go-crypto/openpgp/aes/keywrap" + "github.com/ProtonMail/go-crypto/openpgp/errors" + x448lib "github.com/cloudflare/circl/dh/x448" + "golang.org/x/crypto/hkdf" +) + +const ( + hkdfInfo = "OpenPGP X448" + aes256KeySize = 32 + // The size of a public or private key in bytes. + KeySize = x448lib.Size +) + +type PublicKey struct { + // Point represents the encoded elliptic curve point of the public key. + Point []byte +} + +type PrivateKey struct { + PublicKey + // Secret represents the secret of the private key. + Secret []byte +} + +// NewPrivateKey creates a new empty private key including the public key. +func NewPrivateKey(key PublicKey) *PrivateKey { + return &PrivateKey{ + PublicKey: key, + } +} + +// Validate validates that the provided public key matches +// the private key. +func Validate(pk *PrivateKey) (err error) { + var expectedPublicKey, privateKey x448lib.Key + subtle.ConstantTimeCopy(1, privateKey[:], pk.Secret) + x448lib.KeyGen(&expectedPublicKey, &privateKey) + if subtle.ConstantTimeCompare(expectedPublicKey[:], pk.PublicKey.Point) == 0 { + return errors.KeyInvalidError("x448: invalid key") + } + return nil +} + +// GenerateKey generates a new x448 key pair. +func GenerateKey(rand io.Reader) (*PrivateKey, error) { + var privateKey, publicKey x448lib.Key + privateKeyOut := new(PrivateKey) + err := generateKey(rand, &privateKey, &publicKey) + if err != nil { + return nil, err + } + privateKeyOut.PublicKey.Point = publicKey[:] + privateKeyOut.Secret = privateKey[:] + return privateKeyOut, nil +} + +func generateKey(rand io.Reader, privateKey *x448lib.Key, publicKey *x448lib.Key) error { + maxRounds := 10 + isZero := true + for round := 0; isZero; round++ { + if round == maxRounds { + return errors.InvalidArgumentError("x448: zero keys only, randomness source might be corrupt") + } + _, err := io.ReadFull(rand, privateKey[:]) + if err != nil { + return err + } + isZero = constantTimeIsZero(privateKey[:]) + } + x448lib.KeyGen(publicKey, privateKey) + return nil +} + +// Encrypt encrypts a sessionKey with x448 according to +// the OpenPGP crypto refresh specification section 5.1.7. The function assumes that the +// sessionKey has the correct format and padding according to the specification. +func Encrypt(rand io.Reader, publicKey *PublicKey, sessionKey []byte) (ephemeralPublicKey *PublicKey, encryptedSessionKey []byte, err error) { + var ephemeralPrivate, ephemeralPublic, staticPublic, shared x448lib.Key + // Check that the input static public key has 56 bytes. + if len(publicKey.Point) != KeySize { + err = errors.KeyInvalidError("x448: the public key has the wrong size") + return nil, nil, err + } + copy(staticPublic[:], publicKey.Point) + // Generate ephemeral keyPair. + if err = generateKey(rand, &ephemeralPrivate, &ephemeralPublic); err != nil { + return nil, nil, err + } + // Compute shared key. + ok := x448lib.Shared(&shared, &ephemeralPrivate, &staticPublic) + if !ok { + err = errors.KeyInvalidError("x448: the public key is a low order point") + return nil, nil, err + } + // Derive the encryption key from the shared secret. + encryptionKey := applyHKDF(ephemeralPublic[:], publicKey.Point[:], shared[:]) + ephemeralPublicKey = &PublicKey{ + Point: ephemeralPublic[:], + } + // Encrypt the sessionKey with aes key wrapping. + encryptedSessionKey, err = keywrap.Wrap(encryptionKey, sessionKey) + if err != nil { + return nil, nil, err + } + return ephemeralPublicKey, encryptedSessionKey, nil +} + +// Decrypt decrypts a session key stored in ciphertext with the provided x448 +// private key and ephemeral public key. +func Decrypt(privateKey *PrivateKey, ephemeralPublicKey *PublicKey, ciphertext []byte) (encodedSessionKey []byte, err error) { + var ephemeralPublic, staticPrivate, shared x448lib.Key + // Check that the input ephemeral public key has 56 bytes. + if len(ephemeralPublicKey.Point) != KeySize { + err = errors.KeyInvalidError("x448: the public key has the wrong size") + return nil, err + } + copy(ephemeralPublic[:], ephemeralPublicKey.Point) + subtle.ConstantTimeCopy(1, staticPrivate[:], privateKey.Secret) + // Compute shared key. + ok := x448lib.Shared(&shared, &staticPrivate, &ephemeralPublic) + if !ok { + err = errors.KeyInvalidError("x448: the ephemeral public key is a low order point") + return nil, err + } + // Derive the encryption key from the shared secret. + encryptionKey := applyHKDF(ephemeralPublicKey.Point[:], privateKey.PublicKey.Point[:], shared[:]) + // Decrypt the session key with aes key wrapping. + encodedSessionKey, err = keywrap.Unwrap(encryptionKey, ciphertext) + if err != nil { + return nil, err + } + return encodedSessionKey, nil +} + +func applyHKDF(ephemeralPublicKey []byte, publicKey []byte, sharedSecret []byte) []byte { + inputKey := make([]byte, 3*KeySize) + // ephemeral public key | recipient public key | shared secret. + subtle.ConstantTimeCopy(1, inputKey[:KeySize], ephemeralPublicKey) + subtle.ConstantTimeCopy(1, inputKey[KeySize:2*KeySize], publicKey) + subtle.ConstantTimeCopy(1, inputKey[2*KeySize:], sharedSecret) + hkdfReader := hkdf.New(sha512.New, inputKey, []byte{}, []byte(hkdfInfo)) + encryptionKey := make([]byte, aes256KeySize) + _, _ = io.ReadFull(hkdfReader, encryptionKey) + return encryptionKey +} + +func constantTimeIsZero(bytes []byte) bool { + isZero := byte(0) + for _, b := range bytes { + isZero |= b + } + return isZero == 0 +} + +// ENCODING/DECODING ciphertexts: + +// EncodeFieldsLength returns the length of the ciphertext encoding +// given the encrypted session key. +func EncodedFieldsLength(encryptedSessionKey []byte, v6 bool) int { + lenCipherFunction := 0 + if !v6 { + lenCipherFunction = 1 + } + return KeySize + 1 + len(encryptedSessionKey) + lenCipherFunction +} + +// EncodeField encodes x448 session key encryption fields as +// ephemeral x448 public key | follow byte length | cipherFunction (v3 only) | encryptedSessionKey +// and writes it to writer. +func EncodeFields(writer io.Writer, ephemeralPublicKey *PublicKey, encryptedSessionKey []byte, cipherFunction byte, v6 bool) (err error) { + lenAlgorithm := 0 + if !v6 { + lenAlgorithm = 1 + } + if _, err = writer.Write(ephemeralPublicKey.Point); err != nil { + return err + } + if _, err = writer.Write([]byte{byte(len(encryptedSessionKey) + lenAlgorithm)}); err != nil { + return err + } + if !v6 { + if _, err = writer.Write([]byte{cipherFunction}); err != nil { + return err + } + } + if _, err = writer.Write(encryptedSessionKey); err != nil { + return err + } + return nil +} + +// DecodeField decodes a x448 session key encryption as +// ephemeral x448 public key | follow byte length | cipherFunction (v3 only) | encryptedSessionKey. +func DecodeFields(reader io.Reader, v6 bool) (ephemeralPublicKey *PublicKey, encryptedSessionKey []byte, cipherFunction byte, err error) { + var buf [1]byte + ephemeralPublicKey = &PublicKey{ + Point: make([]byte, KeySize), + } + // 56 octets representing an ephemeral x448 public key. + if _, err = io.ReadFull(reader, ephemeralPublicKey.Point); err != nil { + return nil, nil, 0, err + } + // A one-octet size of the following fields. + if _, err = io.ReadFull(reader, buf[:]); err != nil { + return nil, nil, 0, err + } + followingLen := buf[0] + // The one-octet algorithm identifier, if it was passed (in the case of a v3 PKESK packet). + if !v6 { + if _, err = io.ReadFull(reader, buf[:]); err != nil { + return nil, nil, 0, err + } + cipherFunction = buf[0] + followingLen -= 1 + } + // The encrypted session key. + encryptedSessionKey = make([]byte, followingLen) + if _, err = io.ReadFull(reader, encryptedSessionKey); err != nil { + return nil, nil, 0, err + } + return ephemeralPublicKey, encryptedSessionKey, cipherFunction, nil +} diff --git a/vendor/github.com/cyphar/filepath-securejoin/CHANGELOG.md b/vendor/github.com/cyphar/filepath-securejoin/CHANGELOG.md new file mode 100644 index 0000000000..cb1252b53e --- /dev/null +++ b/vendor/github.com/cyphar/filepath-securejoin/CHANGELOG.md @@ -0,0 +1,209 @@ +# Changelog # +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](http://keepachangelog.com/) +and this project adheres to [Semantic Versioning](http://semver.org/). + +## [Unreleased] ## + +## [0.3.6] - 2024-12-17 ## + +### Compatibility ### +- The minimum Go version requirement for `filepath-securejoin` is now Go 1.18 + (we use generics internally). + + For reference, `filepath-securejoin@v0.3.0` somewhat-arbitrarily bumped the + Go version requirement to 1.21. + + While we did make some use of Go 1.21 stdlib features (and in principle Go + versions <= 1.21 are no longer even supported by upstream anymore), some + downstreams have complained that the version bump has meant that they have to + do workarounds when backporting fixes that use the new `filepath-securejoin` + API onto old branches. This is not an ideal situation, but since using this + library is probably better for most downstreams than a hand-rolled + workaround, we now have compatibility shims that allow us to build on older + Go versions. +- Lower minimum version requirement for `golang.org/x/sys` to `v0.18.0` (we + need the wrappers for `fsconfig(2)`), which should also make backporting + patches to older branches easier. + +## [0.3.5] - 2024-12-06 ## + +### Fixed ### +- `MkdirAll` will now no longer return an `EEXIST` error if two racing + processes are creating the same directory. We will still verify that the path + is a directory, but this will avoid spurious errors when multiple threads or + programs are trying to `MkdirAll` the same path. opencontainers/runc#4543 + +## [0.3.4] - 2024-10-09 ## + +### Fixed ### +- Previously, some testing mocks we had resulted in us doing `import "testing"` + in non-`_test.go` code, which made some downstreams like Kubernetes unhappy. + This has been fixed. (#32) + +## [0.3.3] - 2024-09-30 ## + +### Fixed ### +- The mode and owner verification logic in `MkdirAll` has been removed. This + was originally intended to protect against some theoretical attacks but upon + further consideration these protections don't actually buy us anything and + they were causing spurious errors with more complicated filesystem setups. +- The "is the created directory empty" logic in `MkdirAll` has also been + removed. This was not causing us issues yet, but some pseudofilesystems (such + as `cgroup`) create non-empty directories and so this logic would've been + wrong for such cases. + +## [0.3.2] - 2024-09-13 ## + +### Changed ### +- Passing the `S_ISUID` or `S_ISGID` modes to `MkdirAllInRoot` will now return + an explicit error saying that those bits are ignored by `mkdirat(2)`. In the + past a different error was returned, but since the silent ignoring behaviour + is codified in the man pages a more explicit error seems apt. While silently + ignoring these bits would be the most compatible option, it could lead to + users thinking their code sets these bits when it doesn't. Programs that need + to deal with compatibility can mask the bits themselves. (#23, #25) + +### Fixed ### +- If a directory has `S_ISGID` set, then all child directories will have + `S_ISGID` set when created and a different gid will be used for any inode + created under the directory. Previously, the "expected owner and mode" + validation in `securejoin.MkdirAll` did not correctly handle this. We now + correctly handle this case. (#24, #25) + +## [0.3.1] - 2024-07-23 ## + +### Changed ### +- By allowing `Open(at)InRoot` to opt-out of the extra work done by `MkdirAll` + to do the necessary "partial lookups", `Open(at)InRoot` now does less work + for both implementations (resulting in a many-fold decrease in the number of + operations for `openat2`, and a modest improvement for non-`openat2`) and is + far more guaranteed to match the correct `openat2(RESOLVE_IN_ROOT)` + behaviour. +- We now use `readlinkat(fd, "")` where possible. For `Open(at)InRoot` this + effectively just means that we no longer risk getting spurious errors during + rename races. However, for our hardened procfs handler, this in theory should + prevent mount attacks from tricking us when doing magic-link readlinks (even + when using the unsafe host `/proc` handle). Unfortunately `Reopen` is still + potentially vulnerable to those kinds of somewhat-esoteric attacks. + + Technically this [will only work on post-2.6.39 kernels][linux-readlinkat-emptypath] + but it seems incredibly unlikely anyone is using `filepath-securejoin` on a + pre-2011 kernel. + +### Fixed ### +- Several improvements were made to the errors returned by `Open(at)InRoot` and + `MkdirAll` when dealing with invalid paths under the emulated (ie. + non-`openat2`) implementation. Previously, some paths would return the wrong + error (`ENOENT` when the last component was a non-directory), and other paths + would be returned as though they were acceptable (trailing-slash components + after a non-directory would be ignored by `Open(at)InRoot`). + + These changes were done to match `openat2`'s behaviour and purely is a + consistency fix (most users are going to be using `openat2` anyway). + +[linux-readlinkat-emptypath]: https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/commit/?id=65cfc6722361570bfe255698d9cd4dccaf47570d + +## [0.3.0] - 2024-07-11 ## + +### Added ### +- A new set of `*os.File`-based APIs have been added. These are adapted from + [libpathrs][] and we strongly suggest using them if possible (as they provide + far more protection against attacks than `SecureJoin`): + + - `Open(at)InRoot` resolves a path inside a rootfs and returns an `*os.File` + handle to the path. Note that the handle returned is an `O_PATH` handle, + which cannot be used for reading or writing (as well as some other + operations -- [see open(2) for more details][open.2]) + + - `Reopen` takes an `O_PATH` file handle and safely re-opens it to upgrade + it to a regular handle. This can also be used with non-`O_PATH` handles, + but `O_PATH` is the most obvious application. + + - `MkdirAll` is an implementation of `os.MkdirAll` that is safe to use to + create a directory tree within a rootfs. + + As these are new APIs, they may change in the future. However, they should be + safe to start migrating to as we have extensive tests ensuring they behave + correctly and are safe against various races and other attacks. + +[libpathrs]: https://github.com/openSUSE/libpathrs +[open.2]: https://www.man7.org/linux/man-pages/man2/open.2.html + +## [0.2.5] - 2024-05-03 ## + +### Changed ### +- Some minor changes were made to how lexical components (like `..` and `.`) + are handled during path generation in `SecureJoin`. There is no behaviour + change as a result of this fix (the resulting paths are the same). + +### Fixed ### +- The error returned when we hit a symlink loop now references the correct + path. (#10) + +## [0.2.4] - 2023-09-06 ## + +### Security ### +- This release fixes a potential security issue in filepath-securejoin when + used on Windows ([GHSA-6xv5-86q9-7xr8][], which could be used to generate + paths outside of the provided rootfs in certain cases), as well as improving + the overall behaviour of filepath-securejoin when dealing with Windows paths + that contain volume names. Thanks to Paulo Gomes for discovering and fixing + these issues. + +### Fixed ### +- Switch to GitHub Actions for CI so we can test on Windows as well as Linux + and MacOS. + +[GHSA-6xv5-86q9-7xr8]: https://github.com/advisories/GHSA-6xv5-86q9-7xr8 + +## [0.2.3] - 2021-06-04 ## + +### Changed ### +- Switch to Go 1.13-style `%w` error wrapping, letting us drop the dependency + on `github.com/pkg/errors`. + +## [0.2.2] - 2018-09-05 ## + +### Changed ### +- Use `syscall.ELOOP` as the base error for symlink loops, rather than our own + (internal) error. This allows callers to more easily use `errors.Is` to check + for this case. + +## [0.2.1] - 2018-09-05 ## + +### Fixed ### +- Use our own `IsNotExist` implementation, which lets us handle `ENOTDIR` + properly within `SecureJoin`. + +## [0.2.0] - 2017-07-19 ## + +We now have 100% test coverage! + +### Added ### +- Add a `SecureJoinVFS` API that can be used for mocking (as we do in our new + tests) or for implementing custom handling of lookup operations (such as for + rootless containers, where work is necessary to access directories with weird + modes because we don't have `CAP_DAC_READ_SEARCH` or `CAP_DAC_OVERRIDE`). + +## 0.1.0 - 2017-07-19 + +This is our first release of `github.com/cyphar/filepath-securejoin`, +containing a full implementation with a coverage of 93.5% (the only missing +cases are the error cases, which are hard to mocktest at the moment). + +[Unreleased]: https://github.com/cyphar/filepath-securejoin/compare/v0.3.6...HEAD +[0.3.6]: https://github.com/cyphar/filepath-securejoin/compare/v0.3.5...v0.3.6 +[0.3.5]: https://github.com/cyphar/filepath-securejoin/compare/v0.3.4...v0.3.5 +[0.3.4]: https://github.com/cyphar/filepath-securejoin/compare/v0.3.3...v0.3.4 +[0.3.3]: https://github.com/cyphar/filepath-securejoin/compare/v0.3.2...v0.3.3 +[0.3.2]: https://github.com/cyphar/filepath-securejoin/compare/v0.3.1...v0.3.2 +[0.3.1]: https://github.com/cyphar/filepath-securejoin/compare/v0.3.0...v0.3.1 +[0.3.0]: https://github.com/cyphar/filepath-securejoin/compare/v0.2.5...v0.3.0 +[0.2.5]: https://github.com/cyphar/filepath-securejoin/compare/v0.2.4...v0.2.5 +[0.2.4]: https://github.com/cyphar/filepath-securejoin/compare/v0.2.3...v0.2.4 +[0.2.3]: https://github.com/cyphar/filepath-securejoin/compare/v0.2.2...v0.2.3 +[0.2.2]: https://github.com/cyphar/filepath-securejoin/compare/v0.2.1...v0.2.2 +[0.2.1]: https://github.com/cyphar/filepath-securejoin/compare/v0.2.0...v0.2.1 +[0.2.0]: https://github.com/cyphar/filepath-securejoin/compare/v0.1.0...v0.2.0 diff --git a/vendor/github.com/cyphar/filepath-securejoin/LICENSE b/vendor/github.com/cyphar/filepath-securejoin/LICENSE index bec842f294..cb1ab88da0 100644 --- a/vendor/github.com/cyphar/filepath-securejoin/LICENSE +++ b/vendor/github.com/cyphar/filepath-securejoin/LICENSE @@ -1,5 +1,5 @@ Copyright (C) 2014-2015 Docker Inc & Go Authors. All rights reserved. -Copyright (C) 2017 SUSE LLC. All rights reserved. +Copyright (C) 2017-2024 SUSE LLC. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/vendor/github.com/cyphar/filepath-securejoin/README.md b/vendor/github.com/cyphar/filepath-securejoin/README.md index 4eca0f2355..eaeb53fcd0 100644 --- a/vendor/github.com/cyphar/filepath-securejoin/README.md +++ b/vendor/github.com/cyphar/filepath-securejoin/README.md @@ -1,32 +1,26 @@ ## `filepath-securejoin` ## +[![Go Documentation](https://pkg.go.dev/badge/github.com/cyphar/filepath-securejoin.svg)](https://pkg.go.dev/github.com/cyphar/filepath-securejoin) [![Build Status](https://github.com/cyphar/filepath-securejoin/actions/workflows/ci.yml/badge.svg)](https://github.com/cyphar/filepath-securejoin/actions/workflows/ci.yml) -An implementation of `SecureJoin`, a [candidate for inclusion in the Go -standard library][go#20126]. The purpose of this function is to be a "secure" -alternative to `filepath.Join`, and in particular it provides certain -guarantees that are not provided by `filepath.Join`. - -> **NOTE**: This code is *only* safe if you are not at risk of other processes -> modifying path components after you've used `SecureJoin`. If it is possible -> for a malicious process to modify path components of the resolved path, then -> you will be vulnerable to some fairly trivial TOCTOU race conditions. [There -> are some Linux kernel patches I'm working on which might allow for a better -> solution.][lwn-obeneath] -> -> In addition, with a slightly modified API it might be possible to use -> `O_PATH` and verify that the opened path is actually the resolved one -- but -> I have not done that yet. I might add it in the future as a helper function -> to help users verify the path (we can't just return `/proc/self/fd/` -> because that doesn't always work transparently for all users). - -This is the function prototype: +### Old API ### -```go -func SecureJoin(root, unsafePath string) (string, error) -``` +This library was originally just an implementation of `SecureJoin` which was +[intended to be included in the Go standard library][go#20126] as a safer +`filepath.Join` that would restrict the path lookup to be inside a root +directory. + +The implementation was based on code that existed in several container +runtimes. Unfortunately, this API is **fundamentally unsafe** against attackers +that can modify path components after `SecureJoin` returns and before the +caller uses the path, allowing for some fairly trivial TOCTOU attacks. + +`SecureJoin` (and `SecureJoinVFS`) are still provided by this library to +support legacy users, but new users are strongly suggested to avoid using +`SecureJoin` and instead use the [new api](#new-api) or switch to +[libpathrs][libpathrs]. -This library **guarantees** the following: +With the above limitations in mind, this library guarantees the following: * If no error is set, the resulting string **must** be a child path of `root` and will not contain any symlink path components (they will all be @@ -47,7 +41,7 @@ This library **guarantees** the following: A (trivial) implementation of this function on GNU/Linux systems could be done with the following (note that this requires root privileges and is far more opaque than the implementation in this library, and also requires that -`readlink` is inside the `root` path): +`readlink` is inside the `root` path and is trustworthy): ```go package securejoin @@ -70,9 +64,105 @@ func SecureJoin(root, unsafePath string) (string, error) { } ``` -[lwn-obeneath]: https://lwn.net/Articles/767547/ +[libpathrs]: https://github.com/openSUSE/libpathrs [go#20126]: https://github.com/golang/go/issues/20126 +### New API ### + +While we recommend users switch to [libpathrs][libpathrs] as soon as it has a +stable release, some methods implemented by libpathrs have been ported to this +library to ease the transition. These APIs are only supported on Linux. + +These APIs are implemented such that `filepath-securejoin` will +opportunistically use certain newer kernel APIs that make these operations far +more secure. In particular: + +* All of the lookup operations will use [`openat2`][openat2.2] on new enough + kernels (Linux 5.6 or later) to restrict lookups through magic-links and + bind-mounts (for certain operations) and to make use of `RESOLVE_IN_ROOT` to + efficiently resolve symlinks within a rootfs. + +* The APIs provide hardening against a malicious `/proc` mount to either detect + or avoid being tricked by a `/proc` that is not legitimate. This is done + using [`openat2`][openat2.2] for all users, and privileged users will also be + further protected by using [`fsopen`][fsopen.2] and [`open_tree`][open_tree.2] + (Linux 5.2 or later). + +[openat2.2]: https://www.man7.org/linux/man-pages/man2/openat2.2.html +[fsopen.2]: https://github.com/brauner/man-pages-md/blob/main/fsopen.md +[open_tree.2]: https://github.com/brauner/man-pages-md/blob/main/open_tree.md + +#### `OpenInRoot` #### + +```go +func OpenInRoot(root, unsafePath string) (*os.File, error) +func OpenatInRoot(root *os.File, unsafePath string) (*os.File, error) +func Reopen(handle *os.File, flags int) (*os.File, error) +``` + +`OpenInRoot` is a much safer version of + +```go +path, err := securejoin.SecureJoin(root, unsafePath) +file, err := os.OpenFile(path, unix.O_PATH|unix.O_CLOEXEC) +``` + +that protects against various race attacks that could lead to serious security +issues, depending on the application. Note that the returned `*os.File` is an +`O_PATH` file descriptor, which is quite restricted. Callers will probably need +to use `Reopen` to get a more usable handle (this split is done to provide +useful features like PTY spawning and to avoid users accidentally opening bad +inodes that could cause a DoS). + +Callers need to be careful in how they use the returned `*os.File`. Usually it +is only safe to operate on the handle directly, and it is very easy to create a +security issue. [libpathrs][libpathrs] provides far more helpers to make using +these handles safer -- there is currently no plan to port them to +`filepath-securejoin`. + +`OpenatInRoot` is like `OpenInRoot` except that the root is provided using an +`*os.File`. This allows you to ensure that multiple `OpenatInRoot` (or +`MkdirAllHandle`) calls are operating on the same rootfs. + +> **NOTE**: Unlike `SecureJoin`, `OpenInRoot` will error out as soon as it hits +> a dangling symlink or non-existent path. This is in contrast to `SecureJoin` +> which treated non-existent components as though they were real directories, +> and would allow for partial resolution of dangling symlinks. These behaviours +> are at odds with how Linux treats non-existent paths and dangling symlinks, +> and so these are no longer allowed. + +#### `MkdirAll` #### + +```go +func MkdirAll(root, unsafePath string, mode int) error +func MkdirAllHandle(root *os.File, unsafePath string, mode int) (*os.File, error) +``` + +`MkdirAll` is a much safer version of + +```go +path, err := securejoin.SecureJoin(root, unsafePath) +err = os.MkdirAll(path, mode) +``` + +that protects against the same kinds of races that `OpenInRoot` protects +against. + +`MkdirAllHandle` is like `MkdirAll` except that the root is provided using an +`*os.File` (the reason for this is the same as with `OpenatInRoot`) and an +`*os.File` of the final created directory is returned (this directory is +guaranteed to be effectively identical to the directory created by +`MkdirAllHandle`, which is not possible to ensure by just using `OpenatInRoot` +after `MkdirAll`). + +> **NOTE**: Unlike `SecureJoin`, `MkdirAll` will error out as soon as it hits +> a dangling symlink or non-existent path. This is in contrast to `SecureJoin` +> which treated non-existent components as though they were real directories, +> and would allow for partial resolution of dangling symlinks. These behaviours +> are at odds with how Linux treats non-existent paths and dangling symlinks, +> and so these are no longer allowed. This means that `MkdirAll` will not +> create non-existent directories referenced by a dangling symlink. + ### License ### The license of this project is the same as Go, which is a BSD 3-clause license diff --git a/vendor/github.com/cyphar/filepath-securejoin/VERSION b/vendor/github.com/cyphar/filepath-securejoin/VERSION index abd410582d..449d7e73a9 100644 --- a/vendor/github.com/cyphar/filepath-securejoin/VERSION +++ b/vendor/github.com/cyphar/filepath-securejoin/VERSION @@ -1 +1 @@ -0.2.4 +0.3.6 diff --git a/vendor/github.com/cyphar/filepath-securejoin/doc.go b/vendor/github.com/cyphar/filepath-securejoin/doc.go new file mode 100644 index 0000000000..1ec7d065ef --- /dev/null +++ b/vendor/github.com/cyphar/filepath-securejoin/doc.go @@ -0,0 +1,39 @@ +// Copyright (C) 2014-2015 Docker Inc & Go Authors. All rights reserved. +// Copyright (C) 2017-2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package securejoin implements a set of helpers to make it easier to write Go +// code that is safe against symlink-related escape attacks. The primary idea +// is to let you resolve a path within a rootfs directory as if the rootfs was +// a chroot. +// +// securejoin has two APIs, a "legacy" API and a "modern" API. +// +// The legacy API is [SecureJoin] and [SecureJoinVFS]. These methods are +// **not** safe against race conditions where an attacker changes the +// filesystem after (or during) the [SecureJoin] operation. +// +// The new API is made up of [OpenInRoot] and [MkdirAll] (and derived +// functions). These are safe against racing attackers and have several other +// protections that are not provided by the legacy API. There are many more +// operations that most programs expect to be able to do safely, but we do not +// provide explicit support for them because we want to encourage users to +// switch to [libpathrs](https://github.com/openSUSE/libpathrs) which is a +// cross-language next-generation library that is entirely designed around +// operating on paths safely. +// +// securejoin has been used by several container runtimes (Docker, runc, +// Kubernetes, etc) for quite a few years as a de-facto standard for operating +// on container filesystem paths "safely". However, most users still use the +// legacy API which is unsafe against various attacks (there is a fairly long +// history of CVEs in dependent as a result). Users should switch to the modern +// API as soon as possible (or even better, switch to libpathrs). +// +// This project was initially intended to be included in the Go standard +// library, but [it was rejected](https://go.dev/issue/20126). There is now a +// [new Go proposal](https://go.dev/issue/67002) for a safe path resolution API +// that shares some of the goals of filepath-securejoin. However, that design +// is intended to work like `openat2(RESOLVE_BENEATH)` which does not fit the +// usecase of container runtimes and most system tools. +package securejoin diff --git a/vendor/github.com/cyphar/filepath-securejoin/gocompat_errors_go120.go b/vendor/github.com/cyphar/filepath-securejoin/gocompat_errors_go120.go new file mode 100644 index 0000000000..42452bbf9b --- /dev/null +++ b/vendor/github.com/cyphar/filepath-securejoin/gocompat_errors_go120.go @@ -0,0 +1,18 @@ +//go:build linux && go1.20 + +// Copyright (C) 2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "fmt" +) + +// wrapBaseError is a helper that is equivalent to fmt.Errorf("%w: %w"), except +// that on pre-1.20 Go versions only errors.Is() works properly (errors.Unwrap) +// is only guaranteed to give you baseErr. +func wrapBaseError(baseErr, extraErr error) error { + return fmt.Errorf("%w: %w", extraErr, baseErr) +} diff --git a/vendor/github.com/cyphar/filepath-securejoin/gocompat_errors_unsupported.go b/vendor/github.com/cyphar/filepath-securejoin/gocompat_errors_unsupported.go new file mode 100644 index 0000000000..e7adca3fd1 --- /dev/null +++ b/vendor/github.com/cyphar/filepath-securejoin/gocompat_errors_unsupported.go @@ -0,0 +1,38 @@ +//go:build linux && !go1.20 + +// Copyright (C) 2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "fmt" +) + +type wrappedError struct { + inner error + isError error +} + +func (err wrappedError) Is(target error) bool { + return err.isError == target +} + +func (err wrappedError) Unwrap() error { + return err.inner +} + +func (err wrappedError) Error() string { + return fmt.Sprintf("%v: %v", err.isError, err.inner) +} + +// wrapBaseError is a helper that is equivalent to fmt.Errorf("%w: %w"), except +// that on pre-1.20 Go versions only errors.Is() works properly (errors.Unwrap) +// is only guaranteed to give you baseErr. +func wrapBaseError(baseErr, extraErr error) error { + return wrappedError{ + inner: baseErr, + isError: extraErr, + } +} diff --git a/vendor/github.com/cyphar/filepath-securejoin/gocompat_generics_go121.go b/vendor/github.com/cyphar/filepath-securejoin/gocompat_generics_go121.go new file mode 100644 index 0000000000..ddd6fa9a41 --- /dev/null +++ b/vendor/github.com/cyphar/filepath-securejoin/gocompat_generics_go121.go @@ -0,0 +1,32 @@ +//go:build linux && go1.21 + +// Copyright (C) 2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "slices" + "sync" +) + +func slices_DeleteFunc[S ~[]E, E any](slice S, delFn func(E) bool) S { + return slices.DeleteFunc(slice, delFn) +} + +func slices_Contains[S ~[]E, E comparable](slice S, val E) bool { + return slices.Contains(slice, val) +} + +func slices_Clone[S ~[]E, E any](slice S) S { + return slices.Clone(slice) +} + +func sync_OnceValue[T any](f func() T) func() T { + return sync.OnceValue(f) +} + +func sync_OnceValues[T1, T2 any](f func() (T1, T2)) func() (T1, T2) { + return sync.OnceValues(f) +} diff --git a/vendor/github.com/cyphar/filepath-securejoin/gocompat_generics_unsupported.go b/vendor/github.com/cyphar/filepath-securejoin/gocompat_generics_unsupported.go new file mode 100644 index 0000000000..f1e6fe7e71 --- /dev/null +++ b/vendor/github.com/cyphar/filepath-securejoin/gocompat_generics_unsupported.go @@ -0,0 +1,124 @@ +//go:build linux && !go1.21 + +// Copyright (C) 2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "sync" +) + +// These are very minimal implementations of functions that appear in Go 1.21's +// stdlib, included so that we can build on older Go versions. Most are +// borrowed directly from the stdlib, and a few are modified to be "obviously +// correct" without needing to copy too many other helpers. + +// clearSlice is equivalent to the builtin clear from Go 1.21. +// Copied from the Go 1.24 stdlib implementation. +func clearSlice[S ~[]E, E any](slice S) { + var zero E + for i := range slice { + slice[i] = zero + } +} + +// Copied from the Go 1.24 stdlib implementation. +func slices_IndexFunc[S ~[]E, E any](s S, f func(E) bool) int { + for i := range s { + if f(s[i]) { + return i + } + } + return -1 +} + +// Copied from the Go 1.24 stdlib implementation. +func slices_DeleteFunc[S ~[]E, E any](s S, del func(E) bool) S { + i := slices_IndexFunc(s, del) + if i == -1 { + return s + } + // Don't start copying elements until we find one to delete. + for j := i + 1; j < len(s); j++ { + if v := s[j]; !del(v) { + s[i] = v + i++ + } + } + clearSlice(s[i:]) // zero/nil out the obsolete elements, for GC + return s[:i] +} + +// Similar to the stdlib slices.Contains, except that we don't have +// slices.Index so we need to use slices.IndexFunc for this non-Func helper. +func slices_Contains[S ~[]E, E comparable](s S, v E) bool { + return slices_IndexFunc(s, func(e E) bool { return e == v }) >= 0 +} + +// Copied from the Go 1.24 stdlib implementation. +func slices_Clone[S ~[]E, E any](s S) S { + // Preserve nil in case it matters. + if s == nil { + return nil + } + return append(S([]E{}), s...) +} + +// Copied from the Go 1.24 stdlib implementation. +func sync_OnceValue[T any](f func() T) func() T { + var ( + once sync.Once + valid bool + p any + result T + ) + g := func() { + defer func() { + p = recover() + if !valid { + panic(p) + } + }() + result = f() + f = nil + valid = true + } + return func() T { + once.Do(g) + if !valid { + panic(p) + } + return result + } +} + +// Copied from the Go 1.24 stdlib implementation. +func sync_OnceValues[T1, T2 any](f func() (T1, T2)) func() (T1, T2) { + var ( + once sync.Once + valid bool + p any + r1 T1 + r2 T2 + ) + g := func() { + defer func() { + p = recover() + if !valid { + panic(p) + } + }() + r1, r2 = f() + f = nil + valid = true + } + return func() (T1, T2) { + once.Do(g) + if !valid { + panic(p) + } + return r1, r2 + } +} diff --git a/vendor/github.com/cyphar/filepath-securejoin/join.go b/vendor/github.com/cyphar/filepath-securejoin/join.go index aa32b85fb8..e0ee3f2b57 100644 --- a/vendor/github.com/cyphar/filepath-securejoin/join.go +++ b/vendor/github.com/cyphar/filepath-securejoin/join.go @@ -1,17 +1,11 @@ // Copyright (C) 2014-2015 Docker Inc & Go Authors. All rights reserved. -// Copyright (C) 2017 SUSE LLC. All rights reserved. +// Copyright (C) 2017-2024 SUSE LLC. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package securejoin is an implementation of the hopefully-soon-to-be-included -// SecureJoin helper that is meant to be part of the "path/filepath" package. -// The purpose of this project is to provide a PoC implementation to make the -// SecureJoin proposal (https://github.com/golang/go/issues/20126) more -// tangible. package securejoin import ( - "bytes" "errors" "os" "path/filepath" @@ -19,26 +13,34 @@ import ( "syscall" ) +const maxSymlinkLimit = 255 + // IsNotExist tells you if err is an error that implies that either the path // accessed does not exist (or path components don't exist). This is -// effectively a more broad version of os.IsNotExist. +// effectively a more broad version of [os.IsNotExist]. func IsNotExist(err error) bool { // Check that it's not actually an ENOTDIR, which in some cases is a more // convoluted case of ENOENT (usually involving weird paths). return errors.Is(err, os.ErrNotExist) || errors.Is(err, syscall.ENOTDIR) || errors.Is(err, syscall.ENOENT) } -// SecureJoinVFS joins the two given path components (similar to Join) except +// SecureJoinVFS joins the two given path components (similar to [filepath.Join]) except // that the returned path is guaranteed to be scoped inside the provided root // path (when evaluated). Any symbolic links in the path are evaluated with the // given root treated as the root of the filesystem, similar to a chroot. The -// filesystem state is evaluated through the given VFS interface (if nil, the -// standard os.* family of functions are used). +// filesystem state is evaluated through the given [VFS] interface (if nil, the +// standard [os].* family of functions are used). // // Note that the guarantees provided by this function only apply if the path // components in the returned string are not modified (in other words are not // replaced with symlinks on the filesystem) after this function has returned. -// Such a symlink race is necessarily out-of-scope of SecureJoin. +// Such a symlink race is necessarily out-of-scope of SecureJoinVFS. +// +// NOTE: Due to the above limitation, Linux users are strongly encouraged to +// use [OpenInRoot] instead, which does safely protect against these kinds of +// attacks. There is no way to solve this problem with SecureJoinVFS because +// the API is fundamentally wrong (you cannot return a "safe" path string and +// guarantee it won't be modified afterwards). // // Volume names in unsafePath are always discarded, regardless if they are // provided via direct input or when evaluating symlinks. Therefore: @@ -51,75 +53,73 @@ func SecureJoinVFS(root, unsafePath string, vfs VFS) (string, error) { } unsafePath = filepath.FromSlash(unsafePath) - var path bytes.Buffer - n := 0 - for unsafePath != "" { - if n > 255 { - return "", &os.PathError{Op: "SecureJoin", Path: root + string(filepath.Separator) + unsafePath, Err: syscall.ELOOP} + var ( + currentPath string + remainingPath = unsafePath + linksWalked int + ) + for remainingPath != "" { + if v := filepath.VolumeName(remainingPath); v != "" { + remainingPath = remainingPath[len(v):] } - if v := filepath.VolumeName(unsafePath); v != "" { - unsafePath = unsafePath[len(v):] - } - - // Next path component, p. - i := strings.IndexRune(unsafePath, filepath.Separator) - var p string - if i == -1 { - p, unsafePath = unsafePath, "" + // Get the next path component. + var part string + if i := strings.IndexRune(remainingPath, filepath.Separator); i == -1 { + part, remainingPath = remainingPath, "" } else { - p, unsafePath = unsafePath[:i], unsafePath[i+1:] + part, remainingPath = remainingPath[:i], remainingPath[i+1:] } - // Create a cleaned path, using the lexical semantics of /../a, to - // create a "scoped" path component which can safely be joined to fullP - // for evaluation. At this point, path.String() doesn't contain any - // symlink components. - cleanP := filepath.Clean(string(filepath.Separator) + path.String() + p) - if cleanP == string(filepath.Separator) { - path.Reset() + // Apply the component lexically to the path we are building. + // currentPath does not contain any symlinks, and we are lexically + // dealing with a single component, so it's okay to do a filepath.Clean + // here. + nextPath := filepath.Join(string(filepath.Separator), currentPath, part) + if nextPath == string(filepath.Separator) { + currentPath = "" continue } - fullP := filepath.Clean(root + cleanP) + fullPath := root + string(filepath.Separator) + nextPath // Figure out whether the path is a symlink. - fi, err := vfs.Lstat(fullP) + fi, err := vfs.Lstat(fullPath) if err != nil && !IsNotExist(err) { return "", err } // Treat non-existent path components the same as non-symlinks (we // can't do any better here). if IsNotExist(err) || fi.Mode()&os.ModeSymlink == 0 { - path.WriteString(p) - path.WriteRune(filepath.Separator) + currentPath = nextPath continue } - // Only increment when we actually dereference a link. - n++ + // It's a symlink, so get its contents and expand it by prepending it + // to the yet-unparsed path. + linksWalked++ + if linksWalked > maxSymlinkLimit { + return "", &os.PathError{Op: "SecureJoin", Path: root + string(filepath.Separator) + unsafePath, Err: syscall.ELOOP} + } - // It's a symlink, expand it by prepending it to the yet-unparsed path. - dest, err := vfs.Readlink(fullP) + dest, err := vfs.Readlink(fullPath) if err != nil { return "", err } + remainingPath = dest + string(filepath.Separator) + remainingPath // Absolute symlinks reset any work we've already done. if filepath.IsAbs(dest) { - path.Reset() + currentPath = "" } - unsafePath = dest + string(filepath.Separator) + unsafePath } - // We have to clean path.String() here because it may contain '..' - // components that are entirely lexical, but would be misleading otherwise. - // And finally do a final clean to ensure that root is also lexically - // clean. - fullP := filepath.Clean(string(filepath.Separator) + path.String()) - return filepath.Clean(root + fullP), nil + // There should be no lexical components like ".." left in the path here, + // but for safety clean up the path before joining it to the root. + finalPath := filepath.Join(string(filepath.Separator), currentPath) + return filepath.Join(root, finalPath), nil } -// SecureJoin is a wrapper around SecureJoinVFS that just uses the os.* library -// of functions as the VFS. If in doubt, use this function over SecureJoinVFS. +// SecureJoin is a wrapper around [SecureJoinVFS] that just uses the [os].* library +// of functions as the [VFS]. If in doubt, use this function over [SecureJoinVFS]. func SecureJoin(root, unsafePath string) (string, error) { return SecureJoinVFS(root, unsafePath, nil) } diff --git a/vendor/github.com/cyphar/filepath-securejoin/lookup_linux.go b/vendor/github.com/cyphar/filepath-securejoin/lookup_linux.go new file mode 100644 index 0000000000..be81e498d7 --- /dev/null +++ b/vendor/github.com/cyphar/filepath-securejoin/lookup_linux.go @@ -0,0 +1,388 @@ +//go:build linux + +// Copyright (C) 2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "errors" + "fmt" + "os" + "path" + "path/filepath" + "strings" + + "golang.org/x/sys/unix" +) + +type symlinkStackEntry struct { + // (dir, remainingPath) is what we would've returned if the link didn't + // exist. This matches what openat2(RESOLVE_IN_ROOT) would return in + // this case. + dir *os.File + remainingPath string + // linkUnwalked is the remaining path components from the original + // Readlink which we have yet to walk. When this slice is empty, we + // drop the link from the stack. + linkUnwalked []string +} + +func (se symlinkStackEntry) String() string { + return fmt.Sprintf("<%s>/%s [->%s]", se.dir.Name(), se.remainingPath, strings.Join(se.linkUnwalked, "/")) +} + +func (se symlinkStackEntry) Close() { + _ = se.dir.Close() +} + +type symlinkStack []*symlinkStackEntry + +func (s *symlinkStack) IsEmpty() bool { + return s == nil || len(*s) == 0 +} + +func (s *symlinkStack) Close() { + if s != nil { + for _, link := range *s { + link.Close() + } + // TODO: Switch to clear once we switch to Go 1.21. + *s = nil + } +} + +var ( + errEmptyStack = errors.New("[internal] stack is empty") + errBrokenSymlinkStack = errors.New("[internal error] broken symlink stack") +) + +func (s *symlinkStack) popPart(part string) error { + if s == nil || s.IsEmpty() { + // If there is nothing in the symlink stack, then the part was from the + // real path provided by the user, and this is a no-op. + return errEmptyStack + } + if part == "." { + // "." components are no-ops -- we drop them when doing SwapLink. + return nil + } + + tailEntry := (*s)[len(*s)-1] + + // Double-check that we are popping the component we expect. + if len(tailEntry.linkUnwalked) == 0 { + return fmt.Errorf("%w: trying to pop component %q of empty stack entry %s", errBrokenSymlinkStack, part, tailEntry) + } + headPart := tailEntry.linkUnwalked[0] + if headPart != part { + return fmt.Errorf("%w: trying to pop component %q but the last stack entry is %s (%q)", errBrokenSymlinkStack, part, tailEntry, headPart) + } + + // Drop the component, but keep the entry around in case we are dealing + // with a "tail-chained" symlink. + tailEntry.linkUnwalked = tailEntry.linkUnwalked[1:] + return nil +} + +func (s *symlinkStack) PopPart(part string) error { + if err := s.popPart(part); err != nil { + if errors.Is(err, errEmptyStack) { + // Skip empty stacks. + err = nil + } + return err + } + + // Clean up any of the trailing stack entries that are empty. + for lastGood := len(*s) - 1; lastGood >= 0; lastGood-- { + entry := (*s)[lastGood] + if len(entry.linkUnwalked) > 0 { + break + } + entry.Close() + (*s) = (*s)[:lastGood] + } + return nil +} + +func (s *symlinkStack) push(dir *os.File, remainingPath, linkTarget string) error { + if s == nil { + return nil + } + // Split the link target and clean up any "" parts. + linkTargetParts := slices_DeleteFunc( + strings.Split(linkTarget, "/"), + func(part string) bool { return part == "" || part == "." }) + + // Copy the directory so the caller doesn't close our copy. + dirCopy, err := dupFile(dir) + if err != nil { + return err + } + + // Add to the stack. + *s = append(*s, &symlinkStackEntry{ + dir: dirCopy, + remainingPath: remainingPath, + linkUnwalked: linkTargetParts, + }) + return nil +} + +func (s *symlinkStack) SwapLink(linkPart string, dir *os.File, remainingPath, linkTarget string) error { + // If we are currently inside a symlink resolution, remove the symlink + // component from the last symlink entry, but don't remove the entry even + // if it's empty. If we are a "tail-chained" symlink (a trailing symlink we + // hit during a symlink resolution) we need to keep the old symlink until + // we finish the resolution. + if err := s.popPart(linkPart); err != nil { + if !errors.Is(err, errEmptyStack) { + return err + } + // Push the component regardless of whether the stack was empty. + } + return s.push(dir, remainingPath, linkTarget) +} + +func (s *symlinkStack) PopTopSymlink() (*os.File, string, bool) { + if s == nil || s.IsEmpty() { + return nil, "", false + } + tailEntry := (*s)[0] + *s = (*s)[1:] + return tailEntry.dir, tailEntry.remainingPath, true +} + +// partialLookupInRoot tries to lookup as much of the request path as possible +// within the provided root (a-la RESOLVE_IN_ROOT) and opens the final existing +// component of the requested path, returning a file handle to the final +// existing component and a string containing the remaining path components. +func partialLookupInRoot(root *os.File, unsafePath string) (*os.File, string, error) { + return lookupInRoot(root, unsafePath, true) +} + +func completeLookupInRoot(root *os.File, unsafePath string) (*os.File, error) { + handle, remainingPath, err := lookupInRoot(root, unsafePath, false) + if remainingPath != "" && err == nil { + // should never happen + err = fmt.Errorf("[bug] non-empty remaining path when doing a non-partial lookup: %q", remainingPath) + } + // lookupInRoot(partial=false) will always close the handle if an error is + // returned, so no need to double-check here. + return handle, err +} + +func lookupInRoot(root *os.File, unsafePath string, partial bool) (Handle *os.File, _ string, _ error) { + unsafePath = filepath.ToSlash(unsafePath) // noop + + // This is very similar to SecureJoin, except that we operate on the + // components using file descriptors. We then return the last component we + // managed open, along with the remaining path components not opened. + + // Try to use openat2 if possible. + if hasOpenat2() { + return lookupOpenat2(root, unsafePath, partial) + } + + // Get the "actual" root path from /proc/self/fd. This is necessary if the + // root is some magic-link like /proc/$pid/root, in which case we want to + // make sure when we do checkProcSelfFdPath that we are using the correct + // root path. + logicalRootPath, err := procSelfFdReadlink(root) + if err != nil { + return nil, "", fmt.Errorf("get real root path: %w", err) + } + + currentDir, err := dupFile(root) + if err != nil { + return nil, "", fmt.Errorf("clone root fd: %w", err) + } + defer func() { + // If a handle is not returned, close the internal handle. + if Handle == nil { + _ = currentDir.Close() + } + }() + + // symlinkStack is used to emulate how openat2(RESOLVE_IN_ROOT) treats + // dangling symlinks. If we hit a non-existent path while resolving a + // symlink, we need to return the (dir, remainingPath) that we had when we + // hit the symlink (treating the symlink as though it were a regular file). + // The set of (dir, remainingPath) sets is stored within the symlinkStack + // and we add and remove parts when we hit symlink and non-symlink + // components respectively. We need a stack because of recursive symlinks + // (symlinks that contain symlink components in their target). + // + // Note that the stack is ONLY used for book-keeping. All of the actual + // path walking logic is still based on currentPath/remainingPath and + // currentDir (as in SecureJoin). + var symStack *symlinkStack + if partial { + symStack = new(symlinkStack) + defer symStack.Close() + } + + var ( + linksWalked int + currentPath string + remainingPath = unsafePath + ) + for remainingPath != "" { + // Save the current remaining path so if the part is not real we can + // return the path including the component. + oldRemainingPath := remainingPath + + // Get the next path component. + var part string + if i := strings.IndexByte(remainingPath, '/'); i == -1 { + part, remainingPath = remainingPath, "" + } else { + part, remainingPath = remainingPath[:i], remainingPath[i+1:] + } + // If we hit an empty component, we need to treat it as though it is + // "." so that trailing "/" and "//" components on a non-directory + // correctly return the right error code. + if part == "" { + part = "." + } + + // Apply the component lexically to the path we are building. + // currentPath does not contain any symlinks, and we are lexically + // dealing with a single component, so it's okay to do a filepath.Clean + // here. + nextPath := path.Join("/", currentPath, part) + // If we logically hit the root, just clone the root rather than + // opening the part and doing all of the other checks. + if nextPath == "/" { + if err := symStack.PopPart(part); err != nil { + return nil, "", fmt.Errorf("walking into root with part %q failed: %w", part, err) + } + // Jump to root. + rootClone, err := dupFile(root) + if err != nil { + return nil, "", fmt.Errorf("clone root fd: %w", err) + } + _ = currentDir.Close() + currentDir = rootClone + currentPath = nextPath + continue + } + + // Try to open the next component. + nextDir, err := openatFile(currentDir, part, unix.O_PATH|unix.O_NOFOLLOW|unix.O_CLOEXEC, 0) + switch { + case err == nil: + st, err := nextDir.Stat() + if err != nil { + _ = nextDir.Close() + return nil, "", fmt.Errorf("stat component %q: %w", part, err) + } + + switch st.Mode() & os.ModeType { + case os.ModeSymlink: + // readlinkat implies AT_EMPTY_PATH since Linux 2.6.39. See + // Linux commit 65cfc6722361 ("readlinkat(), fchownat() and + // fstatat() with empty relative pathnames"). + linkDest, err := readlinkatFile(nextDir, "") + // We don't need the handle anymore. + _ = nextDir.Close() + if err != nil { + return nil, "", err + } + + linksWalked++ + if linksWalked > maxSymlinkLimit { + return nil, "", &os.PathError{Op: "securejoin.lookupInRoot", Path: logicalRootPath + "/" + unsafePath, Err: unix.ELOOP} + } + + // Swap out the symlink's component for the link entry itself. + if err := symStack.SwapLink(part, currentDir, oldRemainingPath, linkDest); err != nil { + return nil, "", fmt.Errorf("walking into symlink %q failed: push symlink: %w", part, err) + } + + // Update our logical remaining path. + remainingPath = linkDest + "/" + remainingPath + // Absolute symlinks reset any work we've already done. + if path.IsAbs(linkDest) { + // Jump to root. + rootClone, err := dupFile(root) + if err != nil { + return nil, "", fmt.Errorf("clone root fd: %w", err) + } + _ = currentDir.Close() + currentDir = rootClone + currentPath = "/" + } + + default: + // If we are dealing with a directory, simply walk into it. + _ = currentDir.Close() + currentDir = nextDir + currentPath = nextPath + + // The part was real, so drop it from the symlink stack. + if err := symStack.PopPart(part); err != nil { + return nil, "", fmt.Errorf("walking into directory %q failed: %w", part, err) + } + + // If we are operating on a .., make sure we haven't escaped. + // We only have to check for ".." here because walking down + // into a regular component component cannot cause you to + // escape. This mirrors the logic in RESOLVE_IN_ROOT, except we + // have to check every ".." rather than only checking after a + // rename or mount on the system. + if part == ".." { + // Make sure the root hasn't moved. + if err := checkProcSelfFdPath(logicalRootPath, root); err != nil { + return nil, "", fmt.Errorf("root path moved during lookup: %w", err) + } + // Make sure the path is what we expect. + fullPath := logicalRootPath + nextPath + if err := checkProcSelfFdPath(fullPath, currentDir); err != nil { + return nil, "", fmt.Errorf("walking into %q had unexpected result: %w", part, err) + } + } + } + + default: + if !partial { + return nil, "", err + } + // If there are any remaining components in the symlink stack, we + // are still within a symlink resolution and thus we hit a dangling + // symlink. So pretend that the first symlink in the stack we hit + // was an ENOENT (to match openat2). + if oldDir, remainingPath, ok := symStack.PopTopSymlink(); ok { + _ = currentDir.Close() + return oldDir, remainingPath, err + } + // We have hit a final component that doesn't exist, so we have our + // partial open result. Note that we have to use the OLD remaining + // path, since the lookup failed. + return currentDir, oldRemainingPath, err + } + } + + // If the unsafePath had a trailing slash, we need to make sure we try to + // do a relative "." open so that we will correctly return an error when + // the final component is a non-directory (to match openat2). In the + // context of openat2, a trailing slash and a trailing "/." are completely + // equivalent. + if strings.HasSuffix(unsafePath, "/") { + nextDir, err := openatFile(currentDir, ".", unix.O_PATH|unix.O_NOFOLLOW|unix.O_CLOEXEC, 0) + if err != nil { + if !partial { + _ = currentDir.Close() + currentDir = nil + } + return currentDir, "", err + } + _ = currentDir.Close() + currentDir = nextDir + } + + // All of the components existed! + return currentDir, "", nil +} diff --git a/vendor/github.com/cyphar/filepath-securejoin/mkdir_linux.go b/vendor/github.com/cyphar/filepath-securejoin/mkdir_linux.go new file mode 100644 index 0000000000..5e559bb7a8 --- /dev/null +++ b/vendor/github.com/cyphar/filepath-securejoin/mkdir_linux.go @@ -0,0 +1,215 @@ +//go:build linux + +// Copyright (C) 2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + "golang.org/x/sys/unix" +) + +var ( + errInvalidMode = errors.New("invalid permission mode") + errPossibleAttack = errors.New("possible attack detected") +) + +// MkdirAllHandle is equivalent to [MkdirAll], except that it is safer to use +// in two respects: +// +// - The caller provides the root directory as an *[os.File] (preferably O_PATH) +// handle. This means that the caller can be sure which root directory is +// being used. Note that this can be emulated by using /proc/self/fd/... as +// the root path with [os.MkdirAll]. +// +// - Once all of the directories have been created, an *[os.File] O_PATH handle +// to the directory at unsafePath is returned to the caller. This is done in +// an effectively-race-free way (an attacker would only be able to swap the +// final directory component), which is not possible to emulate with +// [MkdirAll]. +// +// In addition, the returned handle is obtained far more efficiently than doing +// a brand new lookup of unsafePath (such as with [SecureJoin] or openat2) after +// doing [MkdirAll]. If you intend to open the directory after creating it, you +// should use MkdirAllHandle. +func MkdirAllHandle(root *os.File, unsafePath string, mode int) (_ *os.File, Err error) { + // Make sure there are no os.FileMode bits set. + if mode&^0o7777 != 0 { + return nil, fmt.Errorf("%w for mkdir 0o%.3o", errInvalidMode, mode) + } + // On Linux, mkdirat(2) (and os.Mkdir) silently ignore the suid and sgid + // bits. We could also silently ignore them but since we have very few + // users it seems more prudent to return an error so users notice that + // these bits will not be set. + if mode&^0o1777 != 0 { + return nil, fmt.Errorf("%w for mkdir 0o%.3o: suid and sgid are ignored by mkdir", errInvalidMode, mode) + } + + // Try to open as much of the path as possible. + currentDir, remainingPath, err := partialLookupInRoot(root, unsafePath) + defer func() { + if Err != nil { + _ = currentDir.Close() + } + }() + if err != nil && !errors.Is(err, unix.ENOENT) { + return nil, fmt.Errorf("find existing subpath of %q: %w", unsafePath, err) + } + + // If there is an attacker deleting directories as we walk into them, + // detect this proactively. Note this is guaranteed to detect if the + // attacker deleted any part of the tree up to currentDir. + // + // Once we walk into a dead directory, partialLookupInRoot would not be + // able to walk further down the tree (directories must be empty before + // they are deleted), and if the attacker has removed the entire tree we + // can be sure that anything that was originally inside a dead directory + // must also be deleted and thus is a dead directory in its own right. + // + // This is mostly a quality-of-life check, because mkdir will simply fail + // later if the attacker deletes the tree after this check. + if err := isDeadInode(currentDir); err != nil { + return nil, fmt.Errorf("finding existing subpath of %q: %w", unsafePath, err) + } + + // Re-open the path to match the O_DIRECTORY reopen loop later (so that we + // always return a non-O_PATH handle). We also check that we actually got a + // directory. + if reopenDir, err := Reopen(currentDir, unix.O_DIRECTORY|unix.O_CLOEXEC); errors.Is(err, unix.ENOTDIR) { + return nil, fmt.Errorf("cannot create subdirectories in %q: %w", currentDir.Name(), unix.ENOTDIR) + } else if err != nil { + return nil, fmt.Errorf("re-opening handle to %q: %w", currentDir.Name(), err) + } else { + _ = currentDir.Close() + currentDir = reopenDir + } + + remainingParts := strings.Split(remainingPath, string(filepath.Separator)) + if slices_Contains(remainingParts, "..") { + // The path contained ".." components after the end of the "real" + // components. We could try to safely resolve ".." here but that would + // add a bunch of extra logic for something that it's not clear even + // needs to be supported. So just return an error. + // + // If we do filepath.Clean(remainingPath) then we end up with the + // problem that ".." can erase a trailing dangling symlink and produce + // a path that doesn't quite match what the user asked for. + return nil, fmt.Errorf("%w: yet-to-be-created path %q contains '..' components", unix.ENOENT, remainingPath) + } + + // Make sure the mode doesn't have any type bits. + mode &^= unix.S_IFMT + + // Create the remaining components. + for _, part := range remainingParts { + switch part { + case "", ".": + // Skip over no-op paths. + continue + } + + // NOTE: mkdir(2) will not follow trailing symlinks, so we can safely + // create the final component without worrying about symlink-exchange + // attacks. + // + // If we get -EEXIST, it's possible that another program created the + // directory at the same time as us. In that case, just continue on as + // if we created it (if the created inode is not a directory, the + // following open call will fail). + if err := unix.Mkdirat(int(currentDir.Fd()), part, uint32(mode)); err != nil && !errors.Is(err, unix.EEXIST) { + err = &os.PathError{Op: "mkdirat", Path: currentDir.Name() + "/" + part, Err: err} + // Make the error a bit nicer if the directory is dead. + if deadErr := isDeadInode(currentDir); deadErr != nil { + // TODO: Once we bump the minimum Go version to 1.20, we can use + // multiple %w verbs for this wrapping. For now we need to use a + // compatibility shim for older Go versions. + //err = fmt.Errorf("%w (%w)", err, deadErr) + err = wrapBaseError(err, deadErr) + } + return nil, err + } + + // Get a handle to the next component. O_DIRECTORY means we don't need + // to use O_PATH. + var nextDir *os.File + if hasOpenat2() { + nextDir, err = openat2File(currentDir, part, &unix.OpenHow{ + Flags: unix.O_NOFOLLOW | unix.O_DIRECTORY | unix.O_CLOEXEC, + Resolve: unix.RESOLVE_BENEATH | unix.RESOLVE_NO_SYMLINKS | unix.RESOLVE_NO_XDEV, + }) + } else { + nextDir, err = openatFile(currentDir, part, unix.O_NOFOLLOW|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) + } + if err != nil { + return nil, err + } + _ = currentDir.Close() + currentDir = nextDir + + // It's possible that the directory we just opened was swapped by an + // attacker. Unfortunately there isn't much we can do to protect + // against this, and MkdirAll's behaviour is that we will reuse + // existing directories anyway so the need to protect against this is + // incredibly limited (and arguably doesn't even deserve mention here). + // + // Ideally we might want to check that the owner and mode match what we + // would've created -- unfortunately, it is non-trivial to verify that + // the owner and mode of the created directory match. While plain Unix + // DAC rules seem simple enough to emulate, there are a bunch of other + // factors that can change the mode or owner of created directories + // (default POSIX ACLs, mount options like uid=1,gid=2,umask=0 on + // filesystems like vfat, etc etc). We used to try to verify this but + // it just lead to a series of spurious errors. + // + // We could also check that the directory is non-empty, but + // unfortunately some pseduofilesystems (like cgroupfs) create + // non-empty directories, which would result in different spurious + // errors. + } + return currentDir, nil +} + +// MkdirAll is a race-safe alternative to the [os.MkdirAll] function, +// where the new directory is guaranteed to be within the root directory (if an +// attacker can move directories from inside the root to outside the root, the +// created directory tree might be outside of the root but the key constraint +// is that at no point will we walk outside of the directory tree we are +// creating). +// +// Effectively, MkdirAll(root, unsafePath, mode) is equivalent to +// +// path, _ := securejoin.SecureJoin(root, unsafePath) +// err := os.MkdirAll(path, mode) +// +// But is much safer. The above implementation is unsafe because if an attacker +// can modify the filesystem tree between [SecureJoin] and [os.MkdirAll], it is +// possible for MkdirAll to resolve unsafe symlink components and create +// directories outside of the root. +// +// If you plan to open the directory after you have created it or want to use +// an open directory handle as the root, you should use [MkdirAllHandle] instead. +// This function is a wrapper around [MkdirAllHandle]. +// +// NOTE: The mode argument must be set the unix mode bits (unix.S_I...), not +// the Go generic mode bits ([os.FileMode]...). +func MkdirAll(root, unsafePath string, mode int) error { + rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) + if err != nil { + return err + } + defer rootDir.Close() + + f, err := MkdirAllHandle(rootDir, unsafePath, mode) + if err != nil { + return err + } + _ = f.Close() + return nil +} diff --git a/vendor/github.com/cyphar/filepath-securejoin/open_linux.go b/vendor/github.com/cyphar/filepath-securejoin/open_linux.go new file mode 100644 index 0000000000..230be73f0e --- /dev/null +++ b/vendor/github.com/cyphar/filepath-securejoin/open_linux.go @@ -0,0 +1,103 @@ +//go:build linux + +// Copyright (C) 2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "fmt" + "os" + "strconv" + + "golang.org/x/sys/unix" +) + +// OpenatInRoot is equivalent to [OpenInRoot], except that the root is provided +// using an *[os.File] handle, to ensure that the correct root directory is used. +func OpenatInRoot(root *os.File, unsafePath string) (*os.File, error) { + handle, err := completeLookupInRoot(root, unsafePath) + if err != nil { + return nil, &os.PathError{Op: "securejoin.OpenInRoot", Path: unsafePath, Err: err} + } + return handle, nil +} + +// OpenInRoot safely opens the provided unsafePath within the root. +// Effectively, OpenInRoot(root, unsafePath) is equivalent to +// +// path, _ := securejoin.SecureJoin(root, unsafePath) +// handle, err := os.OpenFile(path, unix.O_PATH|unix.O_CLOEXEC) +// +// But is much safer. The above implementation is unsafe because if an attacker +// can modify the filesystem tree between [SecureJoin] and [os.OpenFile], it is +// possible for the returned file to be outside of the root. +// +// Note that the returned handle is an O_PATH handle, meaning that only a very +// limited set of operations will work on the handle. This is done to avoid +// accidentally opening an untrusted file that could cause issues (such as a +// disconnected TTY that could cause a DoS, or some other issue). In order to +// use the returned handle, you can "upgrade" it to a proper handle using +// [Reopen]. +func OpenInRoot(root, unsafePath string) (*os.File, error) { + rootDir, err := os.OpenFile(root, unix.O_PATH|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) + if err != nil { + return nil, err + } + defer rootDir.Close() + return OpenatInRoot(rootDir, unsafePath) +} + +// Reopen takes an *[os.File] handle and re-opens it through /proc/self/fd. +// Reopen(file, flags) is effectively equivalent to +// +// fdPath := fmt.Sprintf("/proc/self/fd/%d", file.Fd()) +// os.OpenFile(fdPath, flags|unix.O_CLOEXEC) +// +// But with some extra hardenings to ensure that we are not tricked by a +// maliciously-configured /proc mount. While this attack scenario is not +// common, in container runtimes it is possible for higher-level runtimes to be +// tricked into configuring an unsafe /proc that can be used to attack file +// operations. See [CVE-2019-19921] for more details. +// +// [CVE-2019-19921]: https://github.com/advisories/GHSA-fh74-hm69-rqjw +func Reopen(handle *os.File, flags int) (*os.File, error) { + procRoot, err := getProcRoot() + if err != nil { + return nil, err + } + + // We can't operate on /proc/thread-self/fd/$n directly when doing a + // re-open, so we need to open /proc/thread-self/fd and then open a single + // final component. + procFdDir, closer, err := procThreadSelf(procRoot, "fd/") + if err != nil { + return nil, fmt.Errorf("get safe /proc/thread-self/fd handle: %w", err) + } + defer procFdDir.Close() + defer closer() + + // Try to detect if there is a mount on top of the magic-link we are about + // to open. If we are using unsafeHostProcRoot(), this could change after + // we check it (and there's nothing we can do about that) but for + // privateProcRoot() this should be guaranteed to be safe (at least since + // Linux 5.12[1], when anonymous mount namespaces were completely isolated + // from external mounts including mount propagation events). + // + // [1]: Linux commit ee2e3f50629f ("mount: fix mounting of detached mounts + // onto targets that reside on shared mounts"). + fdStr := strconv.Itoa(int(handle.Fd())) + if err := checkSymlinkOvermount(procRoot, procFdDir, fdStr); err != nil { + return nil, fmt.Errorf("check safety of /proc/thread-self/fd/%s magiclink: %w", fdStr, err) + } + + flags |= unix.O_CLOEXEC + // Rather than just wrapping openatFile, open-code it so we can copy + // handle.Name(). + reopenFd, err := unix.Openat(int(procFdDir.Fd()), fdStr, flags, 0) + if err != nil { + return nil, fmt.Errorf("reopen fd %d: %w", handle.Fd(), err) + } + return os.NewFile(uintptr(reopenFd), handle.Name()), nil +} diff --git a/vendor/github.com/cyphar/filepath-securejoin/openat2_linux.go b/vendor/github.com/cyphar/filepath-securejoin/openat2_linux.go new file mode 100644 index 0000000000..f7a13e69ce --- /dev/null +++ b/vendor/github.com/cyphar/filepath-securejoin/openat2_linux.go @@ -0,0 +1,127 @@ +//go:build linux + +// Copyright (C) 2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + "golang.org/x/sys/unix" +) + +var hasOpenat2 = sync_OnceValue(func() bool { + fd, err := unix.Openat2(unix.AT_FDCWD, ".", &unix.OpenHow{ + Flags: unix.O_PATH | unix.O_CLOEXEC, + Resolve: unix.RESOLVE_NO_SYMLINKS | unix.RESOLVE_IN_ROOT, + }) + if err != nil { + return false + } + _ = unix.Close(fd) + return true +}) + +func scopedLookupShouldRetry(how *unix.OpenHow, err error) bool { + // RESOLVE_IN_ROOT (and RESOLVE_BENEATH) can return -EAGAIN if we resolve + // ".." while a mount or rename occurs anywhere on the system. This could + // happen spuriously, or as the result of an attacker trying to mess with + // us during lookup. + // + // In addition, scoped lookups have a "safety check" at the end of + // complete_walk which will return -EXDEV if the final path is not in the + // root. + return how.Resolve&(unix.RESOLVE_IN_ROOT|unix.RESOLVE_BENEATH) != 0 && + (errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EXDEV)) +} + +const scopedLookupMaxRetries = 10 + +func openat2File(dir *os.File, path string, how *unix.OpenHow) (*os.File, error) { + fullPath := dir.Name() + "/" + path + // Make sure we always set O_CLOEXEC. + how.Flags |= unix.O_CLOEXEC + var tries int + for tries < scopedLookupMaxRetries { + fd, err := unix.Openat2(int(dir.Fd()), path, how) + if err != nil { + if scopedLookupShouldRetry(how, err) { + // We retry a couple of times to avoid the spurious errors, and + // if we are being attacked then returning -EAGAIN is the best + // we can do. + tries++ + continue + } + return nil, &os.PathError{Op: "openat2", Path: fullPath, Err: err} + } + // If we are using RESOLVE_IN_ROOT, the name we generated may be wrong. + // NOTE: The procRoot code MUST NOT use RESOLVE_IN_ROOT, otherwise + // you'll get infinite recursion here. + if how.Resolve&unix.RESOLVE_IN_ROOT == unix.RESOLVE_IN_ROOT { + if actualPath, err := rawProcSelfFdReadlink(fd); err == nil { + fullPath = actualPath + } + } + return os.NewFile(uintptr(fd), fullPath), nil + } + return nil, &os.PathError{Op: "openat2", Path: fullPath, Err: errPossibleAttack} +} + +func lookupOpenat2(root *os.File, unsafePath string, partial bool) (*os.File, string, error) { + if !partial { + file, err := openat2File(root, unsafePath, &unix.OpenHow{ + Flags: unix.O_PATH | unix.O_CLOEXEC, + Resolve: unix.RESOLVE_IN_ROOT | unix.RESOLVE_NO_MAGICLINKS, + }) + return file, "", err + } + return partialLookupOpenat2(root, unsafePath) +} + +// partialLookupOpenat2 is an alternative implementation of +// partialLookupInRoot, using openat2(RESOLVE_IN_ROOT) to more safely get a +// handle to the deepest existing child of the requested path within the root. +func partialLookupOpenat2(root *os.File, unsafePath string) (*os.File, string, error) { + // TODO: Implement this as a git-bisect-like binary search. + + unsafePath = filepath.ToSlash(unsafePath) // noop + endIdx := len(unsafePath) + var lastError error + for endIdx > 0 { + subpath := unsafePath[:endIdx] + + handle, err := openat2File(root, subpath, &unix.OpenHow{ + Flags: unix.O_PATH | unix.O_CLOEXEC, + Resolve: unix.RESOLVE_IN_ROOT | unix.RESOLVE_NO_MAGICLINKS, + }) + if err == nil { + // Jump over the slash if we have a non-"" remainingPath. + if endIdx < len(unsafePath) { + endIdx += 1 + } + // We found a subpath! + return handle, unsafePath[endIdx:], lastError + } + if errors.Is(err, unix.ENOENT) || errors.Is(err, unix.ENOTDIR) { + // That path doesn't exist, let's try the next directory up. + endIdx = strings.LastIndexByte(subpath, '/') + lastError = err + continue + } + return nil, "", fmt.Errorf("open subpath: %w", err) + } + // If we couldn't open anything, the whole subpath is missing. Return a + // copy of the root fd so that the caller doesn't close this one by + // accident. + rootClone, err := dupFile(root) + if err != nil { + return nil, "", err + } + return rootClone, unsafePath, lastError +} diff --git a/vendor/github.com/cyphar/filepath-securejoin/openat_linux.go b/vendor/github.com/cyphar/filepath-securejoin/openat_linux.go new file mode 100644 index 0000000000..949fb5f2d8 --- /dev/null +++ b/vendor/github.com/cyphar/filepath-securejoin/openat_linux.go @@ -0,0 +1,59 @@ +//go:build linux + +// Copyright (C) 2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "os" + "path/filepath" + + "golang.org/x/sys/unix" +) + +func dupFile(f *os.File) (*os.File, error) { + fd, err := unix.FcntlInt(f.Fd(), unix.F_DUPFD_CLOEXEC, 0) + if err != nil { + return nil, os.NewSyscallError("fcntl(F_DUPFD_CLOEXEC)", err) + } + return os.NewFile(uintptr(fd), f.Name()), nil +} + +func openatFile(dir *os.File, path string, flags int, mode int) (*os.File, error) { + // Make sure we always set O_CLOEXEC. + flags |= unix.O_CLOEXEC + fd, err := unix.Openat(int(dir.Fd()), path, flags, uint32(mode)) + if err != nil { + return nil, &os.PathError{Op: "openat", Path: dir.Name() + "/" + path, Err: err} + } + // All of the paths we use with openatFile(2) are guaranteed to be + // lexically safe, so we can use path.Join here. + fullPath := filepath.Join(dir.Name(), path) + return os.NewFile(uintptr(fd), fullPath), nil +} + +func fstatatFile(dir *os.File, path string, flags int) (unix.Stat_t, error) { + var stat unix.Stat_t + if err := unix.Fstatat(int(dir.Fd()), path, &stat, flags); err != nil { + return stat, &os.PathError{Op: "fstatat", Path: dir.Name() + "/" + path, Err: err} + } + return stat, nil +} + +func readlinkatFile(dir *os.File, path string) (string, error) { + size := 4096 + for { + linkBuf := make([]byte, size) + n, err := unix.Readlinkat(int(dir.Fd()), path, linkBuf) + if err != nil { + return "", &os.PathError{Op: "readlinkat", Path: dir.Name() + "/" + path, Err: err} + } + if n != size { + return string(linkBuf[:n]), nil + } + // Possible truncation, resize the buffer. + size *= 2 + } +} diff --git a/vendor/github.com/cyphar/filepath-securejoin/procfs_linux.go b/vendor/github.com/cyphar/filepath-securejoin/procfs_linux.go new file mode 100644 index 0000000000..809a579cbd --- /dev/null +++ b/vendor/github.com/cyphar/filepath-securejoin/procfs_linux.go @@ -0,0 +1,452 @@ +//go:build linux + +// Copyright (C) 2024 SUSE LLC. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package securejoin + +import ( + "errors" + "fmt" + "os" + "runtime" + "strconv" + + "golang.org/x/sys/unix" +) + +func fstat(f *os.File) (unix.Stat_t, error) { + var stat unix.Stat_t + if err := unix.Fstat(int(f.Fd()), &stat); err != nil { + return stat, &os.PathError{Op: "fstat", Path: f.Name(), Err: err} + } + return stat, nil +} + +func fstatfs(f *os.File) (unix.Statfs_t, error) { + var statfs unix.Statfs_t + if err := unix.Fstatfs(int(f.Fd()), &statfs); err != nil { + return statfs, &os.PathError{Op: "fstatfs", Path: f.Name(), Err: err} + } + return statfs, nil +} + +// The kernel guarantees that the root inode of a procfs mount has an +// f_type of PROC_SUPER_MAGIC and st_ino of PROC_ROOT_INO. +const ( + procSuperMagic = 0x9fa0 // PROC_SUPER_MAGIC + procRootIno = 1 // PROC_ROOT_INO +) + +func verifyProcRoot(procRoot *os.File) error { + if statfs, err := fstatfs(procRoot); err != nil { + return err + } else if statfs.Type != procSuperMagic { + return fmt.Errorf("%w: incorrect procfs root filesystem type 0x%x", errUnsafeProcfs, statfs.Type) + } + if stat, err := fstat(procRoot); err != nil { + return err + } else if stat.Ino != procRootIno { + return fmt.Errorf("%w: incorrect procfs root inode number %d", errUnsafeProcfs, stat.Ino) + } + return nil +} + +var hasNewMountApi = sync_OnceValue(func() bool { + // All of the pieces of the new mount API we use (fsopen, fsconfig, + // fsmount, open_tree) were added together in Linux 5.1[1,2], so we can + // just check for one of the syscalls and the others should also be + // available. + // + // Just try to use open_tree(2) to open a file without OPEN_TREE_CLONE. + // This is equivalent to openat(2), but tells us if open_tree is + // available (and thus all of the other basic new mount API syscalls). + // open_tree(2) is most light-weight syscall to test here. + // + // [1]: merge commit 400913252d09 + // [2]: + fd, err := unix.OpenTree(-int(unix.EBADF), "/", unix.OPEN_TREE_CLOEXEC) + if err != nil { + return false + } + _ = unix.Close(fd) + return true +}) + +func fsopen(fsName string, flags int) (*os.File, error) { + // Make sure we always set O_CLOEXEC. + flags |= unix.FSOPEN_CLOEXEC + fd, err := unix.Fsopen(fsName, flags) + if err != nil { + return nil, os.NewSyscallError("fsopen "+fsName, err) + } + return os.NewFile(uintptr(fd), "fscontext:"+fsName), nil +} + +func fsmount(ctx *os.File, flags, mountAttrs int) (*os.File, error) { + // Make sure we always set O_CLOEXEC. + flags |= unix.FSMOUNT_CLOEXEC + fd, err := unix.Fsmount(int(ctx.Fd()), flags, mountAttrs) + if err != nil { + return nil, os.NewSyscallError("fsmount "+ctx.Name(), err) + } + return os.NewFile(uintptr(fd), "fsmount:"+ctx.Name()), nil +} + +func newPrivateProcMount() (*os.File, error) { + procfsCtx, err := fsopen("proc", unix.FSOPEN_CLOEXEC) + if err != nil { + return nil, err + } + defer procfsCtx.Close() + + // Try to configure hidepid=ptraceable,subset=pid if possible, but ignore errors. + _ = unix.FsconfigSetString(int(procfsCtx.Fd()), "hidepid", "ptraceable") + _ = unix.FsconfigSetString(int(procfsCtx.Fd()), "subset", "pid") + + // Get an actual handle. + if err := unix.FsconfigCreate(int(procfsCtx.Fd())); err != nil { + return nil, os.NewSyscallError("fsconfig create procfs", err) + } + return fsmount(procfsCtx, unix.FSMOUNT_CLOEXEC, unix.MS_RDONLY|unix.MS_NODEV|unix.MS_NOEXEC|unix.MS_NOSUID) +} + +func openTree(dir *os.File, path string, flags uint) (*os.File, error) { + dirFd := -int(unix.EBADF) + dirName := "." + if dir != nil { + dirFd = int(dir.Fd()) + dirName = dir.Name() + } + // Make sure we always set O_CLOEXEC. + flags |= unix.OPEN_TREE_CLOEXEC + fd, err := unix.OpenTree(dirFd, path, flags) + if err != nil { + return nil, &os.PathError{Op: "open_tree", Path: path, Err: err} + } + return os.NewFile(uintptr(fd), dirName+"/"+path), nil +} + +func clonePrivateProcMount() (_ *os.File, Err error) { + // Try to make a clone without using AT_RECURSIVE if we can. If this works, + // we can be sure there are no over-mounts and so if the root is valid then + // we're golden. Otherwise, we have to deal with over-mounts. + procfsHandle, err := openTree(nil, "/proc", unix.OPEN_TREE_CLONE) + if err != nil || hookForcePrivateProcRootOpenTreeAtRecursive(procfsHandle) { + procfsHandle, err = openTree(nil, "/proc", unix.OPEN_TREE_CLONE|unix.AT_RECURSIVE) + } + if err != nil { + return nil, fmt.Errorf("creating a detached procfs clone: %w", err) + } + defer func() { + if Err != nil { + _ = procfsHandle.Close() + } + }() + if err := verifyProcRoot(procfsHandle); err != nil { + return nil, err + } + return procfsHandle, nil +} + +func privateProcRoot() (*os.File, error) { + if !hasNewMountApi() || hookForceGetProcRootUnsafe() { + return nil, fmt.Errorf("new mount api: %w", unix.ENOTSUP) + } + // Try to create a new procfs mount from scratch if we can. This ensures we + // can get a procfs mount even if /proc is fake (for whatever reason). + procRoot, err := newPrivateProcMount() + if err != nil || hookForcePrivateProcRootOpenTree(procRoot) { + // Try to clone /proc then... + procRoot, err = clonePrivateProcMount() + } + return procRoot, err +} + +func unsafeHostProcRoot() (_ *os.File, Err error) { + procRoot, err := os.OpenFile("/proc", unix.O_PATH|unix.O_NOFOLLOW|unix.O_DIRECTORY|unix.O_CLOEXEC, 0) + if err != nil { + return nil, err + } + defer func() { + if Err != nil { + _ = procRoot.Close() + } + }() + if err := verifyProcRoot(procRoot); err != nil { + return nil, err + } + return procRoot, nil +} + +func doGetProcRoot() (*os.File, error) { + procRoot, err := privateProcRoot() + if err != nil { + // Fall back to using a /proc handle if making a private mount failed. + // If we have openat2, at least we can avoid some kinds of over-mount + // attacks, but without openat2 there's not much we can do. + procRoot, err = unsafeHostProcRoot() + } + return procRoot, err +} + +var getProcRoot = sync_OnceValues(func() (*os.File, error) { + return doGetProcRoot() +}) + +var hasProcThreadSelf = sync_OnceValue(func() bool { + return unix.Access("/proc/thread-self/", unix.F_OK) == nil +}) + +var errUnsafeProcfs = errors.New("unsafe procfs detected") + +type procThreadSelfCloser func() + +// procThreadSelf returns a handle to /proc/thread-self/ (or an +// equivalent handle on older kernels where /proc/thread-self doesn't exist). +// Once finished with the handle, you must call the returned closer function +// (runtime.UnlockOSThread). You must not pass the returned *os.File to other +// Go threads or use the handle after calling the closer. +// +// This is similar to ProcThreadSelf from runc, but with extra hardening +// applied and using *os.File. +func procThreadSelf(procRoot *os.File, subpath string) (_ *os.File, _ procThreadSelfCloser, Err error) { + // We need to lock our thread until the caller is done with the handle + // because between getting the handle and using it we could get interrupted + // by the Go runtime and hit the case where the underlying thread is + // swapped out and the original thread is killed, resulting in + // pull-your-hair-out-hard-to-debug issues in the caller. + runtime.LockOSThread() + defer func() { + if Err != nil { + runtime.UnlockOSThread() + } + }() + + // Figure out what prefix we want to use. + threadSelf := "thread-self/" + if !hasProcThreadSelf() || hookForceProcSelfTask() { + /// Pre-3.17 kernels don't have /proc/thread-self, so do it manually. + threadSelf = "self/task/" + strconv.Itoa(unix.Gettid()) + "/" + if _, err := fstatatFile(procRoot, threadSelf, unix.AT_SYMLINK_NOFOLLOW); err != nil || hookForceProcSelf() { + // In this case, we running in a pid namespace that doesn't match + // the /proc mount we have. This can happen inside runc. + // + // Unfortunately, there is no nice way to get the correct TID to + // use here because of the age of the kernel, so we have to just + // use /proc/self and hope that it works. + threadSelf = "self/" + } + } + + // Grab the handle. + var ( + handle *os.File + err error + ) + if hasOpenat2() { + // We prefer being able to use RESOLVE_NO_XDEV if we can, to be + // absolutely sure we are operating on a clean /proc handle that + // doesn't have any cheeky overmounts that could trick us (including + // symlink mounts on top of /proc/thread-self). RESOLVE_BENEATH isn't + // strictly needed, but just use it since we have it. + // + // NOTE: /proc/self is technically a magic-link (the contents of the + // symlink are generated dynamically), but it doesn't use + // nd_jump_link() so RESOLVE_NO_MAGICLINKS allows it. + // + // NOTE: We MUST NOT use RESOLVE_IN_ROOT here, as openat2File uses + // procSelfFdReadlink to clean up the returned f.Name() if we use + // RESOLVE_IN_ROOT (which would lead to an infinite recursion). + handle, err = openat2File(procRoot, threadSelf+subpath, &unix.OpenHow{ + Flags: unix.O_PATH | unix.O_NOFOLLOW | unix.O_CLOEXEC, + Resolve: unix.RESOLVE_BENEATH | unix.RESOLVE_NO_XDEV | unix.RESOLVE_NO_MAGICLINKS, + }) + if err != nil { + // TODO: Once we bump the minimum Go version to 1.20, we can use + // multiple %w verbs for this wrapping. For now we need to use a + // compatibility shim for older Go versions. + //err = fmt.Errorf("%w: %w", errUnsafeProcfs, err) + return nil, nil, wrapBaseError(err, errUnsafeProcfs) + } + } else { + handle, err = openatFile(procRoot, threadSelf+subpath, unix.O_PATH|unix.O_NOFOLLOW|unix.O_CLOEXEC, 0) + if err != nil { + // TODO: Once we bump the minimum Go version to 1.20, we can use + // multiple %w verbs for this wrapping. For now we need to use a + // compatibility shim for older Go versions. + //err = fmt.Errorf("%w: %w", errUnsafeProcfs, err) + return nil, nil, wrapBaseError(err, errUnsafeProcfs) + } + defer func() { + if Err != nil { + _ = handle.Close() + } + }() + // We can't detect bind-mounts of different parts of procfs on top of + // /proc (a-la RESOLVE_NO_XDEV), but we can at least be sure that we + // aren't on the wrong filesystem here. + if statfs, err := fstatfs(handle); err != nil { + return nil, nil, err + } else if statfs.Type != procSuperMagic { + return nil, nil, fmt.Errorf("%w: incorrect /proc/self/fd filesystem type 0x%x", errUnsafeProcfs, statfs.Type) + } + } + return handle, runtime.UnlockOSThread, nil +} + +// STATX_MNT_ID_UNIQUE is provided in golang.org/x/sys@v0.20.0, but in order to +// avoid bumping the requirement for a single constant we can just define it +// ourselves. +const STATX_MNT_ID_UNIQUE = 0x4000 + +var hasStatxMountId = sync_OnceValue(func() bool { + var ( + stx unix.Statx_t + // We don't care which mount ID we get. The kernel will give us the + // unique one if it is supported. + wantStxMask uint32 = STATX_MNT_ID_UNIQUE | unix.STATX_MNT_ID + ) + err := unix.Statx(-int(unix.EBADF), "/", 0, int(wantStxMask), &stx) + return err == nil && stx.Mask&wantStxMask != 0 +}) + +func getMountId(dir *os.File, path string) (uint64, error) { + // If we don't have statx(STATX_MNT_ID*) support, we can't do anything. + if !hasStatxMountId() { + return 0, nil + } + + var ( + stx unix.Statx_t + // We don't care which mount ID we get. The kernel will give us the + // unique one if it is supported. + wantStxMask uint32 = STATX_MNT_ID_UNIQUE | unix.STATX_MNT_ID + ) + + err := unix.Statx(int(dir.Fd()), path, unix.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW, int(wantStxMask), &stx) + if stx.Mask&wantStxMask == 0 { + // It's not a kernel limitation, for some reason we couldn't get a + // mount ID. Assume it's some kind of attack. + err = fmt.Errorf("%w: could not get mount id", errUnsafeProcfs) + } + if err != nil { + return 0, &os.PathError{Op: "statx(STATX_MNT_ID_...)", Path: dir.Name() + "/" + path, Err: err} + } + return stx.Mnt_id, nil +} + +func checkSymlinkOvermount(procRoot *os.File, dir *os.File, path string) error { + // Get the mntId of our procfs handle. + expectedMountId, err := getMountId(procRoot, "") + if err != nil { + return err + } + // Get the mntId of the target magic-link. + gotMountId, err := getMountId(dir, path) + if err != nil { + return err + } + // As long as the directory mount is alive, even with wrapping mount IDs, + // we would expect to see a different mount ID here. (Of course, if we're + // using unsafeHostProcRoot() then an attaker could change this after we + // did this check.) + if expectedMountId != gotMountId { + return fmt.Errorf("%w: symlink %s/%s has an overmount obscuring the real link (mount ids do not match %d != %d)", errUnsafeProcfs, dir.Name(), path, expectedMountId, gotMountId) + } + return nil +} + +func doRawProcSelfFdReadlink(procRoot *os.File, fd int) (string, error) { + fdPath := fmt.Sprintf("fd/%d", fd) + procFdLink, closer, err := procThreadSelf(procRoot, fdPath) + if err != nil { + return "", fmt.Errorf("get safe /proc/thread-self/%s handle: %w", fdPath, err) + } + defer procFdLink.Close() + defer closer() + + // Try to detect if there is a mount on top of the magic-link. Since we use the handle directly + // provide to the closure. If the closure uses the handle directly, this + // should be safe in general (a mount on top of the path afterwards would + // not affect the handle itself) and will definitely be safe if we are + // using privateProcRoot() (at least since Linux 5.12[1], when anonymous + // mount namespaces were completely isolated from external mounts including + // mount propagation events). + // + // [1]: Linux commit ee2e3f50629f ("mount: fix mounting of detached mounts + // onto targets that reside on shared mounts"). + if err := checkSymlinkOvermount(procRoot, procFdLink, ""); err != nil { + return "", fmt.Errorf("check safety of /proc/thread-self/fd/%d magiclink: %w", fd, err) + } + + // readlinkat implies AT_EMPTY_PATH since Linux 2.6.39. See Linux commit + // 65cfc6722361 ("readlinkat(), fchownat() and fstatat() with empty + // relative pathnames"). + return readlinkatFile(procFdLink, "") +} + +func rawProcSelfFdReadlink(fd int) (string, error) { + procRoot, err := getProcRoot() + if err != nil { + return "", err + } + return doRawProcSelfFdReadlink(procRoot, fd) +} + +func procSelfFdReadlink(f *os.File) (string, error) { + return rawProcSelfFdReadlink(int(f.Fd())) +} + +var ( + errPossibleBreakout = errors.New("possible breakout detected") + errInvalidDirectory = errors.New("wandered into deleted directory") + errDeletedInode = errors.New("cannot verify path of deleted inode") +) + +func isDeadInode(file *os.File) error { + // If the nlink of a file drops to 0, there is an attacker deleting + // directories during our walk, which could result in weird /proc values. + // It's better to error out in this case. + stat, err := fstat(file) + if err != nil { + return fmt.Errorf("check for dead inode: %w", err) + } + if stat.Nlink == 0 { + err := errDeletedInode + if stat.Mode&unix.S_IFMT == unix.S_IFDIR { + err = errInvalidDirectory + } + return fmt.Errorf("%w %q", err, file.Name()) + } + return nil +} + +func checkProcSelfFdPath(path string, file *os.File) error { + if err := isDeadInode(file); err != nil { + return err + } + actualPath, err := procSelfFdReadlink(file) + if err != nil { + return fmt.Errorf("get path of handle: %w", err) + } + if actualPath != path { + return fmt.Errorf("%w: handle path %q doesn't match expected path %q", errPossibleBreakout, actualPath, path) + } + return nil +} + +// Test hooks used in the procfs tests to verify that the fallback logic works. +// See testing_mocks_linux_test.go and procfs_linux_test.go for more details. +var ( + hookForcePrivateProcRootOpenTree = hookDummyFile + hookForcePrivateProcRootOpenTreeAtRecursive = hookDummyFile + hookForceGetProcRootUnsafe = hookDummy + + hookForceProcSelfTask = hookDummy + hookForceProcSelf = hookDummy +) + +func hookDummy() bool { return false } +func hookDummyFile(_ *os.File) bool { return false } diff --git a/vendor/github.com/cyphar/filepath-securejoin/vfs.go b/vendor/github.com/cyphar/filepath-securejoin/vfs.go index a82a5eae11..36373f8c51 100644 --- a/vendor/github.com/cyphar/filepath-securejoin/vfs.go +++ b/vendor/github.com/cyphar/filepath-securejoin/vfs.go @@ -1,4 +1,4 @@ -// Copyright (C) 2017 SUSE LLC. All rights reserved. +// Copyright (C) 2017-2024 SUSE LLC. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -10,19 +10,19 @@ import "os" // are several projects (umoci and go-mtree) that are using this sort of // interface. -// VFS is the minimal interface necessary to use SecureJoinVFS. A nil VFS is -// equivalent to using the standard os.* family of functions. This is mainly +// VFS is the minimal interface necessary to use [SecureJoinVFS]. A nil VFS is +// equivalent to using the standard [os].* family of functions. This is mainly // used for the purposes of mock testing, but also can be used to otherwise use -// SecureJoin with VFS-like system. +// [SecureJoinVFS] with VFS-like system. type VFS interface { - // Lstat returns a FileInfo describing the named file. If the file is a - // symbolic link, the returned FileInfo describes the symbolic link. Lstat - // makes no attempt to follow the link. These semantics are identical to - // os.Lstat. + // Lstat returns an [os.FileInfo] describing the named file. If the + // file is a symbolic link, the returned [os.FileInfo] describes the + // symbolic link. Lstat makes no attempt to follow the link. + // The semantics are identical to [os.Lstat]. Lstat(name string) (os.FileInfo, error) - // Readlink returns the destination of the named symbolic link. These - // semantics are identical to os.Readlink. + // Readlink returns the destination of the named symbolic link. + // The semantics are identical to [os.Readlink]. Readlink(name string) (string, error) } @@ -30,12 +30,6 @@ type VFS interface { // module. type osVFS struct{} -// Lstat returns a FileInfo describing the named file. If the file is a -// symbolic link, the returned FileInfo describes the symbolic link. Lstat -// makes no attempt to follow the link. These semantics are identical to -// os.Lstat. func (o osVFS) Lstat(name string) (os.FileInfo, error) { return os.Lstat(name) } -// Readlink returns the destination of the named symbolic link. These -// semantics are identical to os.Readlink. func (o osVFS) Readlink(name string) (string, error) { return os.Readlink(name) } diff --git a/vendor/github.com/go-git/go-billy/v5/Makefile b/vendor/github.com/go-git/go-billy/v5/Makefile index 74dad8b491..3c95ddeaac 100644 --- a/vendor/github.com/go-git/go-billy/v5/Makefile +++ b/vendor/github.com/go-git/go-billy/v5/Makefile @@ -1,6 +1,7 @@ # Go parameters GOCMD = go GOTEST = $(GOCMD) test +WASIRUN_WRAPPER := $(CURDIR)/scripts/wasirun-wrapper .PHONY: test test: @@ -9,3 +10,9 @@ test: test-coverage: echo "" > $(COVERAGE_REPORT); \ $(GOTEST) -coverprofile=$(COVERAGE_REPORT) -coverpkg=./... -covermode=$(COVERAGE_MODE) ./... + +.PHONY: wasitest +wasitest: export GOARCH=wasm +wasitest: export GOOS=wasip1 +wasitest: + $(GOTEST) -exec $(WASIRUN_WRAPPER) ./... diff --git a/vendor/github.com/go-git/go-billy/v5/fs.go b/vendor/github.com/go-git/go-billy/v5/fs.go index a9efccdeb2..d86f9d8236 100644 --- a/vendor/github.com/go-git/go-billy/v5/fs.go +++ b/vendor/github.com/go-git/go-billy/v5/fs.go @@ -164,6 +164,8 @@ type File interface { // Name returns the name of the file as presented to Open. Name() string io.Writer + // TODO: Add io.WriterAt for v6 + // io.WriterAt io.Reader io.ReaderAt io.Seeker diff --git a/vendor/github.com/go-git/go-billy/v5/memfs/memory.go b/vendor/github.com/go-git/go-billy/v5/memfs/memory.go index dab73968b6..6cbd7d08ca 100644 --- a/vendor/github.com/go-git/go-billy/v5/memfs/memory.go +++ b/vendor/github.com/go-git/go-billy/v5/memfs/memory.go @@ -9,6 +9,7 @@ import ( "path/filepath" "sort" "strings" + "syscall" "time" "github.com/go-git/go-billy/v5" @@ -18,16 +19,19 @@ import ( const separator = filepath.Separator -// Memory a very convenient filesystem based on memory files +var errNotLink = errors.New("not a link") + +// Memory a very convenient filesystem based on memory files. type Memory struct { s *storage tempCount int } -//New returns a new Memory filesystem. +// New returns a new Memory filesystem. func New() billy.Filesystem { fs := &Memory{s: newStorage()} + fs.s.New("/", 0755|os.ModeDir, 0) return chroot.New(fs, string(separator)) } @@ -57,7 +61,9 @@ func (fs *Memory) OpenFile(filename string, flag int, perm os.FileMode) (billy.F } if target, isLink := fs.resolveLink(filename, f); isLink { - return fs.OpenFile(target, flag, perm) + if target != filename { + return fs.OpenFile(target, flag, perm) + } } } @@ -68,8 +74,6 @@ func (fs *Memory) OpenFile(filename string, flag int, perm os.FileMode) (billy.F return f.Duplicate(filename, perm, flag), nil } -var errNotLink = errors.New("not a link") - func (fs *Memory) resolveLink(fullpath string, f *file) (target string, isLink bool) { if !isSymlink(f.mode) { return fullpath, false @@ -131,8 +135,12 @@ func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (fs *Memory) ReadDir(path string) ([]os.FileInfo, error) { if f, has := fs.s.Get(path); has { if target, isLink := fs.resolveLink(path, f); isLink { - return fs.ReadDir(target) + if target != path { + return fs.ReadDir(target) + } } + } else { + return nil, &os.PathError{Op: "open", Path: path, Err: syscall.ENOENT} } var entries []os.FileInfo @@ -169,17 +177,19 @@ func (fs *Memory) Remove(filename string) error { return fs.s.Remove(filename) } +// Falls back to Go's filepath.Join, which works differently depending on the +// OS where the code is being executed. func (fs *Memory) Join(elem ...string) string { return filepath.Join(elem...) } func (fs *Memory) Symlink(target, link string) error { - _, err := fs.Stat(link) + _, err := fs.Lstat(link) if err == nil { return os.ErrExist } - if !os.IsNotExist(err) { + if !errors.Is(err, os.ErrNotExist) { return err } @@ -230,7 +240,7 @@ func (f *file) Read(b []byte) (int, error) { n, err := f.ReadAt(b, f.position) f.position += int64(n) - if err == io.EOF && n != 0 { + if errors.Is(err, io.EOF) && n != 0 { err = nil } @@ -269,6 +279,10 @@ func (f *file) Seek(offset int64, whence int) (int64, error) { } func (f *file) Write(p []byte) (int, error) { + return f.WriteAt(p, f.position) +} + +func (f *file) WriteAt(p []byte, off int64) (int, error) { if f.isClosed { return 0, os.ErrClosed } @@ -277,8 +291,8 @@ func (f *file) Write(p []byte) (int, error) { return 0, errors.New("write not supported") } - n, err := f.content.WriteAt(p, f.position) - f.position += int64(n) + n, err := f.content.WriteAt(p, off) + f.position = off + int64(n) return n, err } diff --git a/vendor/github.com/go-git/go-billy/v5/memfs/storage.go b/vendor/github.com/go-git/go-billy/v5/memfs/storage.go index e3c4e38bff..16b48ce002 100644 --- a/vendor/github.com/go-git/go-billy/v5/memfs/storage.go +++ b/vendor/github.com/go-git/go-billy/v5/memfs/storage.go @@ -6,6 +6,7 @@ import ( "io" "os" "path/filepath" + "strings" "sync" ) @@ -112,7 +113,7 @@ func (s *storage) Rename(from, to string) error { move := [][2]string{{from, to}} for pathFrom := range s.files { - if pathFrom == from || !filepath.HasPrefix(pathFrom, from) { + if pathFrom == from || !strings.HasPrefix(pathFrom, from) { continue } diff --git a/vendor/github.com/go-git/go-billy/v5/osfs/os_bound.go b/vendor/github.com/go-git/go-billy/v5/osfs/os_bound.go index b4b6dbc07a..c0a6109901 100644 --- a/vendor/github.com/go-git/go-billy/v5/osfs/os_bound.go +++ b/vendor/github.com/go-git/go-billy/v5/osfs/os_bound.go @@ -246,6 +246,10 @@ func (fs *BoundOS) insideBaseDir(filename string) (bool, error) { // a dir that is within the fs.baseDir, by first evaluating any symlinks // that either filename or fs.baseDir may contain. func (fs *BoundOS) insideBaseDirEval(filename string) (bool, error) { + // "/" contains all others. + if fs.baseDir == "/" { + return true, nil + } dir, err := filepath.EvalSymlinks(filepath.Dir(filename)) if dir == "" || os.IsNotExist(err) { dir = filepath.Dir(filename) @@ -255,7 +259,7 @@ func (fs *BoundOS) insideBaseDirEval(filename string) (bool, error) { wd = fs.baseDir } if filename != wd && dir != wd && !strings.HasPrefix(dir, wd+string(filepath.Separator)) { - return false, fmt.Errorf("path outside base dir") + return false, fmt.Errorf("%q: path outside base dir %q: %w", filename, fs.baseDir, os.ErrNotExist) } return true, nil } diff --git a/vendor/github.com/go-git/go-billy/v5/osfs/os_posix.go b/vendor/github.com/go-git/go-billy/v5/osfs/os_posix.go index d834a1145a..6fb8273f17 100644 --- a/vendor/github.com/go-git/go-billy/v5/osfs/os_posix.go +++ b/vendor/github.com/go-git/go-billy/v5/osfs/os_posix.go @@ -1,5 +1,5 @@ -//go:build !plan9 && !windows && !js -// +build !plan9,!windows,!js +//go:build !plan9 && !windows && !wasm +// +build !plan9,!windows,!wasm package osfs diff --git a/vendor/github.com/go-git/go-billy/v5/osfs/os_wasip1.go b/vendor/github.com/go-git/go-billy/v5/osfs/os_wasip1.go new file mode 100644 index 0000000000..79e6e33192 --- /dev/null +++ b/vendor/github.com/go-git/go-billy/v5/osfs/os_wasip1.go @@ -0,0 +1,34 @@ +//go:build wasip1 +// +build wasip1 + +package osfs + +import ( + "os" + "syscall" +) + +func (f *file) Lock() error { + f.m.Lock() + defer f.m.Unlock() + return nil +} + +func (f *file) Unlock() error { + f.m.Lock() + defer f.m.Unlock() + return nil +} + +func rename(from, to string) error { + return os.Rename(from, to) +} + +// umask sets umask to a new value, and returns a func which allows the +// caller to reset it back to what it was originally. +func umask(new int) func() { + old := syscall.Umask(new) + return func() { + syscall.Umask(old) + } +} diff --git a/vendor/github.com/go-git/go-billy/v5/util/util.go b/vendor/github.com/go-git/go-billy/v5/util/util.go index 5c77128c3c..2cdd832c73 100644 --- a/vendor/github.com/go-git/go-billy/v5/util/util.go +++ b/vendor/github.com/go-git/go-billy/v5/util/util.go @@ -1,6 +1,7 @@ package util import ( + "errors" "io" "os" "path/filepath" @@ -33,14 +34,14 @@ func removeAll(fs billy.Basic, path string) error { // Simple case: if Remove works, we're done. err := fs.Remove(path) - if err == nil || os.IsNotExist(err) { + if err == nil || errors.Is(err, os.ErrNotExist) { return nil } // Otherwise, is this a directory we need to recurse into? dir, serr := fs.Stat(path) if serr != nil { - if os.IsNotExist(serr) { + if errors.Is(serr, os.ErrNotExist) { return nil } @@ -60,7 +61,7 @@ func removeAll(fs billy.Basic, path string) error { // Directory. fis, err := dirfs.ReadDir(path) if err != nil { - if os.IsNotExist(err) { + if errors.Is(err, os.ErrNotExist) { // Race. It was deleted between the Lstat and Open. // Return nil per RemoveAll's docs. return nil @@ -81,7 +82,7 @@ func removeAll(fs billy.Basic, path string) error { // Remove directory. err1 := fs.Remove(path) - if err1 == nil || os.IsNotExist(err1) { + if err1 == nil || errors.Is(err1, os.ErrNotExist) { return nil } @@ -96,22 +97,26 @@ func removeAll(fs billy.Basic, path string) error { // WriteFile writes data to a file named by filename in the given filesystem. // If the file does not exist, WriteFile creates it with permissions perm; // otherwise WriteFile truncates it before writing. -func WriteFile(fs billy.Basic, filename string, data []byte, perm os.FileMode) error { +func WriteFile(fs billy.Basic, filename string, data []byte, perm os.FileMode) (err error) { f, err := fs.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, perm) if err != nil { return err } + defer func() { + if f != nil { + err1 := f.Close() + if err == nil { + err = err1 + } + } + }() n, err := f.Write(data) if err == nil && n < len(data) { err = io.ErrShortWrite } - if err1 := f.Close(); err == nil { - err = err1 - } - - return err + return nil } // Random number state. @@ -154,7 +159,7 @@ func TempFile(fs billy.Basic, dir, prefix string) (f billy.File, err error) { for i := 0; i < 10000; i++ { name := filepath.Join(dir, prefix+nextSuffix()) f, err = fs.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0600) - if os.IsExist(err) { + if errors.Is(err, os.ErrExist) { if nconflict++; nconflict > 10 { randmu.Lock() rand = reseed() @@ -185,7 +190,7 @@ func TempDir(fs billy.Dir, dir, prefix string) (name string, err error) { for i := 0; i < 10000; i++ { try := filepath.Join(dir, prefix+nextSuffix()) err = fs.MkdirAll(try, 0700) - if os.IsExist(err) { + if errors.Is(err, os.ErrExist) { if nconflict++; nconflict > 10 { randmu.Lock() rand = reseed() @@ -193,8 +198,8 @@ func TempDir(fs billy.Dir, dir, prefix string) (name string, err error) { } continue } - if os.IsNotExist(err) { - if _, err := os.Stat(dir); os.IsNotExist(err) { + if errors.Is(err, os.ErrNotExist) { + if _, err := os.Stat(dir); errors.Is(err, os.ErrNotExist) { return "", err } } @@ -272,7 +277,7 @@ func ReadFile(fs billy.Basic, name string) ([]byte, error) { data = data[:len(data)+n] if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { err = nil } diff --git a/vendor/github.com/go-git/go-git/v5/COMPATIBILITY.md b/vendor/github.com/go-git/go-git/v5/COMPATIBILITY.md index ff0c22c896..ba1fb90ac5 100644 --- a/vendor/github.com/go-git/go-git/v5/COMPATIBILITY.md +++ b/vendor/github.com/go-git/go-git/v5/COMPATIBILITY.md @@ -11,7 +11,7 @@ compatibility status with go-git. | `init` | `--bare` | ✅ | | | | `init` | `--template`
`--separate-git-dir`
`--shared` | ❌ | | | | `clone` | | ✅ | | - [PlainClone](_examples/clone/main.go) | -| `clone` | Authentication:
- none
- access token
- username + password
- ssh | ✅ | | - [clone ssh](_examples/clone/auth/ssh/main.go)
- [clone access token](_examples/clone/auth/basic/access_token/main.go)
- [clone user + password](_examples/clone/auth/basic/username_password/main.go) | +| `clone` | Authentication:
- none
- access token
- username + password
- ssh | ✅ | | - [clone ssh (private_key)](_examples/clone/auth/ssh/private_key/main.go)
- [clone ssh (ssh_agent)](_examples/clone/auth/ssh/ssh_agent/main.go)
- [clone access token](_examples/clone/auth/basic/access_token/main.go)
- [clone user + password](_examples/clone/auth/basic/username_password/main.go) | | `clone` | `--progress`
`--single-branch`
`--depth`
`--origin`
`--recurse-submodules`
`--shared` | ✅ | | - [recurse submodules](_examples/clone/main.go)
- [progress](_examples/progress/main.go) | ## Basic snapshotting @@ -34,6 +34,7 @@ compatibility status with go-git. | `merge` | | ⚠️ (partial) | Fast-forward only | | | `mergetool` | | ❌ | | | | `stash` | | ❌ | | | +| `sparse-checkout` | | ✅ | | - [sparse-checkout](_examples/sparse-checkout/main.go) | | `tag` | | ✅ | | - [tag](_examples/tag/main.go)
- [tag create and push](_examples/tag-create-push/main.go) | ## Sharing and updating projects diff --git a/vendor/github.com/go-git/go-git/v5/CONTRIBUTING.md b/vendor/github.com/go-git/go-git/v5/CONTRIBUTING.md index fce25328a7..a5b01823bf 100644 --- a/vendor/github.com/go-git/go-git/v5/CONTRIBUTING.md +++ b/vendor/github.com/go-git/go-git/v5/CONTRIBUTING.md @@ -31,6 +31,13 @@ In order for a PR to be accepted it needs to pass a list of requirements: - If the PR is a new feature, it has to come with a suite of unit tests, that tests the new functionality. - In any case, all the PRs have to pass the personal evaluation of at least one of the maintainers of go-git. +### Branches + +The `master` branch is currently used for maintaining the `v5` major release only. The accepted changes would +be dependency bumps, bug fixes and small changes that aren't needed for `v6`. New development should target the +`v6-exp` branch, and if agreed with at least one go-git maintainer, it can be back ported to `v5` by creating +a new PR that targets `master`. + ### Format of the commit message Every commit message should describe what was changed, under which context and, if applicable, the GitHub issue it relates to: diff --git a/vendor/github.com/go-git/go-git/v5/blame.go b/vendor/github.com/go-git/go-git/v5/blame.go index 2a877dcdf9..e3cb39aec3 100644 --- a/vendor/github.com/go-git/go-git/v5/blame.go +++ b/vendor/github.com/go-git/go-git/v5/blame.go @@ -97,13 +97,10 @@ func Blame(c *object.Commit, path string) (*BlameResult, error) { if err != nil { return nil, err } - if finished == true { + if finished { break } } - if err != nil { - return nil, err - } b.lineToCommit = make([]*object.Commit, finalLength) for i := range needsMap { @@ -309,8 +306,8 @@ func (b *blame) addBlames(curItems []*queueItem) (bool, error) { for h := range hunks { hLines := countLines(hunks[h].Text) for hl := 0; hl < hLines; hl++ { - switch { - case hunks[h].Type == diffmatchpatch.DiffEqual: + switch hunks[h].Type { + case diffmatchpatch.DiffEqual: prevl++ curl++ if curl == curItem.NeedsMap[need].Cur { @@ -322,7 +319,7 @@ func (b *blame) addBlames(curItems []*queueItem) (bool, error) { break out } } - case hunks[h].Type == diffmatchpatch.DiffInsert: + case diffmatchpatch.DiffInsert: curl++ if curl == curItem.NeedsMap[need].Cur { // the line we want is added, it may have been added here (or by another parent), skip it for now @@ -331,7 +328,7 @@ func (b *blame) addBlames(curItems []*queueItem) (bool, error) { break out } } - case hunks[h].Type == diffmatchpatch.DiffDelete: + case diffmatchpatch.DiffDelete: prevl += hLines continue out default: diff --git a/vendor/github.com/go-git/go-git/v5/config/config.go b/vendor/github.com/go-git/go-git/v5/config/config.go index 6d41c15dcd..33f6e37d26 100644 --- a/vendor/github.com/go-git/go-git/v5/config/config.go +++ b/vendor/github.com/go-git/go-git/v5/config/config.go @@ -252,6 +252,7 @@ const ( extensionsSection = "extensions" fetchKey = "fetch" urlKey = "url" + pushurlKey = "pushurl" bareKey = "bare" worktreeKey = "worktree" commentCharKey = "commentChar" @@ -633,6 +634,7 @@ func (c *RemoteConfig) unmarshal(s *format.Subsection) error { c.Name = c.raw.Name c.URLs = append([]string(nil), c.raw.Options.GetAll(urlKey)...) + c.URLs = append(c.URLs, c.raw.Options.GetAll(pushurlKey)...) c.Fetch = fetch c.Mirror = c.raw.Options.Get(mirrorKey) == "true" diff --git a/vendor/github.com/go-git/go-git/v5/internal/revision/scanner.go b/vendor/github.com/go-git/go-git/v5/internal/revision/scanner.go index c46c21b795..2444f33ec2 100644 --- a/vendor/github.com/go-git/go-git/v5/internal/revision/scanner.go +++ b/vendor/github.com/go-git/go-git/v5/internal/revision/scanner.go @@ -43,6 +43,11 @@ func tokenizeExpression(ch rune, tokenType token, check runeCategoryValidator, r return tokenType, string(data), nil } +// maxRevisionLength holds the maximum length that will be parsed for a +// revision. Git itself doesn't enforce a max length, but rather leans on +// the OS to enforce it via its ARG_MAX. +const maxRevisionLength = 128 * 1024 // 128kb + var zeroRune = rune(0) // scanner represents a lexical scanner. @@ -52,7 +57,7 @@ type scanner struct { // newScanner returns a new instance of scanner. func newScanner(r io.Reader) *scanner { - return &scanner{r: bufio.NewReader(r)} + return &scanner{r: bufio.NewReader(io.LimitReader(r, maxRevisionLength))} } // Scan extracts tokens and their strings counterpart diff --git a/vendor/github.com/go-git/go-git/v5/options.go b/vendor/github.com/go-git/go-git/v5/options.go index d7776dad5e..3cd0f952c3 100644 --- a/vendor/github.com/go-git/go-git/v5/options.go +++ b/vendor/github.com/go-git/go-git/v5/options.go @@ -416,6 +416,9 @@ type ResetOptions struct { // the index (resetting it to the tree of Commit) and the working tree // depending on Mode. If empty MixedReset is used. Mode ResetMode + // Files, if not empty will constrain the reseting the index to only files + // specified in this list. + Files []string } // Validate validates the fields and sets the default values. @@ -790,3 +793,26 @@ type PlainInitOptions struct { // Validate validates the fields and sets the default values. func (o *PlainInitOptions) Validate() error { return nil } + +var ( + ErrNoRestorePaths = errors.New("you must specify path(s) to restore") +) + +// RestoreOptions describes how a restore should be performed. +type RestoreOptions struct { + // Marks to restore the content in the index + Staged bool + // Marks to restore the content of the working tree + Worktree bool + // List of file paths that will be restored + Files []string +} + +// Validate validates the fields and sets the default values. +func (o *RestoreOptions) Validate() error { + if len(o.Files) == 0 { + return ErrNoRestorePaths + } + + return nil +} diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/format/gitignore/dir.go b/vendor/github.com/go-git/go-git/v5/plumbing/format/gitignore/dir.go index aca5d0dbd2..92df5a3de7 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/format/gitignore/dir.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/format/gitignore/dir.go @@ -64,6 +64,10 @@ func ReadPatterns(fs billy.Filesystem, path []string) (ps []Pattern, err error) for _, fi := range fis { if fi.IsDir() && fi.Name() != gitDir { + if NewMatcher(ps).Match(append(path, fi.Name()), true) { + continue + } + var subps []Pattern subps, err = ReadPatterns(fs, append(path, fi.Name())) if err != nil { diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/format/index/decoder.go b/vendor/github.com/go-git/go-git/v5/plumbing/format/index/decoder.go index 6778cf74ec..fc25d37022 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/format/index/decoder.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/format/index/decoder.go @@ -24,8 +24,8 @@ var ( // ErrInvalidChecksum is returned by Decode if the SHA1 hash mismatch with // the read content ErrInvalidChecksum = errors.New("invalid checksum") - - errUnknownExtension = errors.New("unknown extension") + // ErrUnknownExtension is returned when an index extension is encountered that is considered mandatory + ErrUnknownExtension = errors.New("unknown extension") ) const ( @@ -39,6 +39,7 @@ const ( // A Decoder reads and decodes index files from an input stream. type Decoder struct { + buf *bufio.Reader r io.Reader hash hash.Hash lastEntry *Entry @@ -49,8 +50,10 @@ type Decoder struct { // NewDecoder returns a new decoder that reads from r. func NewDecoder(r io.Reader) *Decoder { h := hash.New(hash.CryptoType) + buf := bufio.NewReader(r) return &Decoder{ - r: io.TeeReader(r, h), + buf: buf, + r: io.TeeReader(buf, h), hash: h, extReader: bufio.NewReader(nil), } @@ -210,71 +213,75 @@ func (d *Decoder) readExtensions(idx *Index) error { // count that they are not supported by jgit or libgit var expected []byte + var peeked []byte var err error - var header [4]byte + // we should always be able to peek for 4 bytes (header) + 4 bytes (extlen) + final hash + // if this fails, we know that we're at the end of the index + peekLen := 4 + 4 + d.hash.Size() + for { expected = d.hash.Sum(nil) - - var n int - if n, err = io.ReadFull(d.r, header[:]); err != nil { - if n == 0 { - err = io.EOF - } - + peeked, err = d.buf.Peek(peekLen) + if len(peeked) < peekLen { + // there can't be an extension at this point, so let's bail out break } + if err != nil { + return err + } - err = d.readExtension(idx, header[:]) + err = d.readExtension(idx) if err != nil { - break + return err } } - if err != errUnknownExtension { + return d.readChecksum(expected) +} + +func (d *Decoder) readExtension(idx *Index) error { + var header [4]byte + + if _, err := io.ReadFull(d.r, header[:]); err != nil { return err } - return d.readChecksum(expected, header) -} + r, err := d.getExtensionReader() + if err != nil { + return err + } -func (d *Decoder) readExtension(idx *Index, header []byte) error { switch { - case bytes.Equal(header, treeExtSignature): - r, err := d.getExtensionReader() - if err != nil { - return err - } - + case bytes.Equal(header[:], treeExtSignature): idx.Cache = &Tree{} d := &treeExtensionDecoder{r} if err := d.Decode(idx.Cache); err != nil { return err } - case bytes.Equal(header, resolveUndoExtSignature): - r, err := d.getExtensionReader() - if err != nil { - return err - } - + case bytes.Equal(header[:], resolveUndoExtSignature): idx.ResolveUndo = &ResolveUndo{} d := &resolveUndoDecoder{r} if err := d.Decode(idx.ResolveUndo); err != nil { return err } - case bytes.Equal(header, endOfIndexEntryExtSignature): - r, err := d.getExtensionReader() - if err != nil { - return err - } - + case bytes.Equal(header[:], endOfIndexEntryExtSignature): idx.EndOfIndexEntry = &EndOfIndexEntry{} d := &endOfIndexEntryDecoder{r} if err := d.Decode(idx.EndOfIndexEntry); err != nil { return err } default: - return errUnknownExtension + // See https://git-scm.com/docs/index-format, which says: + // If the first byte is 'A'..'Z' the extension is optional and can be ignored. + if header[0] < 'A' || header[0] > 'Z' { + return ErrUnknownExtension + } + + d := &unknownExtensionDecoder{r} + if err := d.Decode(); err != nil { + return err + } } return nil @@ -290,11 +297,10 @@ func (d *Decoder) getExtensionReader() (*bufio.Reader, error) { return d.extReader, nil } -func (d *Decoder) readChecksum(expected []byte, alreadyRead [4]byte) error { +func (d *Decoder) readChecksum(expected []byte) error { var h plumbing.Hash - copy(h[:4], alreadyRead[:]) - if _, err := io.ReadFull(d.r, h[4:]); err != nil { + if _, err := io.ReadFull(d.r, h[:]); err != nil { return err } @@ -476,3 +482,22 @@ func (d *endOfIndexEntryDecoder) Decode(e *EndOfIndexEntry) error { _, err = io.ReadFull(d.r, e.Hash[:]) return err } + +type unknownExtensionDecoder struct { + r *bufio.Reader +} + +func (d *unknownExtensionDecoder) Decode() error { + var buf [1024]byte + + for { + _, err := d.r.Read(buf[:]) + if err == io.EOF { + break + } + if err != nil { + return err + } + } + return nil +} diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/format/index/encoder.go b/vendor/github.com/go-git/go-git/v5/plumbing/format/index/encoder.go index fa2d814454..c232e03231 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/format/index/encoder.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/format/index/encoder.go @@ -3,8 +3,11 @@ package index import ( "bytes" "errors" + "fmt" "io" + "path" "sort" + "strings" "time" "github.com/go-git/go-git/v5/plumbing/hash" @@ -13,7 +16,7 @@ import ( var ( // EncodeVersionSupported is the range of supported index versions - EncodeVersionSupported uint32 = 3 + EncodeVersionSupported uint32 = 4 // ErrInvalidTimestamp is returned by Encode if a Index with a Entry with // negative timestamp values @@ -22,20 +25,25 @@ var ( // An Encoder writes an Index to an output stream. type Encoder struct { - w io.Writer - hash hash.Hash + w io.Writer + hash hash.Hash + lastEntry *Entry } // NewEncoder returns a new encoder that writes to w. func NewEncoder(w io.Writer) *Encoder { h := hash.New(hash.CryptoType) mw := io.MultiWriter(w, h) - return &Encoder{mw, h} + return &Encoder{mw, h, nil} } // Encode writes the Index to the stream of the encoder. func (e *Encoder) Encode(idx *Index) error { - // TODO: support v4 + return e.encode(idx, true) +} + +func (e *Encoder) encode(idx *Index, footer bool) error { + // TODO: support extensions if idx.Version > EncodeVersionSupported { return ErrUnsupportedVersion @@ -49,7 +57,10 @@ func (e *Encoder) Encode(idx *Index) error { return err } - return e.encodeFooter() + if footer { + return e.encodeFooter() + } + return nil } func (e *Encoder) encodeHeader(idx *Index) error { @@ -64,7 +75,7 @@ func (e *Encoder) encodeEntries(idx *Index) error { sort.Sort(byName(idx.Entries)) for _, entry := range idx.Entries { - if err := e.encodeEntry(entry); err != nil { + if err := e.encodeEntry(idx, entry); err != nil { return err } entryLength := entryHeaderLength @@ -73,7 +84,7 @@ func (e *Encoder) encodeEntries(idx *Index) error { } wrote := entryLength + len(entry.Name) - if err := e.padEntry(wrote); err != nil { + if err := e.padEntry(idx, wrote); err != nil { return err } } @@ -81,7 +92,7 @@ func (e *Encoder) encodeEntries(idx *Index) error { return nil } -func (e *Encoder) encodeEntry(entry *Entry) error { +func (e *Encoder) encodeEntry(idx *Index, entry *Entry) error { sec, nsec, err := e.timeToUint32(&entry.CreatedAt) if err != nil { return err @@ -132,9 +143,68 @@ func (e *Encoder) encodeEntry(entry *Entry) error { return err } + switch idx.Version { + case 2, 3: + err = e.encodeEntryName(entry) + case 4: + err = e.encodeEntryNameV4(entry) + default: + err = ErrUnsupportedVersion + } + + return err +} + +func (e *Encoder) encodeEntryName(entry *Entry) error { return binary.Write(e.w, []byte(entry.Name)) } +func (e *Encoder) encodeEntryNameV4(entry *Entry) error { + name := entry.Name + l := 0 + if e.lastEntry != nil { + dir := path.Dir(e.lastEntry.Name) + "/" + if strings.HasPrefix(entry.Name, dir) { + l = len(e.lastEntry.Name) - len(dir) + name = strings.TrimPrefix(entry.Name, dir) + } else { + l = len(e.lastEntry.Name) + } + } + + e.lastEntry = entry + + err := binary.WriteVariableWidthInt(e.w, int64(l)) + if err != nil { + return err + } + + return binary.Write(e.w, []byte(name+string('\x00'))) +} + +func (e *Encoder) encodeRawExtension(signature string, data []byte) error { + if len(signature) != 4 { + return fmt.Errorf("invalid signature length") + } + + _, err := e.w.Write([]byte(signature)) + if err != nil { + return err + } + + err = binary.WriteUint32(e.w, uint32(len(data))) + if err != nil { + return err + } + + _, err = e.w.Write(data) + if err != nil { + return err + } + + return nil +} + func (e *Encoder) timeToUint32(t *time.Time) (uint32, uint32, error) { if t.IsZero() { return 0, 0, nil @@ -147,7 +217,11 @@ func (e *Encoder) timeToUint32(t *time.Time) (uint32, uint32, error) { return uint32(t.Unix()), uint32(t.Nanosecond()), nil } -func (e *Encoder) padEntry(wrote int) error { +func (e *Encoder) padEntry(idx *Index, wrote int) error { + if idx.Version == 4 { + return nil + } + padLen := 8 - wrote%8 _, err := e.w.Write(bytes.Repeat([]byte{'\x00'}, padLen)) diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/format/packfile/delta_index.go b/vendor/github.com/go-git/go-git/v5/plumbing/format/packfile/delta_index.go index 07a61120e5..a60ec0b24d 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/format/packfile/delta_index.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/format/packfile/delta_index.go @@ -32,19 +32,17 @@ func (idx *deltaIndex) findMatch(src, tgt []byte, tgtOffset int) (srcOffset, l i return 0, -1 } - if len(tgt) >= tgtOffset+s && len(src) >= blksz { - h := hashBlock(tgt, tgtOffset) - tIdx := h & idx.mask - eIdx := idx.table[tIdx] - if eIdx != 0 { - srcOffset = idx.entries[eIdx] - } else { - return - } - - l = matchLength(src, tgt, tgtOffset, srcOffset) + h := hashBlock(tgt, tgtOffset) + tIdx := h & idx.mask + eIdx := idx.table[tIdx] + if eIdx == 0 { + return } + srcOffset = idx.entries[eIdx] + + l = matchLength(src, tgt, tgtOffset, srcOffset) + return } diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/format/packfile/patch_delta.go b/vendor/github.com/go-git/go-git/v5/plumbing/format/packfile/patch_delta.go index 960769c7c8..a9c6b9b56f 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/format/packfile/patch_delta.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/format/packfile/patch_delta.go @@ -26,6 +26,13 @@ var ( const ( payload = 0x7f // 0111 1111 continuation = 0x80 // 1000 0000 + + // maxPatchPreemptionSize defines what is the max size of bytes to be + // premptively made available for a patch operation. + maxPatchPreemptionSize uint = 65536 + + // minDeltaSize defines the smallest size for a delta. + minDeltaSize = 4 ) type offset struct { @@ -86,9 +93,13 @@ func ApplyDelta(target, base plumbing.EncodedObject, delta []byte) (err error) { } // PatchDelta returns the result of applying the modification deltas in delta to src. -// An error will be returned if delta is corrupted (ErrDeltaLen) or an action command +// An error will be returned if delta is corrupted (ErrInvalidDelta) or an action command // is not copy from source or copy from delta (ErrDeltaCmd). func PatchDelta(src, delta []byte) ([]byte, error) { + if len(src) == 0 || len(delta) < minDeltaSize { + return nil, ErrInvalidDelta + } + b := &bytes.Buffer{} if err := patchDelta(b, src, delta); err != nil { return nil, err @@ -239,7 +250,9 @@ func patchDelta(dst *bytes.Buffer, src, delta []byte) error { remainingTargetSz := targetSz var cmd byte - dst.Grow(int(targetSz)) + + growSz := min(targetSz, maxPatchPreemptionSize) + dst.Grow(int(growSz)) for { if len(delta) == 0 { return ErrInvalidDelta @@ -403,6 +416,10 @@ func patchDeltaWriter(dst io.Writer, base io.ReaderAt, delta io.Reader, // This must be called twice on the delta data buffer, first to get the // expected source buffer size, and again to get the target buffer size. func decodeLEB128(input []byte) (uint, []byte) { + if len(input) == 0 { + return 0, input + } + var num, sz uint var b byte for { diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/format/pktline/scanner.go b/vendor/github.com/go-git/go-git/v5/plumbing/format/pktline/scanner.go index fbb137de06..706d984ee0 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/format/pktline/scanner.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/format/pktline/scanner.go @@ -140,6 +140,8 @@ func asciiHexToByte(b byte) (byte, error) { return b - '0', nil case b >= 'a' && b <= 'f': return b - 'a' + 10, nil + case b >= 'A' && b <= 'F': + return b - 'A' + 10, nil default: return 0, ErrInvalidPktLen } diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/object/signature.go b/vendor/github.com/go-git/go-git/v5/plumbing/object/signature.go index 91cf371f0c..f9c3d306bd 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/object/signature.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/object/signature.go @@ -19,6 +19,7 @@ var ( // a PKCS#7 (S/MIME) signature. x509SignatureFormat = signatureFormat{ []byte("-----BEGIN CERTIFICATE-----"), + []byte("-----BEGIN SIGNED MESSAGE-----"), } // sshSignatureFormat is the format of an SSH signature. diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/object/tree.go b/vendor/github.com/go-git/go-git/v5/plumbing/object/tree.go index 0fd0e51398..2e1b789156 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/object/tree.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/object/tree.go @@ -295,6 +295,7 @@ func (s TreeEntrySorter) Swap(i, j int) { } // Encode transforms a Tree into a plumbing.EncodedObject. +// The tree entries must be sorted by name. func (t *Tree) Encode(o plumbing.EncodedObject) (err error) { o.SetType(plumbing.TreeObject) w, err := o.Writer() diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/filter.go b/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/filter.go new file mode 100644 index 0000000000..145fc711ca --- /dev/null +++ b/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/filter.go @@ -0,0 +1,76 @@ +package packp + +import ( + "errors" + "fmt" + "github.com/go-git/go-git/v5/plumbing" + "net/url" + "strings" +) + +var ErrUnsupportedObjectFilterType = errors.New("unsupported object filter type") + +// Filter values enable the partial clone capability which causes +// the server to omit objects that match the filter. +// +// See [Git's documentation] for more details. +// +// [Git's documentation]: https://github.com/git/git/blob/e02ecfcc534e2021aae29077a958dd11c3897e4c/Documentation/rev-list-options.txt#L948 +type Filter string + +type BlobLimitPrefix string + +const ( + BlobLimitPrefixNone BlobLimitPrefix = "" + BlobLimitPrefixKibi BlobLimitPrefix = "k" + BlobLimitPrefixMebi BlobLimitPrefix = "m" + BlobLimitPrefixGibi BlobLimitPrefix = "g" +) + +// FilterBlobNone omits all blobs. +func FilterBlobNone() Filter { + return "blob:none" +} + +// FilterBlobLimit omits blobs of size at least n bytes (when prefix is +// BlobLimitPrefixNone), n kibibytes (when prefix is BlobLimitPrefixKibi), +// n mebibytes (when prefix is BlobLimitPrefixMebi) or n gibibytes (when +// prefix is BlobLimitPrefixGibi). n can be zero, in which case all blobs +// will be omitted. +func FilterBlobLimit(n uint64, prefix BlobLimitPrefix) Filter { + return Filter(fmt.Sprintf("blob:limit=%d%s", n, prefix)) +} + +// FilterTreeDepth omits all blobs and trees whose depth from the root tree +// is larger or equal to depth. +func FilterTreeDepth(depth uint64) Filter { + return Filter(fmt.Sprintf("tree:%d", depth)) +} + +// FilterObjectType omits all objects which are not of the requested type t. +// Supported types are TagObject, CommitObject, TreeObject and BlobObject. +func FilterObjectType(t plumbing.ObjectType) (Filter, error) { + switch t { + case plumbing.TagObject: + fallthrough + case plumbing.CommitObject: + fallthrough + case plumbing.TreeObject: + fallthrough + case plumbing.BlobObject: + return Filter(fmt.Sprintf("object:type=%s", t.String())), nil + default: + return "", fmt.Errorf("%w: %s", ErrUnsupportedObjectFilterType, t.String()) + } +} + +// FilterCombine combines multiple Filter values together. +func FilterCombine(filters ...Filter) Filter { + var escapedFilters []string + + for _, filter := range filters { + escapedFilters = append(escapedFilters, url.QueryEscape(string(filter))) + } + + return Filter(fmt.Sprintf("combine:%s", strings.Join(escapedFilters, "+"))) +} diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/sideband/demux.go b/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/sideband/demux.go index 0116f962ef..01d95a3aba 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/sideband/demux.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/sideband/demux.go @@ -114,7 +114,7 @@ func (d *Demuxer) nextPackData() ([]byte, error) { size := len(content) if size == 0 { - return nil, nil + return nil, io.EOF } else if size > d.max { return nil, ErrMaxPackedExceeded } diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/srvresp.go b/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/srvresp.go index a9ddb538b2..d760ad6609 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/srvresp.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/srvresp.go @@ -120,6 +120,9 @@ func (r *ServerResponse) decodeACKLine(line []byte) error { } sp := bytes.Index(line, []byte(" ")) + if sp+41 > len(line) { + return fmt.Errorf("malformed ACK %q", line) + } h := plumbing.NewHash(string(line[sp+1 : sp+41])) r.ACKs = append(r.ACKs, h) return nil diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/ulreq.go b/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/ulreq.go index 344f8c7e3a..ef4e08a10a 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/ulreq.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/ulreq.go @@ -17,6 +17,7 @@ type UploadRequest struct { Wants []plumbing.Hash Shallows []plumbing.Hash Depth Depth + Filter Filter } // Depth values stores the desired depth of the requested packfile: see diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/ulreq_encode.go b/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/ulreq_encode.go index c451e23164..8b19c0f674 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/ulreq_encode.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/protocol/packp/ulreq_encode.go @@ -132,6 +132,17 @@ func (e *ulReqEncoder) encodeDepth() stateFn { return nil } + return e.encodeFilter +} + +func (e *ulReqEncoder) encodeFilter() stateFn { + if filter := e.data.Filter; filter != "" { + if err := e.pe.Encodef("filter %s\n", filter); err != nil { + e.err = fmt.Errorf("encoding filter %s: %s", filter, err) + return nil + } + } + return e.encodeFlush } diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/reference.go b/vendor/github.com/go-git/go-git/v5/plumbing/reference.go index ddba930292..4daa341649 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/reference.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/reference.go @@ -188,7 +188,7 @@ func (r ReferenceName) Validate() error { isBranch := r.IsBranch() isTag := r.IsTag() - for _, part := range parts { + for i, part := range parts { // rule 6 if len(part) == 0 { return ErrInvalidReferenceName @@ -205,7 +205,7 @@ func (r ReferenceName) Validate() error { return ErrInvalidReferenceName } - if (isBranch || isTag) && strings.HasPrefix(part, "-") { // branches & tags can't start with - + if (isBranch || isTag) && strings.HasPrefix(part, "-") && (i == 2) { // branches & tags can't start with - return ErrInvalidReferenceName } } diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/transport/common.go b/vendor/github.com/go-git/go-git/v5/plumbing/transport/common.go index b05437fbfc..fae1aa98ca 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/transport/common.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/transport/common.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "net/url" + "path/filepath" "strconv" "strings" @@ -295,7 +296,11 @@ func parseFile(endpoint string) (*Endpoint, bool) { return nil, false } - path := endpoint + path, err := filepath.Abs(endpoint) + if err != nil { + return nil, false + } + return &Endpoint{ Protocol: "file", Path: path, diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/transport/file/client.go b/vendor/github.com/go-git/go-git/v5/plumbing/transport/file/client.go index 38714e2ad1..d921d0a5a4 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/transport/file/client.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/transport/file/client.go @@ -7,6 +7,7 @@ import ( "io" "os" "path/filepath" + "runtime" "strings" "github.com/go-git/go-git/v5/plumbing/transport" @@ -95,7 +96,23 @@ func (r *runner) Command(cmd string, ep *transport.Endpoint, auth transport.Auth } } - return &command{cmd: execabs.Command(cmd, ep.Path)}, nil + return &command{cmd: execabs.Command(cmd, adjustPathForWindows(ep.Path))}, nil +} + +func isDriveLetter(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') +} + +// On Windows, the path that results from a file: URL has a leading slash. This +// has to be removed if there's a drive letter +func adjustPathForWindows(p string) string { + if runtime.GOOS != "windows" { + return p + } + if len(p) >= 3 && p[0] == '/' && isDriveLetter(p[1]) && p[2] == ':' { + return p[1:] + } + return p } type command struct { diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/transport/http/common.go b/vendor/github.com/go-git/go-git/v5/plumbing/transport/http/common.go index 1c4ceee68d..120008db1c 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/transport/http/common.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/transport/http/common.go @@ -430,11 +430,11 @@ func NewErr(r *http.Response) error { switch r.StatusCode { case http.StatusUnauthorized: - return transport.ErrAuthenticationRequired + return fmt.Errorf("%w: %s", transport.ErrAuthenticationRequired, reason) case http.StatusForbidden: - return transport.ErrAuthorizationFailed + return fmt.Errorf("%w: %s", transport.ErrAuthorizationFailed, reason) case http.StatusNotFound: - return transport.ErrRepositoryNotFound + return fmt.Errorf("%w: %s", transport.ErrRepositoryNotFound, reason) } return plumbing.NewUnexpectedError(&Err{r, reason}) diff --git a/vendor/github.com/go-git/go-git/v5/plumbing/transport/server/loader.go b/vendor/github.com/go-git/go-git/v5/plumbing/transport/server/loader.go index e7e2b075e5..f03a91c6d1 100644 --- a/vendor/github.com/go-git/go-git/v5/plumbing/transport/server/loader.go +++ b/vendor/github.com/go-git/go-git/v5/plumbing/transport/server/loader.go @@ -40,8 +40,16 @@ func (l *fsLoader) Load(ep *transport.Endpoint) (storer.Storer, error) { return nil, err } - if _, err := fs.Stat("config"); err != nil { - return nil, transport.ErrRepositoryNotFound + var bare bool + if _, err := fs.Stat("config"); err == nil { + bare = true + } + + if !bare { + // do not use git.GitDirName due to import cycle + if _, err := fs.Stat(".git"); err != nil { + return nil, transport.ErrRepositoryNotFound + } } return filesystem.NewStorage(fs, cache.NewObjectLRUDefault()), nil diff --git a/vendor/github.com/go-git/go-git/v5/remote.go b/vendor/github.com/go-git/go-git/v5/remote.go index 7cc0db9b7d..e2c734e751 100644 --- a/vendor/github.com/go-git/go-git/v5/remote.go +++ b/vendor/github.com/go-git/go-git/v5/remote.go @@ -9,6 +9,7 @@ import ( "time" "github.com/go-git/go-billy/v5/osfs" + "github.com/go-git/go-git/v5/config" "github.com/go-git/go-git/v5/internal/url" "github.com/go-git/go-git/v5/plumbing" @@ -82,7 +83,7 @@ func (r *Remote) String() string { var fetch, push string if len(r.c.URLs) > 0 { fetch = r.c.URLs[0] - push = r.c.URLs[0] + push = r.c.URLs[len(r.c.URLs)-1] } return fmt.Sprintf("%s\t%s (fetch)\n%[1]s\t%[3]s (push)", r.c.Name, fetch, push) @@ -109,8 +110,8 @@ func (r *Remote) PushContext(ctx context.Context, o *PushOptions) (err error) { return fmt.Errorf("remote names don't match: %s != %s", o.RemoteName, r.c.Name) } - if o.RemoteURL == "" { - o.RemoteURL = r.c.URLs[0] + if o.RemoteURL == "" && len(r.c.URLs) > 0 { + o.RemoteURL = r.c.URLs[len(r.c.URLs)-1] } s, err := newSendPackSession(o.RemoteURL, o.Auth, o.InsecureSkipTLS, o.CABundle, o.ProxyOptions) @@ -491,7 +492,18 @@ func (r *Remote) fetch(ctx context.Context, o *FetchOptions) (sto storer.Referen } if !updated && !updatedPrune { - return remoteRefs, NoErrAlreadyUpToDate + // No references updated, but may have fetched new objects, check if we now have any of our wants + for _, hash := range req.Wants { + exists, _ := objectExists(r.s, hash) + if exists { + updated = true + break + } + } + + if !updated { + return remoteRefs, NoErrAlreadyUpToDate + } } return remoteRefs, nil @@ -878,17 +890,12 @@ func getHavesFromRef( return nil } - // No need to load the commit if we know the remote already - // has this hash. - if remoteRefs[h] { - haves[h] = true - return nil - } - commit, err := object.GetCommit(s, h) if err != nil { - // Ignore the error if this isn't a commit. - haves[ref.Hash()] = true + if !errors.Is(err, plumbing.ErrObjectNotFound) { + // Ignore the error if this isn't a commit. + haves[ref.Hash()] = true + } return nil } diff --git a/vendor/github.com/go-git/go-git/v5/repository.go b/vendor/github.com/go-git/go-git/v5/repository.go index a57c7141f8..200098e7a0 100644 --- a/vendor/github.com/go-git/go-git/v5/repository.go +++ b/vendor/github.com/go-git/go-git/v5/repository.go @@ -956,7 +956,7 @@ func (r *Repository) clone(ctx context.Context, o *CloneOptions) error { } if o.RecurseSubmodules != NoRecurseSubmodules { - if err := w.updateSubmodules(&SubmoduleUpdateOptions{ + if err := w.updateSubmodules(ctx, &SubmoduleUpdateOptions{ RecurseSubmodules: o.RecurseSubmodules, Depth: func() int { if o.ShallowSubmodules { @@ -1037,7 +1037,7 @@ func (r *Repository) setIsBare(isBare bool) error { return r.Storer.SetConfig(cfg) } -func (r *Repository) updateRemoteConfigIfNeeded(o *CloneOptions, c *config.RemoteConfig, head *plumbing.Reference) error { +func (r *Repository) updateRemoteConfigIfNeeded(o *CloneOptions, c *config.RemoteConfig, _ *plumbing.Reference) error { if !o.SingleBranch { return nil } diff --git a/vendor/github.com/go-git/go-git/v5/status.go b/vendor/github.com/go-git/go-git/v5/status.go index 7f18e02278..d14f7e6572 100644 --- a/vendor/github.com/go-git/go-git/v5/status.go +++ b/vendor/github.com/go-git/go-git/v5/status.go @@ -4,6 +4,9 @@ import ( "bytes" "fmt" "path/filepath" + + mindex "github.com/go-git/go-git/v5/utils/merkletrie/index" + "github.com/go-git/go-git/v5/utils/merkletrie/noder" ) // Status represents the current status of a Worktree. @@ -77,3 +80,69 @@ const ( Copied StatusCode = 'C' UpdatedButUnmerged StatusCode = 'U' ) + +// StatusStrategy defines the different types of strategies when processing +// the worktree status. +type StatusStrategy int + +const ( + // TODO: (V6) Review the default status strategy. + // TODO: (V6) Review the type used to represent Status, to enable lazy + // processing of statuses going direct to the backing filesystem. + defaultStatusStrategy = Empty + + // Empty starts its status map from empty. Missing entries for a given + // path means that the file is untracked. This causes a known issue (#119) + // whereby unmodified files can be incorrectly reported as untracked. + // + // This can be used when returning the changed state within a modified Worktree. + // For example, to check whether the current worktree is clean. + Empty StatusStrategy = 0 + // Preload goes through all existing nodes from the index and add them to the + // status map as unmodified. This is currently the most reliable strategy + // although it comes at a performance cost in large repositories. + // + // This method is recommended when fetching the status of unmodified files. + // For example, to confirm the status of a specific file that is either + // untracked or unmodified. + Preload StatusStrategy = 1 +) + +func (s StatusStrategy) new(w *Worktree) (Status, error) { + switch s { + case Preload: + return preloadStatus(w) + case Empty: + return make(Status), nil + } + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedStatusStrategy, s) +} + +func preloadStatus(w *Worktree) (Status, error) { + idx, err := w.r.Storer.Index() + if err != nil { + return nil, err + } + + idxRoot := mindex.NewRootNode(idx) + nodes := []noder.Noder{idxRoot} + + status := make(Status) + for len(nodes) > 0 { + var node noder.Noder + node, nodes = nodes[0], nodes[1:] + if node.IsDir() { + children, err := node.Children() + if err != nil { + return nil, err + } + nodes = append(nodes, children...) + continue + } + fs := status.File(node.Name()) + fs.Worktree = Unmodified + fs.Staging = Unmodified + } + + return status, nil +} diff --git a/vendor/github.com/go-git/go-git/v5/storage/filesystem/dotgit/dotgit.go b/vendor/github.com/go-git/go-git/v5/storage/filesystem/dotgit/dotgit.go index 31c4694816..72c9ccfc14 100644 --- a/vendor/github.com/go-git/go-git/v5/storage/filesystem/dotgit/dotgit.go +++ b/vendor/github.com/go-git/go-git/v5/storage/filesystem/dotgit/dotgit.go @@ -72,6 +72,9 @@ var ( // ErrIsDir is returned when a reference file is attempting to be read, // but the path specified is a directory. ErrIsDir = errors.New("reference path is a directory") + // ErrEmptyRefFile is returned when a reference file is attempted to be read, + // but the file is empty + ErrEmptyRefFile = errors.New("ref file is empty") ) // Options holds configuration for the storage. @@ -249,7 +252,7 @@ func (d *DotGit) objectPacks() ([]plumbing.Hash, error) { continue } - h := plumbing.NewHash(n[5 : len(n)-5]) //pack-(hash).pack + h := plumbing.NewHash(n[5 : len(n)-5]) // pack-(hash).pack if h.IsZero() { // Ignore files with badly-formatted names. continue @@ -661,18 +664,33 @@ func (d *DotGit) readReferenceFrom(rd io.Reader, name string) (ref *plumbing.Ref return nil, err } + if len(b) == 0 { + return nil, ErrEmptyRefFile + } + line := strings.TrimSpace(string(b)) return plumbing.NewReferenceFromStrings(name, line), nil } +// checkReferenceAndTruncate reads the reference from the given file, or the `pack-refs` file if +// the file was empty. Then it checks that the old reference matches the stored reference and +// truncates the file. func (d *DotGit) checkReferenceAndTruncate(f billy.File, old *plumbing.Reference) error { if old == nil { return nil } + ref, err := d.readReferenceFrom(f, old.Name().String()) + if errors.Is(err, ErrEmptyRefFile) { + // This may happen if the reference is being read from a newly created file. + // In that case, try getting the reference from the packed refs file. + ref, err = d.packedRef(old.Name()) + } + if err != nil { return err } + if ref.Hash() != old.Hash() { return storage.ErrReferenceHasChanged } @@ -701,16 +719,16 @@ func (d *DotGit) SetRef(r, old *plumbing.Reference) error { // Symbolic references are resolved and included in the output. func (d *DotGit) Refs() ([]*plumbing.Reference, error) { var refs []*plumbing.Reference - var seen = make(map[plumbing.ReferenceName]bool) - if err := d.addRefsFromRefDir(&refs, seen); err != nil { + seen := make(map[plumbing.ReferenceName]bool) + if err := d.addRefFromHEAD(&refs); err != nil { return nil, err } - if err := d.addRefsFromPackedRefs(&refs, seen); err != nil { + if err := d.addRefsFromRefDir(&refs, seen); err != nil { return nil, err } - if err := d.addRefFromHEAD(&refs); err != nil { + if err := d.addRefsFromPackedRefs(&refs, seen); err != nil { return nil, err } @@ -815,7 +833,8 @@ func (d *DotGit) addRefsFromPackedRefsFile(refs *[]*plumbing.Reference, f billy. } func (d *DotGit) openAndLockPackedRefs(doCreate bool) ( - pr billy.File, err error) { + pr billy.File, err error, +) { var f billy.File defer func() { if err != nil && f != nil { @@ -1020,7 +1039,7 @@ func (d *DotGit) readReferenceFile(path, name string) (ref *plumbing.Reference, func (d *DotGit) CountLooseRefs() (int, error) { var refs []*plumbing.Reference - var seen = make(map[plumbing.ReferenceName]bool) + seen := make(map[plumbing.ReferenceName]bool) if err := d.addRefsFromRefDir(&refs, seen); err != nil { return 0, err } diff --git a/vendor/github.com/go-git/go-git/v5/storage/filesystem/index.go b/vendor/github.com/go-git/go-git/v5/storage/filesystem/index.go index a19176f83d..a86ef3e2e5 100644 --- a/vendor/github.com/go-git/go-git/v5/storage/filesystem/index.go +++ b/vendor/github.com/go-git/go-git/v5/storage/filesystem/index.go @@ -48,7 +48,7 @@ func (s *IndexStorage) Index() (i *index.Index, err error) { defer ioutil.CheckClose(f, &err) - d := index.NewDecoder(bufio.NewReader(f)) + d := index.NewDecoder(f) err = d.Decode(idx) return idx, err } diff --git a/vendor/github.com/go-git/go-git/v5/storage/filesystem/object.go b/vendor/github.com/go-git/go-git/v5/storage/filesystem/object.go index e812fe934d..91b4aceae1 100644 --- a/vendor/github.com/go-git/go-git/v5/storage/filesystem/object.go +++ b/vendor/github.com/go-git/go-git/v5/storage/filesystem/object.go @@ -431,13 +431,13 @@ func (s *ObjectStorage) getFromUnpacked(h plumbing.Hash) (obj plumbing.EncodedOb defer ioutil.CheckClose(w, &err) - s.objectCache.Put(obj) - bufp := copyBufferPool.Get().(*[]byte) buf := *bufp _, err = io.CopyBuffer(w, r, buf) copyBufferPool.Put(bufp) + s.objectCache.Put(obj) + return obj, err } diff --git a/vendor/github.com/go-git/go-git/v5/submodule.go b/vendor/github.com/go-git/go-git/v5/submodule.go index 84f020dc72..afabb6acad 100644 --- a/vendor/github.com/go-git/go-git/v5/submodule.go +++ b/vendor/github.com/go-git/go-git/v5/submodule.go @@ -214,10 +214,10 @@ func (s *Submodule) update(ctx context.Context, o *SubmoduleUpdateOptions, force return err } - return s.doRecursiveUpdate(r, o) + return s.doRecursiveUpdate(ctx, r, o) } -func (s *Submodule) doRecursiveUpdate(r *Repository, o *SubmoduleUpdateOptions) error { +func (s *Submodule) doRecursiveUpdate(ctx context.Context, r *Repository, o *SubmoduleUpdateOptions) error { if o.RecurseSubmodules == NoRecurseSubmodules { return nil } @@ -236,7 +236,7 @@ func (s *Submodule) doRecursiveUpdate(r *Repository, o *SubmoduleUpdateOptions) *new = *o new.RecurseSubmodules-- - return l.Update(new) + return l.UpdateContext(ctx, new) } func (s *Submodule) fetchAndCheckout( diff --git a/vendor/github.com/go-git/go-git/v5/utils/merkletrie/change.go b/vendor/github.com/go-git/go-git/v5/utils/merkletrie/change.go index cc6dc89071..450feb4bac 100644 --- a/vendor/github.com/go-git/go-git/v5/utils/merkletrie/change.go +++ b/vendor/github.com/go-git/go-git/v5/utils/merkletrie/change.go @@ -1,12 +1,17 @@ package merkletrie import ( + "errors" "fmt" "io" "github.com/go-git/go-git/v5/utils/merkletrie/noder" ) +var ( + ErrEmptyFileName = errors.New("empty filename in tree entry") +) + // Action values represent the kind of things a Change can represent: // insertion, deletions or modifications of files. type Action int @@ -121,6 +126,10 @@ func (l *Changes) AddRecursiveDelete(root noder.Path) error { type noderToChangeFn func(noder.Path) Change // NewInsert or NewDelete func (l *Changes) addRecursive(root noder.Path, ctor noderToChangeFn) error { + if root.String() == "" { + return ErrEmptyFileName + } + if !root.IsDir() { l.Add(ctor(root)) return nil diff --git a/vendor/github.com/go-git/go-git/v5/utils/merkletrie/difftree.go b/vendor/github.com/go-git/go-git/v5/utils/merkletrie/difftree.go index 8090942ddb..4ef2d9907a 100644 --- a/vendor/github.com/go-git/go-git/v5/utils/merkletrie/difftree.go +++ b/vendor/github.com/go-git/go-git/v5/utils/merkletrie/difftree.go @@ -11,7 +11,7 @@ package merkletrie // corresponding changes and move the iterators further over both // trees. // -// The table bellow show all the possible comparison results, along +// The table below shows all the possible comparison results, along // with what changes should we produce and how to advance the // iterators. // diff --git a/vendor/github.com/go-git/go-git/v5/utils/sync/bufio.go b/vendor/github.com/go-git/go-git/v5/utils/sync/bufio.go index 5009ea8047..42f60f7ea1 100644 --- a/vendor/github.com/go-git/go-git/v5/utils/sync/bufio.go +++ b/vendor/github.com/go-git/go-git/v5/utils/sync/bufio.go @@ -13,7 +13,7 @@ var bufioReader = sync.Pool{ } // GetBufioReader returns a *bufio.Reader that is managed by a sync.Pool. -// Returns a bufio.Reader that is resetted with reader and ready for use. +// Returns a bufio.Reader that is reset with reader and ready for use. // // After use, the *bufio.Reader should be put back into the sync.Pool // by calling PutBufioReader. diff --git a/vendor/github.com/go-git/go-git/v5/utils/sync/bytes.go b/vendor/github.com/go-git/go-git/v5/utils/sync/bytes.go index dd06fc0bc6..c67b978375 100644 --- a/vendor/github.com/go-git/go-git/v5/utils/sync/bytes.go +++ b/vendor/github.com/go-git/go-git/v5/utils/sync/bytes.go @@ -35,7 +35,7 @@ func PutByteSlice(buf *[]byte) { } // GetBytesBuffer returns a *bytes.Buffer that is managed by a sync.Pool. -// Returns a buffer that is resetted and ready for use. +// Returns a buffer that is reset and ready for use. // // After use, the *bytes.Buffer should be put back into the sync.Pool // by calling PutBytesBuffer. diff --git a/vendor/github.com/go-git/go-git/v5/utils/sync/zlib.go b/vendor/github.com/go-git/go-git/v5/utils/sync/zlib.go index c613885957..edf674d852 100644 --- a/vendor/github.com/go-git/go-git/v5/utils/sync/zlib.go +++ b/vendor/github.com/go-git/go-git/v5/utils/sync/zlib.go @@ -35,7 +35,7 @@ type ZLibReader struct { } // GetZlibReader returns a ZLibReader that is managed by a sync.Pool. -// Returns a ZLibReader that is resetted using a dictionary that is +// Returns a ZLibReader that is reset using a dictionary that is // also managed by a sync.Pool. // // After use, the ZLibReader should be put back into the sync.Pool @@ -58,7 +58,7 @@ func PutZlibReader(z ZLibReader) { } // GetZlibWriter returns a *zlib.Writer that is managed by a sync.Pool. -// Returns a writer that is resetted with w and ready for use. +// Returns a writer that is reset with w and ready for use. // // After use, the *zlib.Writer should be put back into the sync.Pool // by calling PutZlibWriter. diff --git a/vendor/github.com/go-git/go-git/v5/worktree.go b/vendor/github.com/go-git/go-git/v5/worktree.go index ab11d42db8..8dfa50b1b3 100644 --- a/vendor/github.com/go-git/go-git/v5/worktree.go +++ b/vendor/github.com/go-git/go-git/v5/worktree.go @@ -25,11 +25,12 @@ import ( ) var ( - ErrWorktreeNotClean = errors.New("worktree is not clean") - ErrSubmoduleNotFound = errors.New("submodule not found") - ErrUnstagedChanges = errors.New("worktree contains unstaged changes") - ErrGitModulesSymlink = errors.New(gitmodulesFile + " is a symlink") - ErrNonFastForwardUpdate = errors.New("non-fast-forward update") + ErrWorktreeNotClean = errors.New("worktree is not clean") + ErrSubmoduleNotFound = errors.New("submodule not found") + ErrUnstagedChanges = errors.New("worktree contains unstaged changes") + ErrGitModulesSymlink = errors.New(gitmodulesFile + " is a symlink") + ErrNonFastForwardUpdate = errors.New("non-fast-forward update") + ErrRestoreWorktreeOnlyNotSupported = errors.New("worktree only is not supported") ) // Worktree represents a git worktree. @@ -139,7 +140,7 @@ func (w *Worktree) PullContext(ctx context.Context, o *PullOptions) error { } if o.RecurseSubmodules != NoRecurseSubmodules { - return w.updateSubmodules(&SubmoduleUpdateOptions{ + return w.updateSubmodules(ctx, &SubmoduleUpdateOptions{ RecurseSubmodules: o.RecurseSubmodules, Auth: o.Auth, }) @@ -148,13 +149,13 @@ func (w *Worktree) PullContext(ctx context.Context, o *PullOptions) error { return nil } -func (w *Worktree) updateSubmodules(o *SubmoduleUpdateOptions) error { +func (w *Worktree) updateSubmodules(ctx context.Context, o *SubmoduleUpdateOptions) error { s, err := w.Submodules() if err != nil { return err } o.Init = true - return s.Update(o) + return s.UpdateContext(ctx, o) } // Checkout switch branches or restore working tree files. @@ -307,13 +308,13 @@ func (w *Worktree) ResetSparsely(opts *ResetOptions, dirs []string) error { } if opts.Mode == MixedReset || opts.Mode == MergeReset || opts.Mode == HardReset { - if err := w.resetIndex(t, dirs); err != nil { + if err := w.resetIndex(t, dirs, opts.Files); err != nil { return err } } if opts.Mode == MergeReset || opts.Mode == HardReset { - if err := w.resetWorktree(t); err != nil { + if err := w.resetWorktree(t, opts.Files); err != nil { return err } } @@ -321,20 +322,52 @@ func (w *Worktree) ResetSparsely(opts *ResetOptions, dirs []string) error { return nil } +// Restore restores specified files in the working tree or stage with contents from +// a restore source. If a path is tracked but does not exist in the restore, +// source, it will be removed to match the source. +// +// If Staged and Worktree are true, then the restore source will be the index. +// If only Staged is true, then the restore source will be HEAD. +// If only Worktree is true or neither Staged nor Worktree are true, will +// result in ErrRestoreWorktreeOnlyNotSupported because restoring the working +// tree while leaving the stage untouched is not currently supported. +// +// Restore with no files specified will return ErrNoRestorePaths. +func (w *Worktree) Restore(o *RestoreOptions) error { + if err := o.Validate(); err != nil { + return err + } + + if o.Staged { + opts := &ResetOptions{ + Files: o.Files, + } + + if o.Worktree { + // If we are doing both Worktree and Staging then it is a hard reset + opts.Mode = HardReset + } else { + // If we are doing just staging then it is a mixed reset + opts.Mode = MixedReset + } + + return w.Reset(opts) + } + + return ErrRestoreWorktreeOnlyNotSupported +} + // Reset the worktree to a specified state. func (w *Worktree) Reset(opts *ResetOptions) error { return w.ResetSparsely(opts, nil) } -func (w *Worktree) resetIndex(t *object.Tree, dirs []string) error { +func (w *Worktree) resetIndex(t *object.Tree, dirs []string, files []string) error { idx, err := w.r.Storer.Index() - if len(dirs) > 0 { - idx.SkipUnless(dirs) - } - if err != nil { return err } + b := newIndexBuilder(idx) changes, err := w.diffTreeWithStaging(t, true) @@ -362,6 +395,13 @@ func (w *Worktree) resetIndex(t *object.Tree, dirs []string) error { name = ch.From.String() } + if len(files) > 0 { + contains := inFiles(files, name) + if !contains { + continue + } + } + b.Remove(name) if e == nil { continue @@ -376,10 +416,25 @@ func (w *Worktree) resetIndex(t *object.Tree, dirs []string) error { } b.Write(idx) + + if len(dirs) > 0 { + idx.SkipUnless(dirs) + } + return w.r.Storer.SetIndex(idx) } -func (w *Worktree) resetWorktree(t *object.Tree) error { +func inFiles(files []string, v string) bool { + for _, s := range files { + if s == v { + return true + } + } + + return false +} + +func (w *Worktree) resetWorktree(t *object.Tree, files []string) error { changes, err := w.diffStagingWithWorktree(true, false) if err != nil { return err @@ -395,6 +450,25 @@ func (w *Worktree) resetWorktree(t *object.Tree) error { if err := w.validChange(ch); err != nil { return err } + + if len(files) > 0 { + file := "" + if ch.From != nil { + file = ch.From.String() + } else if ch.To != nil { + file = ch.To.String() + } + + if file == "" { + continue + } + + contains := inFiles(files, file) + if !contains { + continue + } + } + if err := w.checkoutChange(ch, t, b); err != nil { return err } @@ -642,7 +716,7 @@ func (w *Worktree) checkoutChangeRegularFile(name string, return err } - return w.addIndexFromFile(name, e.Hash, idx) + return w.addIndexFromFile(name, e.Hash, f.Mode, idx) } return nil @@ -725,18 +799,13 @@ func (w *Worktree) addIndexFromTreeEntry(name string, f *object.TreeEntry, idx * return nil } -func (w *Worktree) addIndexFromFile(name string, h plumbing.Hash, idx *indexBuilder) error { +func (w *Worktree) addIndexFromFile(name string, h plumbing.Hash, mode filemode.FileMode, idx *indexBuilder) error { idx.Remove(name) fi, err := w.Filesystem.Lstat(name) if err != nil { return err } - mode, err := filemode.NewFromOSFileMode(fi.Mode()) - if err != nil { - return err - } - e := &index.Entry{ Hash: h, Name: name, @@ -1058,7 +1127,7 @@ func rmFileAndDirsIfEmpty(fs billy.Filesystem, name string) error { dir := filepath.Dir(name) for { removed, err := removeDirIfEmpty(fs, dir) - if err != nil { + if err != nil && !os.IsNotExist(err) { return err } diff --git a/vendor/github.com/go-git/go-git/v5/worktree_commit.go b/vendor/github.com/go-git/go-git/v5/worktree_commit.go index f62054bcb4..9b1988ae6b 100644 --- a/vendor/github.com/go-git/go-git/v5/worktree_commit.go +++ b/vendor/github.com/go-git/go-git/v5/worktree_commit.go @@ -5,6 +5,7 @@ import ( "errors" "io" "path" + "regexp" "sort" "strings" @@ -23,6 +24,10 @@ var ( // ErrEmptyCommit occurs when a commit is attempted using a clean // working tree, with no changes to be committed. ErrEmptyCommit = errors.New("cannot create empty commit: clean working tree") + + // characters to be removed from user name and/or email before using them to build a commit object + // See https://git-scm.com/docs/git-commit#_commit_information + invalidCharactersRe = regexp.MustCompile(`[<>\n]`) ) // Commit stores the current contents of the index in a new commit along with @@ -38,8 +43,6 @@ func (w *Worktree) Commit(msg string, opts *CommitOptions) (plumbing.Hash, error } } - var treeHash plumbing.Hash - if opts.Amend { head, err := w.r.Head() if err != nil { @@ -61,16 +64,34 @@ func (w *Worktree) Commit(msg string, opts *CommitOptions) (plumbing.Hash, error return plumbing.ZeroHash, err } + // First handle the case of the first commit in the repository being empty. + if len(opts.Parents) == 0 && len(idx.Entries) == 0 && !opts.AllowEmptyCommits { + return plumbing.ZeroHash, ErrEmptyCommit + } + h := &buildTreeHelper{ fs: w.Filesystem, s: w.r.Storer, } - treeHash, err = h.BuildTree(idx, opts) + treeHash, err := h.BuildTree(idx, opts) if err != nil { return plumbing.ZeroHash, err } + previousTree := plumbing.ZeroHash + if len(opts.Parents) > 0 { + parentCommit, err := w.r.CommitObject(opts.Parents[0]) + if err != nil { + return plumbing.ZeroHash, err + } + previousTree = parentCommit.TreeHash + } + + if treeHash == previousTree && !opts.AllowEmptyCommits { + return plumbing.ZeroHash, ErrEmptyCommit + } + commit, err := w.buildCommitObject(msg, opts, treeHash) if err != nil { return plumbing.ZeroHash, err @@ -121,8 +142,8 @@ func (w *Worktree) updateHEAD(commit plumbing.Hash) error { func (w *Worktree) buildCommitObject(msg string, opts *CommitOptions, tree plumbing.Hash) (plumbing.Hash, error) { commit := &object.Commit{ - Author: *opts.Author, - Committer: *opts.Committer, + Author: w.sanitize(*opts.Author), + Committer: w.sanitize(*opts.Committer), Message: msg, TreeHash: tree, ParentHashes: opts.Parents, @@ -148,6 +169,14 @@ func (w *Worktree) buildCommitObject(msg string, opts *CommitOptions, tree plumb return w.r.Storer.SetEncodedObject(obj) } +func (w *Worktree) sanitize(signature object.Signature) object.Signature { + return object.Signature{ + Name: invalidCharactersRe.ReplaceAllString(signature.Name, ""), + Email: invalidCharactersRe.ReplaceAllString(signature.Email, ""), + When: signature.When, + } +} + type gpgSigner struct { key *openpgp.Entity cfg *packet.Config @@ -175,10 +204,6 @@ type buildTreeHelper struct { // BuildTree builds the tree objects and push its to the storer, the hash // of the root tree is returned. func (h *buildTreeHelper) BuildTree(idx *index.Index, opts *CommitOptions) (plumbing.Hash, error) { - if len(idx.Entries) == 0 && (opts == nil || !opts.AllowEmptyCommits) { - return plumbing.ZeroHash, ErrEmptyCommit - } - const rootNode = "" h.trees = map[string]*object.Tree{rootNode: {}} h.entries = map[string]*object.TreeEntry{} diff --git a/vendor/github.com/go-git/go-git/v5/worktree_linux.go b/vendor/github.com/go-git/go-git/v5/worktree_linux.go index 6fcace2f93..f6b85fe3df 100644 --- a/vendor/github.com/go-git/go-git/v5/worktree_linux.go +++ b/vendor/github.com/go-git/go-git/v5/worktree_linux.go @@ -1,3 +1,4 @@ +//go:build linux // +build linux package git @@ -21,6 +22,6 @@ func init() { } } -func isSymlinkWindowsNonAdmin(err error) bool { +func isSymlinkWindowsNonAdmin(_ error) bool { return false } diff --git a/vendor/github.com/go-git/go-git/v5/worktree_status.go b/vendor/github.com/go-git/go-git/v5/worktree_status.go index dd9b2439cf..6e72db9744 100644 --- a/vendor/github.com/go-git/go-git/v5/worktree_status.go +++ b/vendor/github.com/go-git/go-git/v5/worktree_status.go @@ -29,10 +29,23 @@ var ( // ErrGlobNoMatches in an AddGlob if the glob pattern does not match any // files in the worktree. ErrGlobNoMatches = errors.New("glob pattern did not match any files") + // ErrUnsupportedStatusStrategy occurs when an invalid StatusStrategy is used + // when processing the Worktree status. + ErrUnsupportedStatusStrategy = errors.New("unsupported status strategy") ) // Status returns the working tree status. func (w *Worktree) Status() (Status, error) { + return w.StatusWithOptions(StatusOptions{Strategy: defaultStatusStrategy}) +} + +// StatusOptions defines the options for Worktree.StatusWithOptions(). +type StatusOptions struct { + Strategy StatusStrategy +} + +// StatusWithOptions returns the working tree status. +func (w *Worktree) StatusWithOptions(o StatusOptions) (Status, error) { var hash plumbing.Hash ref, err := w.r.Head() @@ -44,11 +57,14 @@ func (w *Worktree) Status() (Status, error) { hash = ref.Hash() } - return w.status(hash) + return w.status(o.Strategy, hash) } -func (w *Worktree) status(commit plumbing.Hash) (Status, error) { - s := make(Status) +func (w *Worktree) status(ss StatusStrategy, commit plumbing.Hash) (Status, error) { + s, err := ss.new(w) + if err != nil { + return nil, err + } left, err := w.diffCommitWithStaging(commit, false) if err != nil { @@ -488,7 +504,7 @@ func (w *Worktree) copyFileToStorage(path string) (hash plumbing.Hash, err error return w.r.Storer.SetEncodedObject(obj) } -func (w *Worktree) fillEncodedObjectFromFile(dst io.Writer, path string, fi os.FileInfo) (err error) { +func (w *Worktree) fillEncodedObjectFromFile(dst io.Writer, path string, _ os.FileInfo) (err error) { src, err := w.Filesystem.Open(path) if err != nil { return err @@ -503,7 +519,7 @@ func (w *Worktree) fillEncodedObjectFromFile(dst io.Writer, path string, fi os.F return err } -func (w *Worktree) fillEncodedObjectFromSymlink(dst io.Writer, path string, fi os.FileInfo) error { +func (w *Worktree) fillEncodedObjectFromSymlink(dst io.Writer, path string, _ os.FileInfo) error { target, err := w.Filesystem.Readlink(path) if err != nil { return err @@ -543,9 +559,11 @@ func (w *Worktree) doUpdateFileToIndex(e *index.Entry, filename string, h plumbi return err } - if e.Mode.IsRegular() { - e.Size = uint32(info.Size()) - } + // The entry size must always reflect the current state, otherwise + // it will cause go-git's Worktree.Status() to divert from "git status". + // The size of a symlink is the length of the path to the target. + // The size of Regular and Executable files is the size of the files. + e.Size = uint32(info.Size()) fillSystemInfo(e, info.Sys()) return nil diff --git a/vendor/github.com/skeema/knownhosts/CONTRIBUTING.md b/vendor/github.com/skeema/knownhosts/CONTRIBUTING.md new file mode 100644 index 0000000000..9624f82760 --- /dev/null +++ b/vendor/github.com/skeema/knownhosts/CONTRIBUTING.md @@ -0,0 +1,36 @@ +# Contributing to skeema/knownhosts + +Thank you for your interest in contributing! This document provides guidelines for submitting pull requests. + +### Link to an issue + +Before starting the pull request process, initial discussion should take place on a GitHub issue first. For bug reports, the issue should track the open bug and confirm it is reproducible. For feature requests, the issue should cover why the feature is necessary. + +In the issue comments, discuss your suggested approach for a fix/implementation, and please wait to get feedback before opening a pull request. + +### Test coverage + +In general, please provide reasonably thorough test coverage. Whenever possible, your PR should aim to match or improve the overall test coverage percentage of the package. You can run tests and check coverage locally using `go test -cover`. We also have CI automation in GitHub Actions which will comment on each pull request with a coverage percentage. + +That said, it is fine to submit an initial draft / work-in-progress PR without coverage, if you are waiting on implementation feedback before writing the tests. + +We intentionally avoid hard-coding SSH keys or known_hosts files into the test logic. Instead, the tests generate new keys and then use them to generate a known_hosts file, which is then cached/reused for that overall test run, in order to keep performance reasonable. + +### Documentation + +Exported types require doc comments. The linter CI step will catch this if missing. + +### Backwards compatibility + +Because this package is imported by [nearly 7000 repos on GitHub](https://github.com/skeema/knownhosts/network/dependents), we must be very strict about backwards compatibility of exported symbols and function signatures. + +Backwards compatibility can be very tricky in some situations. In this case, a maintainer may need to add additional commits to your branch to adjust the approach. Please do not take offense if this occurs; it is sometimes simply faster to implement a refactor on our end directly. When the PR/branch is merged, a merge commit will be used, to ensure your commits appear as-is in the repo history and are still properly credited to you. + +### Avoid rewriting core x/crypto/ssh/knownhosts logic + +skeema/knownhosts is intended to be a relatively thin *wrapper* around x/crypto/ssh/knownhosts, without duplicating or re-implementing the core known_hosts file parsing and host key handling logic. Importers of this package should be confident that it can be used as a nearly-drop-in replacement for x/crypto/ssh/knownhosts without introducing substantial risk, security flaws, parser differentials, or unexpected behavior changes. + +To solve shortcomings in x/crypto/ssh/knownhosts, we try to come up with workarounds that still utilize x/crypto/ssh/knownhosts functionality whenever possible. + +Some bugs in x/crypto/ssh/knownhosts do require re-reading the known_hosts file here to solve, but we make that *optional* by offering separate constructors/types with and without that behavior. + diff --git a/vendor/github.com/skeema/knownhosts/README.md b/vendor/github.com/skeema/knownhosts/README.md index 36b847614c..046bc0edcb 100644 --- a/vendor/github.com/skeema/knownhosts/README.md +++ b/vendor/github.com/skeema/knownhosts/README.md @@ -1,31 +1,33 @@ # knownhosts: enhanced Golang SSH known_hosts management [![build status](https://img.shields.io/github/actions/workflow/status/skeema/knownhosts/tests.yml?branch=main)](https://github.com/skeema/knownhosts/actions) +[![code coverage](https://img.shields.io/coveralls/skeema/knownhosts.svg)](https://coveralls.io/r/skeema/knownhosts) [![godoc](https://img.shields.io/badge/godoc-reference-blue.svg)](https://pkg.go.dev/github.com/skeema/knownhosts) > This repo is brought to you by [Skeema](https://github.com/skeema/skeema), a > declarative pure-SQL schema management system for MySQL and MariaDB. Our -> premium products include extensive [SSH tunnel](https://www.skeema.io/docs/options/#ssh) +> premium products include extensive [SSH tunnel](https://www.skeema.io/docs/features/ssh/) > functionality, which internally makes use of this package. Go provides excellent functionality for OpenSSH known_hosts files in its external package [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts). -However, that package is somewhat low-level, making it difficult to implement full known_hosts management similar to command-line `ssh`'s behavior for `StrictHostKeyChecking=no` configuration. +However, that package is somewhat low-level, making it difficult to implement full known_hosts management similar to OpenSSH's command-line behavior. Additionally, [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts) has several known issues in edge cases, some of which have remained open for multiple years. -This repo ([github.com/skeema/knownhosts](https://github.com/skeema/knownhosts)) is a thin wrapper package around [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts), adding the following functionality: +Package [github.com/skeema/knownhosts](https://github.com/skeema/knownhosts) provides a *thin wrapper* around [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts), adding the following improvements and fixes without duplicating its core logic: * Look up known_hosts public keys for any given host -* Auto-populate ssh.ClientConfig.HostKeyAlgorithms easily based on known_hosts, providing a solution for [golang/go#29286](https://github.com/golang/go/issues/29286) +* Auto-populate ssh.ClientConfig.HostKeyAlgorithms easily based on known_hosts, providing a solution for [golang/go#29286](https://github.com/golang/go/issues/29286). (This also properly handles cert algorithms for hosts using CA keys when [using the NewDB constructor](#enhancements-requiring-extra-parsing) added in skeema/knownhosts v1.3.0.) +* Properly match wildcard hostname known_hosts entries regardless of port number, providing a solution for [golang/go#52056](https://github.com/golang/go/issues/52056). (Added in v1.3.0; requires [using the NewDB constructor](#enhancements-requiring-extra-parsing)) * Write new known_hosts entries to an io.Writer * Properly format/normalize new known_hosts entries containing ipv6 addresses, providing a solution for [golang/go#53463](https://github.com/golang/go/issues/53463) -* Determine if an ssh.HostKeyCallback's error corresponds to a host whose key has changed (indicating potential MitM attack) vs a host that just isn't known yet +* Easily determine if an ssh.HostKeyCallback's error corresponds to a host whose key has changed (indicating potential MitM attack) vs a host that just isn't known yet ## How host key lookup works Although [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts) doesn't directly expose a way to query its known_host map, we use a subtle trick to do so: invoke the HostKeyCallback with a valid host but a bogus key. The resulting KeyError allows us to determine which public keys are actually present for that host. -By using this technique, [github.com/skeema/knownhosts](https://github.com/skeema/knownhosts) doesn't need to duplicate or re-implement any of the actual known_hosts management from [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts). +By using this technique, [github.com/skeema/knownhosts](https://github.com/skeema/knownhosts) doesn't need to duplicate any of the core known_hosts host-lookup logic from [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts). ## Populating ssh.ClientConfig.HostKeyAlgorithms based on known_hosts @@ -42,20 +44,33 @@ import ( ) func sshConfigForHost(hostWithPort string) (*ssh.ClientConfig, error) { - kh, err := knownhosts.New("/home/myuser/.ssh/known_hosts") + kh, err := knownhosts.NewDB("/home/myuser/.ssh/known_hosts") if err != nil { return nil, err } config := &ssh.ClientConfig{ User: "myuser", Auth: []ssh.AuthMethod{ /* ... */ }, - HostKeyCallback: kh.HostKeyCallback(), // or, equivalently, use ssh.HostKeyCallback(kh) + HostKeyCallback: kh.HostKeyCallback(), HostKeyAlgorithms: kh.HostKeyAlgorithms(hostWithPort), } return config, nil } ``` +## Enhancements requiring extra parsing + +Originally, this package did not re-read/re-parse the known_hosts files at all, relying entirely on [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts) for all known_hosts file reading and processing. This package only offered a constructor called `New`, returning a host key callback, identical to the call pattern of [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts) but with extra methods available on the callback type. + +However, a couple shortcomings in [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts) cannot possibly be solved without re-reading the known_hosts file. Therefore, as of v1.3.0 of this package, we now offer an alternative constructor `NewDB`, which does an additional read of the known_hosts file (after the one from [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts)), in order to detect: + +* @cert-authority lines, so that we can correctly return cert key algorithms instead of normal host key algorithms when appropriate +* host pattern wildcards, so that we can match OpenSSH's behavior for non-standard port numbers, unlike how [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts) normally treats them + +Aside from *detecting* these special cases, this package otherwise still directly uses [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts) for host lookups and all other known_hosts file processing. We do **not** fork or re-implement those core behaviors of [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts). + +The performance impact of this extra known_hosts read should be minimal, as the file should typically be in the filesystem cache already from the original read by [golang.org/x/crypto/ssh/knownhosts](https://pkg.go.dev/golang.org/x/crypto/ssh/knownhosts). That said, users who wish to avoid the extra read can stay with the `New` constructor, which intentionally retains its pre-v1.3.0 behavior as-is. However, the extra fixes for @cert-authority and host pattern wildcards will not be enabled in that case. + ## Writing new known_hosts entries If you wish to mimic the behavior of OpenSSH's `StrictHostKeyChecking=no` or `StrictHostKeyChecking=ask`, this package provides a few functions to simplify this task. For example: @@ -63,7 +78,7 @@ If you wish to mimic the behavior of OpenSSH's `StrictHostKeyChecking=no` or `St ```golang sshHost := "yourserver.com:22" khPath := "/home/myuser/.ssh/known_hosts" -kh, err := knownhosts.New(khPath) +kh, err := knownhosts.NewDB(khPath) if err != nil { log.Fatal("Failed to read known_hosts: ", err) } @@ -71,7 +86,8 @@ if err != nil { // Create a custom permissive hostkey callback which still errors on hosts // with changed keys, but allows unknown hosts and adds them to known_hosts cb := ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { - err := kh(hostname, remote, key) + innerCallback := kh.HostKeyCallback() + err := innerCallback(hostname, remote, key) if knownhosts.IsHostKeyChanged(err) { return fmt.Errorf("REMOTE HOST IDENTIFICATION HAS CHANGED for host %s! This may indicate a MitM attack.", hostname) } else if knownhosts.IsHostUnknown(err) { diff --git a/vendor/github.com/skeema/knownhosts/knownhosts.go b/vendor/github.com/skeema/knownhosts/knownhosts.go index 4dad7771b8..2b7536e0da 100644 --- a/vendor/github.com/skeema/knownhosts/knownhosts.go +++ b/vendor/github.com/skeema/knownhosts/knownhosts.go @@ -3,11 +3,14 @@ package knownhosts import ( + "bufio" + "bytes" "encoding/base64" "errors" "fmt" "io" "net" + "os" "sort" "strings" @@ -15,23 +18,133 @@ import ( xknownhosts "golang.org/x/crypto/ssh/knownhosts" ) -// HostKeyCallback wraps ssh.HostKeyCallback with an additional method to -// perform host key algorithm lookups from the known_hosts entries. -type HostKeyCallback ssh.HostKeyCallback +// HostKeyDB wraps logic in golang.org/x/crypto/ssh/knownhosts with additional +// behaviors, such as the ability to perform host key/algorithm lookups from +// known_hosts entries. +type HostKeyDB struct { + callback ssh.HostKeyCallback + isCert map[string]bool // keyed by "filename:line" + isWildcard map[string]bool // keyed by "filename:line" +} -// New creates a host key callback from the given OpenSSH host key files. The -// returned value may be used in ssh.ClientConfig.HostKeyCallback by casting it -// to ssh.HostKeyCallback, or using its HostKeyCallback method. Otherwise, it -// operates the same as the New function in golang.org/x/crypto/ssh/knownhosts. -func New(files ...string) (HostKeyCallback, error) { +// NewDB creates a HostKeyDB from the given OpenSSH known_hosts file(s). It +// reads and parses the provided files one additional time (beyond logic in +// golang.org/x/crypto/ssh/knownhosts) in order to: +// +// - Handle CA lines properly and return ssh.CertAlgo* values when calling the +// HostKeyAlgorithms method, for use in ssh.ClientConfig.HostKeyAlgorithms +// - Allow * wildcards in hostnames to match on non-standard ports, providing +// a workaround for https://github.com/golang/go/issues/52056 in order to +// align with OpenSSH's wildcard behavior +// +// When supplying multiple files, their order does not matter. +func NewDB(files ...string) (*HostKeyDB, error) { cb, err := xknownhosts.New(files...) - return HostKeyCallback(cb), err + if err != nil { + return nil, err + } + hkdb := &HostKeyDB{ + callback: cb, + isCert: make(map[string]bool), + isWildcard: make(map[string]bool), + } + + // Re-read each file a single time, looking for @cert-authority lines. The + // logic for reading the file is designed to mimic hostKeyDB.Read from + // golang.org/x/crypto/ssh/knownhosts + for _, filename := range files { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + scanner := bufio.NewScanner(f) + lineNum := 0 + for scanner.Scan() { + lineNum++ + line := scanner.Bytes() + line = bytes.TrimSpace(line) + // Does the line start with "@cert-authority" followed by whitespace? + if len(line) > 15 && bytes.HasPrefix(line, []byte("@cert-authority")) && (line[15] == ' ' || line[15] == '\t') { + mapKey := fmt.Sprintf("%s:%d", filename, lineNum) + hkdb.isCert[mapKey] = true + line = bytes.TrimSpace(line[16:]) + } + // truncate line to just the host pattern field + if i := bytes.IndexAny(line, "\t "); i >= 0 { + line = line[:i] + } + // Does the host pattern contain a * wildcard and no specific port? + if i := bytes.IndexRune(line, '*'); i >= 0 && !bytes.Contains(line[i:], []byte("]:")) { + mapKey := fmt.Sprintf("%s:%d", filename, lineNum) + hkdb.isWildcard[mapKey] = true + } + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("knownhosts: %s:%d: %w", filename, lineNum, err) + } + } + return hkdb, nil } -// HostKeyCallback simply casts the receiver back to ssh.HostKeyCallback, for -// use in ssh.ClientConfig.HostKeyCallback. -func (hkcb HostKeyCallback) HostKeyCallback() ssh.HostKeyCallback { - return ssh.HostKeyCallback(hkcb) +// HostKeyCallback returns an ssh.HostKeyCallback. This can be used directly in +// ssh.ClientConfig.HostKeyCallback, as shown in the example for NewDB. +// Alternatively, you can wrap it with an outer callback to potentially handle +// appending a new entry to the known_hosts file; see example in WriteKnownHost. +func (hkdb *HostKeyDB) HostKeyCallback() ssh.HostKeyCallback { + // Either NewDB found no wildcard host patterns, or hkdb was created from + // HostKeyCallback.ToDB in which case we didn't scan known_hosts for them: + // return the callback (which came from x/crypto/ssh/knownhosts) as-is + if len(hkdb.isWildcard) == 0 { + return hkdb.callback + } + + // If we scanned for wildcards and found at least one, return a wrapped + // callback with extra behavior: if the host lookup found no matches, and the + // host arg had a non-standard port, re-do the lookup on standard port 22. If + // that second call returns a *xknownhosts.KeyError, filter down any resulting + // Want keys to known wildcard entries. + f := func(hostname string, remote net.Addr, key ssh.PublicKey) error { + callbackErr := hkdb.callback(hostname, remote, key) + if callbackErr == nil || IsHostKeyChanged(callbackErr) { // hostname has known_host entries as-is + return callbackErr + } + justHost, port, splitErr := net.SplitHostPort(hostname) + if splitErr != nil || port == "" || port == "22" { // hostname already using standard port + return callbackErr + } + // If we reach here, the port was non-standard and no known_host entries + // were found for the non-standard port. Try again with standard port. + if tcpAddr, ok := remote.(*net.TCPAddr); ok && tcpAddr.Port != 22 { + remote = &net.TCPAddr{ + IP: tcpAddr.IP, + Port: 22, + Zone: tcpAddr.Zone, + } + } + callbackErr = hkdb.callback(justHost+":22", remote, key) + var keyErr *xknownhosts.KeyError + if errors.As(callbackErr, &keyErr) && len(keyErr.Want) > 0 { + wildcardKeys := make([]xknownhosts.KnownKey, 0, len(keyErr.Want)) + for _, wantKey := range keyErr.Want { + if hkdb.isWildcard[fmt.Sprintf("%s:%d", wantKey.Filename, wantKey.Line)] { + wildcardKeys = append(wildcardKeys, wantKey) + } + } + callbackErr = &xknownhosts.KeyError{ + Want: wildcardKeys, + } + } + return callbackErr + } + return ssh.HostKeyCallback(f) +} + +// PublicKey wraps ssh.PublicKey with an additional field, to identify +// whether the key corresponds to a certificate authority. +type PublicKey struct { + ssh.PublicKey + Cert bool } // HostKeys returns a slice of known host public keys for the supplied host:port @@ -39,12 +152,16 @@ func (hkcb HostKeyCallback) HostKeyCallback() ssh.HostKeyCallback { // already known. For hosts that have multiple known_hosts entries (for // different key types), the result will be sorted by known_hosts filename and // line number. -func (hkcb HostKeyCallback) HostKeys(hostWithPort string) (keys []ssh.PublicKey) { +// If hkdb was originally created by calling NewDB, the Cert boolean field of +// each result entry reports whether the key corresponded to a @cert-authority +// line. If hkdb was NOT obtained from NewDB, then Cert will always be false. +func (hkdb *HostKeyDB) HostKeys(hostWithPort string) (keys []PublicKey) { var keyErr *xknownhosts.KeyError placeholderAddr := &net.TCPAddr{IP: []byte{0, 0, 0, 0}} placeholderPubKey := &fakePublicKey{} var kkeys []xknownhosts.KnownKey - if hkcbErr := hkcb(hostWithPort, placeholderAddr, placeholderPubKey); errors.As(hkcbErr, &keyErr) { + callback := hkdb.HostKeyCallback() + if hkcbErr := callback(hostWithPort, placeholderAddr, placeholderPubKey); errors.As(hkcbErr, &keyErr) { kkeys = append(kkeys, keyErr.Want...) knownKeyLess := func(i, j int) bool { if kkeys[i].Filename < kkeys[j].Filename { @@ -53,9 +170,14 @@ func (hkcb HostKeyCallback) HostKeys(hostWithPort string) (keys []ssh.PublicKey) return (kkeys[i].Filename == kkeys[j].Filename && kkeys[i].Line < kkeys[j].Line) } sort.Slice(kkeys, knownKeyLess) - keys = make([]ssh.PublicKey, len(kkeys)) + keys = make([]PublicKey, len(kkeys)) for n := range kkeys { - keys[n] = kkeys[n].Key + keys[n] = PublicKey{ + PublicKey: kkeys[n].Key, + } + if len(hkdb.isCert) > 0 { + keys[n].Cert = hkdb.isCert[fmt.Sprintf("%s:%d", kkeys[n].Filename, kkeys[n].Line)] + } } } return keys @@ -66,17 +188,23 @@ func (hkcb HostKeyCallback) HostKeys(hostWithPort string) (keys []ssh.PublicKey) // is not already known. The result may be used in ssh.ClientConfig's // HostKeyAlgorithms field, either as-is or after filtering (if you wish to // ignore or prefer particular algorithms). For hosts that have multiple -// known_hosts entries (for different key types), the result will be sorted by +// known_hosts entries (of different key types), the result will be sorted by // known_hosts filename and line number. -func (hkcb HostKeyCallback) HostKeyAlgorithms(hostWithPort string) (algos []string) { +// If hkdb was originally created by calling NewDB, any @cert-authority lines +// in the known_hosts file will properly be converted to the corresponding +// ssh.CertAlgo* values. +func (hkdb *HostKeyDB) HostKeyAlgorithms(hostWithPort string) (algos []string) { // We ensure that algos never contains duplicates. This is done for robustness // even though currently golang.org/x/crypto/ssh/knownhosts never exposes // multiple keys of the same type. This way our behavior here is unaffected // even if https://github.com/golang/go/issues/28870 is implemented, for // example by https://github.com/golang/crypto/pull/254. - hostKeys := hkcb.HostKeys(hostWithPort) + hostKeys := hkdb.HostKeys(hostWithPort) seen := make(map[string]struct{}, len(hostKeys)) - addAlgo := func(typ string) { + addAlgo := func(typ string, cert bool) { + if cert { + typ = keyTypeToCertAlgo(typ) + } if _, already := seen[typ]; !already { algos = append(algos, typ) seen[typ] = struct{}{} @@ -88,25 +216,143 @@ func (hkcb HostKeyCallback) HostKeyAlgorithms(hostWithPort string) (algos []stri // KeyAlgoRSASHA256 and KeyAlgoRSASHA512 are only public key algorithms, // not public key formats, so they can't appear as a PublicKey.Type. // The corresponding PublicKey.Type is KeyAlgoRSA. See RFC 8332, Section 2. - addAlgo(ssh.KeyAlgoRSASHA512) - addAlgo(ssh.KeyAlgoRSASHA256) + addAlgo(ssh.KeyAlgoRSASHA512, key.Cert) + addAlgo(ssh.KeyAlgoRSASHA256, key.Cert) } - addAlgo(typ) + addAlgo(typ, key.Cert) } return algos } +func keyTypeToCertAlgo(keyType string) string { + switch keyType { + case ssh.KeyAlgoRSA: + return ssh.CertAlgoRSAv01 + case ssh.KeyAlgoRSASHA256: + return ssh.CertAlgoRSASHA256v01 + case ssh.KeyAlgoRSASHA512: + return ssh.CertAlgoRSASHA512v01 + case ssh.KeyAlgoDSA: + return ssh.CertAlgoDSAv01 + case ssh.KeyAlgoECDSA256: + return ssh.CertAlgoECDSA256v01 + case ssh.KeyAlgoSKECDSA256: + return ssh.CertAlgoSKECDSA256v01 + case ssh.KeyAlgoECDSA384: + return ssh.CertAlgoECDSA384v01 + case ssh.KeyAlgoECDSA521: + return ssh.CertAlgoECDSA521v01 + case ssh.KeyAlgoED25519: + return ssh.CertAlgoED25519v01 + case ssh.KeyAlgoSKED25519: + return ssh.CertAlgoSKED25519v01 + } + return "" +} + +// HostKeyCallback wraps ssh.HostKeyCallback with additional methods to +// perform host key and algorithm lookups from the known_hosts entries. It is +// otherwise identical to ssh.HostKeyCallback, and does not introduce any file- +// parsing behavior beyond what is in golang.org/x/crypto/ssh/knownhosts. +// +// In most situations, use HostKeyDB and its constructor NewDB instead of using +// the HostKeyCallback type. The HostKeyCallback type is only provided for +// backwards compatibility with older versions of this package, as well as for +// very strict situations where any extra known_hosts file-parsing is +// undesirable. +// +// Methods of HostKeyCallback do not provide any special treatment for +// @cert-authority lines, which will (incorrectly) look like normal non-CA host +// keys. Additionally, HostKeyCallback lacks the fix for applying * wildcard +// known_host entries to all ports, like OpenSSH's behavior. +type HostKeyCallback ssh.HostKeyCallback + +// New creates a HostKeyCallback from the given OpenSSH known_hosts file(s). The +// returned value may be used in ssh.ClientConfig.HostKeyCallback by casting it +// to ssh.HostKeyCallback, or using its HostKeyCallback method. Otherwise, it +// operates the same as the New function in golang.org/x/crypto/ssh/knownhosts. +// When supplying multiple files, their order does not matter. +// +// In most situations, you should avoid this function, as the returned value +// lacks several enhanced behaviors. See doc comment for HostKeyCallback for +// more information. Instead, most callers should use NewDB to create a +// HostKeyDB, which includes these enhancements. +func New(files ...string) (HostKeyCallback, error) { + cb, err := xknownhosts.New(files...) + return HostKeyCallback(cb), err +} + +// HostKeyCallback simply casts the receiver back to ssh.HostKeyCallback, for +// use in ssh.ClientConfig.HostKeyCallback. +func (hkcb HostKeyCallback) HostKeyCallback() ssh.HostKeyCallback { + return ssh.HostKeyCallback(hkcb) +} + +// ToDB converts the receiver into a HostKeyDB. However, the returned HostKeyDB +// lacks the enhanced behaviors described in the doc comment for NewDB: proper +// CA support, and wildcard matching on nonstandard ports. +// +// It is generally preferable to create a HostKeyDB by using NewDB. The ToDB +// method is only provided for situations in which the calling code needs to +// make the extra NewDB behaviors optional / user-configurable, perhaps for +// reasons of performance or code trust (since NewDB reads the known_host file +// an extra time, which may be undesirable in some strict situations). This way, +// callers can conditionally create a non-enhanced HostKeyDB by using New and +// ToDB. See code example. +func (hkcb HostKeyCallback) ToDB() *HostKeyDB { + // This intentionally leaves the isCert and isWildcard map fields as nil, as + // there is no way to retroactively populate them from just a HostKeyCallback. + // Methods of HostKeyDB will skip any related enhanced behaviors accordingly. + return &HostKeyDB{callback: ssh.HostKeyCallback(hkcb)} +} + +// HostKeys returns a slice of known host public keys for the supplied host:port +// found in the known_hosts file(s), or an empty slice if the host is not +// already known. For hosts that have multiple known_hosts entries (for +// different key types), the result will be sorted by known_hosts filename and +// line number. +// In the returned values, there is no way to distinguish between CA keys +// (known_hosts lines beginning with @cert-authority) and regular keys. To do +// so, see NewDB and HostKeyDB.HostKeys instead. +func (hkcb HostKeyCallback) HostKeys(hostWithPort string) []ssh.PublicKey { + annotatedKeys := hkcb.ToDB().HostKeys(hostWithPort) + rawKeys := make([]ssh.PublicKey, len(annotatedKeys)) + for n, ak := range annotatedKeys { + rawKeys[n] = ak.PublicKey + } + return rawKeys +} + +// HostKeyAlgorithms returns a slice of host key algorithms for the supplied +// host:port found in the known_hosts file(s), or an empty slice if the host +// is not already known. The result may be used in ssh.ClientConfig's +// HostKeyAlgorithms field, either as-is or after filtering (if you wish to +// ignore or prefer particular algorithms). For hosts that have multiple +// known_hosts entries (for different key types), the result will be sorted by +// known_hosts filename and line number. +// The returned values will not include ssh.CertAlgo* values. If any +// known_hosts lines had @cert-authority prefixes, their original key algo will +// be returned instead. For proper CA support, see NewDB and +// HostKeyDB.HostKeyAlgorithms instead. +func (hkcb HostKeyCallback) HostKeyAlgorithms(hostWithPort string) (algos []string) { + return hkcb.ToDB().HostKeyAlgorithms(hostWithPort) +} + // HostKeyAlgorithms is a convenience function for performing host key algorithm // lookups on an ssh.HostKeyCallback directly. It is intended for use in code // paths that stay with the New method of golang.org/x/crypto/ssh/knownhosts -// rather than this package's New method. +// rather than this package's New or NewDB methods. +// The returned values will not include ssh.CertAlgo* values. If any +// known_hosts lines had @cert-authority prefixes, their original key algo will +// be returned instead. For proper CA support, see NewDB and +// HostKeyDB.HostKeyAlgorithms instead. func HostKeyAlgorithms(cb ssh.HostKeyCallback, hostWithPort string) []string { return HostKeyCallback(cb).HostKeyAlgorithms(hostWithPort) } // IsHostKeyChanged returns a boolean indicating whether the error indicates // the host key has changed. It is intended to be called on the error returned -// from invoking a HostKeyCallback to check whether an SSH host is known. +// from invoking a host key callback, to check whether an SSH host is known. func IsHostKeyChanged(err error) bool { var keyErr *xknownhosts.KeyError return errors.As(err, &keyErr) && len(keyErr.Want) > 0 @@ -114,7 +360,7 @@ func IsHostKeyChanged(err error) bool { // IsHostUnknown returns a boolean indicating whether the error represents an // unknown host. It is intended to be called on the error returned from invoking -// a HostKeyCallback to check whether an SSH host is known. +// a host key callback to check whether an SSH host is known. func IsHostUnknown(err error) bool { var keyErr *xknownhosts.KeyError return errors.As(err, &keyErr) && len(keyErr.Want) == 0 @@ -154,11 +400,12 @@ func Line(addresses []string, key ssh.PublicKey) string { }, " ") } -// WriteKnownHost writes a known_hosts line to writer for the supplied hostname, +// WriteKnownHost writes a known_hosts line to w for the supplied hostname, // remote, and key. This is useful when writing a custom hostkey callback which -// wraps a callback obtained from knownhosts.New to provide additional -// known_hosts management functionality. The hostname, remote, and key typically -// correspond to the callback's args. +// wraps a callback obtained from this package to provide additional known_hosts +// management functionality. The hostname, remote, and key typically correspond +// to the callback's args. This function does not support writing +// @cert-authority lines. func WriteKnownHost(w io.Writer, hostname string, remote net.Addr, key ssh.PublicKey) error { // Always include hostname; only also include remote if it isn't a zero value // and doesn't normalize to the same string as hostname. @@ -177,6 +424,14 @@ func WriteKnownHost(w io.Writer, hostname string, remote net.Addr, key ssh.Publi return err } +// WriteKnownHostCA writes a @cert-authority line to w for the supplied host +// name/pattern and key. +func WriteKnownHostCA(w io.Writer, hostPattern string, key ssh.PublicKey) error { + encodedKey := base64.StdEncoding.EncodeToString(key.Marshal()) + _, err := fmt.Fprintf(w, "@cert-authority %s %s %s\n", hostPattern, key.Type(), encodedKey) + return err +} + // fakePublicKey is used as part of the work-around for // https://github.com/golang/go/issues/29286 type fakePublicKey struct{} diff --git a/vendor/github.com/stretchr/testify/assert/assertion_compare.go b/vendor/github.com/stretchr/testify/assert/assertion_compare.go index 4d4b4aad6f..7e19eba090 100644 --- a/vendor/github.com/stretchr/testify/assert/assertion_compare.go +++ b/vendor/github.com/stretchr/testify/assert/assertion_compare.go @@ -7,10 +7,13 @@ import ( "time" ) -type CompareType int +// Deprecated: CompareType has only ever been for internal use and has accidentally been published since v1.6.0. Do not use it. +type CompareType = compareResult + +type compareResult int const ( - compareLess CompareType = iota - 1 + compareLess compareResult = iota - 1 compareEqual compareGreater ) @@ -39,7 +42,7 @@ var ( bytesType = reflect.TypeOf([]byte{}) ) -func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { +func compare(obj1, obj2 interface{}, kind reflect.Kind) (compareResult, bool) { obj1Value := reflect.ValueOf(obj1) obj2Value := reflect.ValueOf(obj2) @@ -325,7 +328,13 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time) } - return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64) + if timeObj1.Before(timeObj2) { + return compareLess, true + } + if timeObj1.Equal(timeObj2) { + return compareEqual, true + } + return compareGreater, true } case reflect.Slice: { @@ -345,7 +354,7 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte) } - return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true + return compareResult(bytes.Compare(bytesObj1, bytesObj2)), true } case reflect.Uintptr: { @@ -381,7 +390,7 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface if h, ok := t.(tHelper); ok { h.Helper() } - return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...) + return compareTwoValues(t, e1, e2, []compareResult{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...) } // GreaterOrEqual asserts that the first element is greater than or equal to the second @@ -394,7 +403,7 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in if h, ok := t.(tHelper); ok { h.Helper() } - return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...) + return compareTwoValues(t, e1, e2, []compareResult{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...) } // Less asserts that the first element is less than the second @@ -406,7 +415,7 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) if h, ok := t.(tHelper); ok { h.Helper() } - return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...) + return compareTwoValues(t, e1, e2, []compareResult{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...) } // LessOrEqual asserts that the first element is less than or equal to the second @@ -419,7 +428,7 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter if h, ok := t.(tHelper); ok { h.Helper() } - return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...) + return compareTwoValues(t, e1, e2, []compareResult{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...) } // Positive asserts that the specified element is positive @@ -431,7 +440,7 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool { h.Helper() } zero := reflect.Zero(reflect.TypeOf(e)) - return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...) + return compareTwoValues(t, e, zero.Interface(), []compareResult{compareGreater}, "\"%v\" is not positive", msgAndArgs...) } // Negative asserts that the specified element is negative @@ -443,10 +452,10 @@ func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool { h.Helper() } zero := reflect.Zero(reflect.TypeOf(e)) - return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...) + return compareTwoValues(t, e, zero.Interface(), []compareResult{compareLess}, "\"%v\" is not negative", msgAndArgs...) } -func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool { +func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []compareResult, failMessage string, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() } @@ -469,7 +478,7 @@ func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedCompare return true } -func containsValue(values []CompareType, value CompareType) bool { +func containsValue(values []compareResult, value compareResult) bool { for _, v := range values { if v == value { return true diff --git a/vendor/github.com/stretchr/testify/assert/assertion_format.go b/vendor/github.com/stretchr/testify/assert/assertion_format.go index 3ddab109ad..1906341657 100644 --- a/vendor/github.com/stretchr/testify/assert/assertion_format.go +++ b/vendor/github.com/stretchr/testify/assert/assertion_format.go @@ -104,8 +104,8 @@ func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{}, return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...) } -// EqualValuesf asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValuesf asserts that two objects are equal or convertible to the larger +// type and equal. // // assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { @@ -186,7 +186,7 @@ func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick // assert.EventuallyWithTf(t, func(c *assert.CollectT, "error message %s", "formatted") { // // add assertions as needed; any assertion failure will fail the current tick // assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func EventuallyWithTf(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -568,6 +568,23 @@ func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, a return NotContains(t, s, contains, append([]interface{}{msg}, args...)...) } +// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// assert.NotElementsMatchf(t, [1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false +// +// assert.NotElementsMatchf(t, [1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true +// +// assert.NotElementsMatchf(t, [1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true +func NotElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...) +} + // NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // @@ -604,7 +621,16 @@ func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg s return NotEqualValues(t, expected, actual, append([]interface{}{msg}, args...)...) } -// NotErrorIsf asserts that at none of the errors in err's chain matches target. +// NotErrorAsf asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func NotErrorAsf(t TestingT, err error, target interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotErrorAs(t, err, target, append([]interface{}{msg}, args...)...) +} + +// NotErrorIsf asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { diff --git a/vendor/github.com/stretchr/testify/assert/assertion_forward.go b/vendor/github.com/stretchr/testify/assert/assertion_forward.go index a84e09bd40..21629087ba 100644 --- a/vendor/github.com/stretchr/testify/assert/assertion_forward.go +++ b/vendor/github.com/stretchr/testify/assert/assertion_forward.go @@ -186,8 +186,8 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface return EqualExportedValuesf(a.t, expected, actual, msg, args...) } -// EqualValues asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValues asserts that two objects are equal or convertible to the larger +// type and equal. // // a.EqualValues(uint32(123), int32(123)) func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { @@ -197,8 +197,8 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn return EqualValues(a.t, expected, actual, msgAndArgs...) } -// EqualValuesf asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValuesf asserts that two objects are equal or convertible to the larger +// type and equal. // // a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { @@ -336,7 +336,7 @@ func (a *Assertions) Eventually(condition func() bool, waitFor time.Duration, ti // a.EventuallyWithT(func(c *assert.CollectT) { // // add assertions as needed; any assertion failure will fail the current tick // assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func (a *Assertions) EventuallyWithT(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -361,7 +361,7 @@ func (a *Assertions) EventuallyWithT(condition func(collect *CollectT), waitFor // a.EventuallyWithTf(func(c *assert.CollectT, "error message %s", "formatted") { // // add assertions as needed; any assertion failure will fail the current tick // assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func (a *Assertions) EventuallyWithTf(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1128,6 +1128,40 @@ func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg strin return NotContainsf(a.t, s, contains, msg, args...) } +// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// a.NotElementsMatch([1, 1, 2, 3], [1, 1, 2, 3]) -> false +// +// a.NotElementsMatch([1, 1, 2, 3], [1, 2, 3]) -> true +// +// a.NotElementsMatch([1, 2, 3], [1, 2, 4]) -> true +func (a *Assertions) NotElementsMatch(listA interface{}, listB interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotElementsMatch(a.t, listA, listB, msgAndArgs...) +} + +// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// a.NotElementsMatchf([1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false +// +// a.NotElementsMatchf([1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true +// +// a.NotElementsMatchf([1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true +func (a *Assertions) NotElementsMatchf(listA interface{}, listB interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotElementsMatchf(a.t, listA, listB, msg, args...) +} + // NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // @@ -1200,7 +1234,25 @@ func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg str return NotEqualf(a.t, expected, actual, msg, args...) } -// NotErrorIs asserts that at none of the errors in err's chain matches target. +// NotErrorAs asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func (a *Assertions) NotErrorAs(err error, target interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotErrorAs(a.t, err, target, msgAndArgs...) +} + +// NotErrorAsf asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func (a *Assertions) NotErrorAsf(err error, target interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotErrorAsf(a.t, err, target, msg, args...) +} + +// NotErrorIs asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { @@ -1209,7 +1261,7 @@ func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface return NotErrorIs(a.t, err, target, msgAndArgs...) } -// NotErrorIsf asserts that at none of the errors in err's chain matches target. +// NotErrorIsf asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { diff --git a/vendor/github.com/stretchr/testify/assert/assertion_order.go b/vendor/github.com/stretchr/testify/assert/assertion_order.go index 00df62a059..1d2f71824a 100644 --- a/vendor/github.com/stretchr/testify/assert/assertion_order.go +++ b/vendor/github.com/stretchr/testify/assert/assertion_order.go @@ -6,7 +6,7 @@ import ( ) // isOrdered checks that collection contains orderable elements. -func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool { +func isOrdered(t TestingT, object interface{}, allowedComparesResults []compareResult, failMessage string, msgAndArgs ...interface{}) bool { objKind := reflect.TypeOf(object).Kind() if objKind != reflect.Slice && objKind != reflect.Array { return false @@ -50,7 +50,7 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT // assert.IsIncreasing(t, []float{1, 2}) // assert.IsIncreasing(t, []string{"a", "b"}) func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { - return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...) + return isOrdered(t, object, []compareResult{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...) } // IsNonIncreasing asserts that the collection is not increasing @@ -59,7 +59,7 @@ func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo // assert.IsNonIncreasing(t, []float{2, 1}) // assert.IsNonIncreasing(t, []string{"b", "a"}) func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { - return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...) + return isOrdered(t, object, []compareResult{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...) } // IsDecreasing asserts that the collection is decreasing @@ -68,7 +68,7 @@ func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) // assert.IsDecreasing(t, []float{2, 1}) // assert.IsDecreasing(t, []string{"b", "a"}) func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { - return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...) + return isOrdered(t, object, []compareResult{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...) } // IsNonDecreasing asserts that the collection is not decreasing @@ -77,5 +77,5 @@ func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo // assert.IsNonDecreasing(t, []float{1, 2}) // assert.IsNonDecreasing(t, []string{"a", "b"}) func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { - return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...) + return isOrdered(t, object, []compareResult{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...) } diff --git a/vendor/github.com/stretchr/testify/assert/assertions.go b/vendor/github.com/stretchr/testify/assert/assertions.go index 0b7570f21c..4e91332bb5 100644 --- a/vendor/github.com/stretchr/testify/assert/assertions.go +++ b/vendor/github.com/stretchr/testify/assert/assertions.go @@ -19,7 +19,9 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/pmezard/go-difflib/difflib" - "gopkg.in/yaml.v3" + + // Wrapper around gopkg.in/yaml.v3 + "github.com/stretchr/testify/assert/yaml" ) //go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl" @@ -45,6 +47,10 @@ type BoolAssertionFunc func(TestingT, bool, ...interface{}) bool // for table driven tests. type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool +// PanicAssertionFunc is a common function prototype when validating a panic value. Can be useful +// for table driven tests. +type PanicAssertionFunc = func(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool + // Comparison is a custom function that returns true on success and false on failure type Comparison func() (success bool) @@ -496,7 +502,13 @@ func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) b h.Helper() } - if !samePointers(expected, actual) { + same, ok := samePointers(expected, actual) + if !ok { + return Fail(t, "Both arguments must be pointers", msgAndArgs...) + } + + if !same { + // both are pointers but not the same type & pointing to the same address return Fail(t, fmt.Sprintf("Not same: \n"+ "expected: %p %#v\n"+ "actual : %p %#v", expected, expected, actual, actual), msgAndArgs...) @@ -516,7 +528,13 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{} h.Helper() } - if samePointers(expected, actual) { + same, ok := samePointers(expected, actual) + if !ok { + //fails when the arguments are not pointers + return !(Fail(t, "Both arguments must be pointers", msgAndArgs...)) + } + + if same { return Fail(t, fmt.Sprintf( "Expected and actual point to the same object: %p %#v", expected, expected), msgAndArgs...) @@ -524,21 +542,23 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{} return true } -// samePointers compares two generic interface objects and returns whether -// they point to the same object -func samePointers(first, second interface{}) bool { +// samePointers checks if two generic interface objects are pointers of the same +// type pointing to the same object. It returns two values: same indicating if +// they are the same type and point to the same object, and ok indicating that +// both inputs are pointers. +func samePointers(first, second interface{}) (same bool, ok bool) { firstPtr, secondPtr := reflect.ValueOf(first), reflect.ValueOf(second) if firstPtr.Kind() != reflect.Ptr || secondPtr.Kind() != reflect.Ptr { - return false + return false, false //not both are pointers } firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second) if firstType != secondType { - return false + return false, true // both are pointers, but of different types } // compare pointer addresses - return first == second + return first == second, true } // formatUnequalValues takes two values of arbitrary types and returns string @@ -572,8 +592,8 @@ func truncatingFormat(data interface{}) string { return value } -// EqualValues asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValues asserts that two objects are equal or convertible to the larger +// type and equal. // // assert.EqualValues(t, uint32(123), int32(123)) func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { @@ -615,21 +635,6 @@ func EqualExportedValues(t TestingT, expected, actual interface{}, msgAndArgs .. return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...) } - if aType.Kind() == reflect.Ptr { - aType = aType.Elem() - } - if bType.Kind() == reflect.Ptr { - bType = bType.Elem() - } - - if aType.Kind() != reflect.Struct { - return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...) - } - - if bType.Kind() != reflect.Struct { - return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...) - } - expected = copyExportedFields(expected) actual = copyExportedFields(actual) @@ -1170,6 +1175,39 @@ func formatListDiff(listA, listB interface{}, extraA, extraB []interface{}) stri return msg.String() } +// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// assert.NotElementsMatch(t, [1, 1, 2, 3], [1, 1, 2, 3]) -> false +// +// assert.NotElementsMatch(t, [1, 1, 2, 3], [1, 2, 3]) -> true +// +// assert.NotElementsMatch(t, [1, 2, 3], [1, 2, 4]) -> true +func NotElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if isEmpty(listA) && isEmpty(listB) { + return Fail(t, "listA and listB contain the same elements", msgAndArgs) + } + + if !isList(t, listA, msgAndArgs...) { + return Fail(t, "listA is not a list type", msgAndArgs...) + } + if !isList(t, listB, msgAndArgs...) { + return Fail(t, "listB is not a list type", msgAndArgs...) + } + + extraA, extraB := diffLists(listA, listB) + if len(extraA) == 0 && len(extraB) == 0 { + return Fail(t, "listA and listB contain the same elements", msgAndArgs) + } + + return true +} + // Condition uses a Comparison to assert a complex condition. func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { @@ -1488,6 +1526,9 @@ func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAnd if err != nil { return Fail(t, err.Error(), msgAndArgs...) } + if math.IsNaN(actualEpsilon) { + return Fail(t, "relative error is NaN", msgAndArgs...) + } if actualEpsilon > epsilon { return Fail(t, fmt.Sprintf("Relative error is too high: %#v (expected)\n"+ " < %#v (actual)", epsilon, actualEpsilon), msgAndArgs...) @@ -1611,7 +1652,6 @@ func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...in // matchRegexp return true if a specified regexp matches a string. func matchRegexp(rx interface{}, str interface{}) bool { - var r *regexp.Regexp if rr, ok := rx.(*regexp.Regexp); ok { r = rr @@ -1619,7 +1659,14 @@ func matchRegexp(rx interface{}, str interface{}) bool { r = regexp.MustCompile(fmt.Sprint(rx)) } - return (r.FindStringIndex(fmt.Sprint(str)) != nil) + switch v := str.(type) { + case []byte: + return r.Match(v) + case string: + return r.MatchString(v) + default: + return r.MatchString(fmt.Sprint(v)) + } } @@ -1872,7 +1919,7 @@ var spewConfigStringerEnabled = spew.ConfigState{ MaxDepth: 10, } -type tHelper interface { +type tHelper = interface { Helper() } @@ -1911,6 +1958,9 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t // CollectT implements the TestingT interface and collects all errors. type CollectT struct { + // A slice of errors. Non-nil slice denotes a failure. + // If it's non-nil but len(c.errors) == 0, this is also a failure + // obtained by direct c.FailNow() call. errors []error } @@ -1919,9 +1969,10 @@ func (c *CollectT) Errorf(format string, args ...interface{}) { c.errors = append(c.errors, fmt.Errorf(format, args...)) } -// FailNow panics. -func (*CollectT) FailNow() { - panic("Assertion failed") +// FailNow stops execution by calling runtime.Goexit. +func (c *CollectT) FailNow() { + c.fail() + runtime.Goexit() } // Deprecated: That was a method for internal usage that should not have been published. Now just panics. @@ -1934,6 +1985,16 @@ func (*CollectT) Copy(TestingT) { panic("Copy() is deprecated") } +func (c *CollectT) fail() { + if !c.failed() { + c.errors = []error{} // Make it non-nil to mark a failure. + } +} + +func (c *CollectT) failed() bool { + return c.errors != nil +} + // EventuallyWithT asserts that given condition will be met in waitFor time, // periodically checking target function each tick. In contrast to Eventually, // it supplies a CollectT to the condition function, so that the condition @@ -1951,14 +2012,14 @@ func (*CollectT) Copy(TestingT) { // assert.EventuallyWithT(t, func(c *assert.CollectT) { // // add assertions as needed; any assertion failure will fail the current tick // assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() } var lastFinishedTickErrs []error - ch := make(chan []error, 1) + ch := make(chan *CollectT, 1) timer := time.NewTimer(waitFor) defer timer.Stop() @@ -1978,16 +2039,16 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time go func() { collect := new(CollectT) defer func() { - ch <- collect.errors + ch <- collect }() condition(collect) }() - case errs := <-ch: - if len(errs) == 0 { + case collect := <-ch: + if !collect.failed() { return true } // Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached. - lastFinishedTickErrs = errs + lastFinishedTickErrs = collect.errors tick = ticker.C } } @@ -2049,7 +2110,7 @@ func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool { ), msgAndArgs...) } -// NotErrorIs asserts that at none of the errors in err's chain matches target. +// NotErrorIs asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func NotErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { @@ -2090,6 +2151,24 @@ func ErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{ ), msgAndArgs...) } +// NotErrorAs asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func NotErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if !errors.As(err, target) { + return true + } + + chain := buildErrorChainString(err) + + return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+ + "found: %q\n"+ + "in chain: %s", target, chain, + ), msgAndArgs...) +} + func buildErrorChainString(err error) string { if err == nil { return "" diff --git a/vendor/github.com/stretchr/testify/assert/yaml/yaml_custom.go b/vendor/github.com/stretchr/testify/assert/yaml/yaml_custom.go new file mode 100644 index 0000000000..baa0cc7d7f --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/yaml/yaml_custom.go @@ -0,0 +1,25 @@ +//go:build testify_yaml_custom && !testify_yaml_fail && !testify_yaml_default +// +build testify_yaml_custom,!testify_yaml_fail,!testify_yaml_default + +// Package yaml is an implementation of YAML functions that calls a pluggable implementation. +// +// This implementation is selected with the testify_yaml_custom build tag. +// +// go test -tags testify_yaml_custom +// +// This implementation can be used at build time to replace the default implementation +// to avoid linking with [gopkg.in/yaml.v3]. +// +// In your test package: +// +// import assertYaml "github.com/stretchr/testify/assert/yaml" +// +// func init() { +// assertYaml.Unmarshal = func (in []byte, out interface{}) error { +// // ... +// return nil +// } +// } +package yaml + +var Unmarshal func(in []byte, out interface{}) error diff --git a/vendor/github.com/stretchr/testify/assert/yaml/yaml_default.go b/vendor/github.com/stretchr/testify/assert/yaml/yaml_default.go new file mode 100644 index 0000000000..b83c6cf64c --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/yaml/yaml_default.go @@ -0,0 +1,37 @@ +//go:build !testify_yaml_fail && !testify_yaml_custom +// +build !testify_yaml_fail,!testify_yaml_custom + +// Package yaml is just an indirection to handle YAML deserialization. +// +// This package is just an indirection that allows the builder to override the +// indirection with an alternative implementation of this package that uses +// another implementation of YAML deserialization. This allows to not either not +// use YAML deserialization at all, or to use another implementation than +// [gopkg.in/yaml.v3] (for example for license compatibility reasons, see [PR #1120]). +// +// Alternative implementations are selected using build tags: +// +// - testify_yaml_fail: [Unmarshal] always fails with an error +// - testify_yaml_custom: [Unmarshal] is a variable. Caller must initialize it +// before calling any of [github.com/stretchr/testify/assert.YAMLEq] or +// [github.com/stretchr/testify/assert.YAMLEqf]. +// +// Usage: +// +// go test -tags testify_yaml_fail +// +// You can check with "go list" which implementation is linked: +// +// go list -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml +// go list -tags testify_yaml_fail -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml +// go list -tags testify_yaml_custom -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml +// +// [PR #1120]: https://github.com/stretchr/testify/pull/1120 +package yaml + +import goyaml "gopkg.in/yaml.v3" + +// Unmarshal is just a wrapper of [gopkg.in/yaml.v3.Unmarshal]. +func Unmarshal(in []byte, out interface{}) error { + return goyaml.Unmarshal(in, out) +} diff --git a/vendor/github.com/stretchr/testify/assert/yaml/yaml_fail.go b/vendor/github.com/stretchr/testify/assert/yaml/yaml_fail.go new file mode 100644 index 0000000000..e78f7dfe69 --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/yaml/yaml_fail.go @@ -0,0 +1,18 @@ +//go:build testify_yaml_fail && !testify_yaml_custom && !testify_yaml_default +// +build testify_yaml_fail,!testify_yaml_custom,!testify_yaml_default + +// Package yaml is an implementation of YAML functions that always fail. +// +// This implementation can be used at build time to replace the default implementation +// to avoid linking with [gopkg.in/yaml.v3]: +// +// go test -tags testify_yaml_fail +package yaml + +import "errors" + +var errNotImplemented = errors.New("YAML functions are not available (see https://pkg.go.dev/github.com/stretchr/testify/assert/yaml)") + +func Unmarshal([]byte, interface{}) error { + return errNotImplemented +} diff --git a/vendor/github.com/stretchr/testify/require/require.go b/vendor/github.com/stretchr/testify/require/require.go index 506a82f807..d8921950d7 100644 --- a/vendor/github.com/stretchr/testify/require/require.go +++ b/vendor/github.com/stretchr/testify/require/require.go @@ -34,9 +34,9 @@ func Conditionf(t TestingT, comp assert.Comparison, msg string, args ...interfac // Contains asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // -// assert.Contains(t, "Hello World", "World") -// assert.Contains(t, ["Hello", "World"], "World") -// assert.Contains(t, {"Hello": "World"}, "Hello") +// require.Contains(t, "Hello World", "World") +// require.Contains(t, ["Hello", "World"], "World") +// require.Contains(t, {"Hello": "World"}, "Hello") func Contains(t TestingT, s interface{}, contains interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -50,9 +50,9 @@ func Contains(t TestingT, s interface{}, contains interface{}, msgAndArgs ...int // Containsf asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // -// assert.Containsf(t, "Hello World", "World", "error message %s", "formatted") -// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted") -// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted") +// require.Containsf(t, "Hello World", "World", "error message %s", "formatted") +// require.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted") +// require.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted") func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -91,7 +91,7 @@ func DirExistsf(t TestingT, path string, msg string, args ...interface{}) { // listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, // the number of appearances of each of them in both lists should match. // -// assert.ElementsMatch(t, [1, 3, 2, 3], [1, 3, 3, 2]) +// require.ElementsMatch(t, [1, 3, 2, 3], [1, 3, 3, 2]) func ElementsMatch(t TestingT, listA interface{}, listB interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -106,7 +106,7 @@ func ElementsMatch(t TestingT, listA interface{}, listB interface{}, msgAndArgs // listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, // the number of appearances of each of them in both lists should match. // -// assert.ElementsMatchf(t, [1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted") +// require.ElementsMatchf(t, [1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted") func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -120,7 +120,7 @@ func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string // Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // -// assert.Empty(t, obj) +// require.Empty(t, obj) func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -134,7 +134,7 @@ func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) { // Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either // a slice or a channel with len == 0. // -// assert.Emptyf(t, obj, "error message %s", "formatted") +// require.Emptyf(t, obj, "error message %s", "formatted") func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -147,7 +147,7 @@ func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) { // Equal asserts that two objects are equal. // -// assert.Equal(t, 123, 123) +// require.Equal(t, 123, 123) // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality @@ -166,7 +166,7 @@ func Equal(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...i // and that it is equal to the provided error. // // actualObj, err := SomeFunction() -// assert.EqualError(t, err, expectedErrorString) +// require.EqualError(t, err, expectedErrorString) func EqualError(t TestingT, theError error, errString string, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -181,7 +181,7 @@ func EqualError(t TestingT, theError error, errString string, msgAndArgs ...inte // and that it is equal to the provided error. // // actualObj, err := SomeFunction() -// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted") +// require.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted") func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -200,8 +200,8 @@ func EqualErrorf(t TestingT, theError error, errString string, msg string, args // Exported int // notExported int // } -// assert.EqualExportedValues(t, S{1, 2}, S{1, 3}) => true -// assert.EqualExportedValues(t, S{1, 2}, S{2, 3}) => false +// require.EqualExportedValues(t, S{1, 2}, S{1, 3}) => true +// require.EqualExportedValues(t, S{1, 2}, S{2, 3}) => false func EqualExportedValues(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -220,8 +220,8 @@ func EqualExportedValues(t TestingT, expected interface{}, actual interface{}, m // Exported int // notExported int // } -// assert.EqualExportedValuesf(t, S{1, 2}, S{1, 3}, "error message %s", "formatted") => true -// assert.EqualExportedValuesf(t, S{1, 2}, S{2, 3}, "error message %s", "formatted") => false +// require.EqualExportedValuesf(t, S{1, 2}, S{1, 3}, "error message %s", "formatted") => true +// require.EqualExportedValuesf(t, S{1, 2}, S{2, 3}, "error message %s", "formatted") => false func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -232,10 +232,10 @@ func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{}, t.FailNow() } -// EqualValues asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValues asserts that two objects are equal or convertible to the larger +// type and equal. // -// assert.EqualValues(t, uint32(123), int32(123)) +// require.EqualValues(t, uint32(123), int32(123)) func EqualValues(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -246,10 +246,10 @@ func EqualValues(t TestingT, expected interface{}, actual interface{}, msgAndArg t.FailNow() } -// EqualValuesf asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValuesf asserts that two objects are equal or convertible to the larger +// type and equal. // -// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") +// require.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -262,7 +262,7 @@ func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg stri // Equalf asserts that two objects are equal. // -// assert.Equalf(t, 123, 123, "error message %s", "formatted") +// require.Equalf(t, 123, 123, "error message %s", "formatted") // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality @@ -280,8 +280,8 @@ func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, ar // Error asserts that a function returned an error (i.e. not `nil`). // // actualObj, err := SomeFunction() -// if assert.Error(t, err) { -// assert.Equal(t, expectedError, err) +// if require.Error(t, err) { +// require.Equal(t, expectedError, err) // } func Error(t TestingT, err error, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { @@ -321,7 +321,7 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int // and that the error contains the specified substring. // // actualObj, err := SomeFunction() -// assert.ErrorContains(t, err, expectedErrorSubString) +// require.ErrorContains(t, err, expectedErrorSubString) func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -336,7 +336,7 @@ func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...in // and that the error contains the specified substring. // // actualObj, err := SomeFunction() -// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted") +// require.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted") func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -374,8 +374,8 @@ func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface // Errorf asserts that a function returned an error (i.e. not `nil`). // // actualObj, err := SomeFunction() -// if assert.Errorf(t, err, "error message %s", "formatted") { -// assert.Equal(t, expectedErrorf, err) +// if require.Errorf(t, err, "error message %s", "formatted") { +// require.Equal(t, expectedErrorf, err) // } func Errorf(t TestingT, err error, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { @@ -390,7 +390,7 @@ func Errorf(t TestingT, err error, msg string, args ...interface{}) { // Eventually asserts that given condition will be met in waitFor time, // periodically checking target function each tick. // -// assert.Eventually(t, func() bool { return true; }, time.Second, 10*time.Millisecond) +// require.Eventually(t, func() bool { return true; }, time.Second, 10*time.Millisecond) func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -415,10 +415,10 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t // time.Sleep(8*time.Second) // externalValue = true // }() -// assert.EventuallyWithT(t, func(c *assert.CollectT) { +// require.EventuallyWithT(t, func(c *require.CollectT) { // // add assertions as needed; any assertion failure will fail the current tick -// assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// require.True(c, externalValue, "expected 'externalValue' to be true") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func EventuallyWithT(t TestingT, condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -443,10 +443,10 @@ func EventuallyWithT(t TestingT, condition func(collect *assert.CollectT), waitF // time.Sleep(8*time.Second) // externalValue = true // }() -// assert.EventuallyWithTf(t, func(c *assert.CollectT, "error message %s", "formatted") { +// require.EventuallyWithTf(t, func(c *require.CollectT, "error message %s", "formatted") { // // add assertions as needed; any assertion failure will fail the current tick -// assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// require.True(c, externalValue, "expected 'externalValue' to be true") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func EventuallyWithTf(t TestingT, condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -460,7 +460,7 @@ func EventuallyWithTf(t TestingT, condition func(collect *assert.CollectT), wait // Eventuallyf asserts that given condition will be met in waitFor time, // periodically checking target function each tick. // -// assert.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +// require.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -473,7 +473,7 @@ func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick // Exactly asserts that two objects are equal in value and type. // -// assert.Exactly(t, int32(123), int64(123)) +// require.Exactly(t, int32(123), int64(123)) func Exactly(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -486,7 +486,7 @@ func Exactly(t TestingT, expected interface{}, actual interface{}, msgAndArgs .. // Exactlyf asserts that two objects are equal in value and type. // -// assert.Exactlyf(t, int32(123), int64(123), "error message %s", "formatted") +// require.Exactlyf(t, int32(123), int64(123), "error message %s", "formatted") func Exactlyf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -543,7 +543,7 @@ func Failf(t TestingT, failureMessage string, msg string, args ...interface{}) { // False asserts that the specified value is false. // -// assert.False(t, myBool) +// require.False(t, myBool) func False(t TestingT, value bool, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -556,7 +556,7 @@ func False(t TestingT, value bool, msgAndArgs ...interface{}) { // Falsef asserts that the specified value is false. // -// assert.Falsef(t, myBool, "error message %s", "formatted") +// require.Falsef(t, myBool, "error message %s", "formatted") func Falsef(t TestingT, value bool, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -593,9 +593,9 @@ func FileExistsf(t TestingT, path string, msg string, args ...interface{}) { // Greater asserts that the first element is greater than the second // -// assert.Greater(t, 2, 1) -// assert.Greater(t, float64(2), float64(1)) -// assert.Greater(t, "b", "a") +// require.Greater(t, 2, 1) +// require.Greater(t, float64(2), float64(1)) +// require.Greater(t, "b", "a") func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -608,10 +608,10 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface // GreaterOrEqual asserts that the first element is greater than or equal to the second // -// assert.GreaterOrEqual(t, 2, 1) -// assert.GreaterOrEqual(t, 2, 2) -// assert.GreaterOrEqual(t, "b", "a") -// assert.GreaterOrEqual(t, "b", "b") +// require.GreaterOrEqual(t, 2, 1) +// require.GreaterOrEqual(t, 2, 2) +// require.GreaterOrEqual(t, "b", "a") +// require.GreaterOrEqual(t, "b", "b") func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -624,10 +624,10 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in // GreaterOrEqualf asserts that the first element is greater than or equal to the second // -// assert.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted") -// assert.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted") -// assert.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted") -// assert.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted") +// require.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted") +// require.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted") +// require.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted") +// require.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted") func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -640,9 +640,9 @@ func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, arg // Greaterf asserts that the first element is greater than the second // -// assert.Greaterf(t, 2, 1, "error message %s", "formatted") -// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted") -// assert.Greaterf(t, "b", "a", "error message %s", "formatted") +// require.Greaterf(t, 2, 1, "error message %s", "formatted") +// require.Greaterf(t, float64(2), float64(1), "error message %s", "formatted") +// require.Greaterf(t, "b", "a", "error message %s", "formatted") func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -656,7 +656,7 @@ func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...in // HTTPBodyContains asserts that a specified handler returns a // body that contains a string. // -// assert.HTTPBodyContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") +// require.HTTPBodyContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) { @@ -672,7 +672,7 @@ func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method string, url s // HTTPBodyContainsf asserts that a specified handler returns a // body that contains a string. // -// assert.HTTPBodyContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") +// require.HTTPBodyContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) { @@ -688,7 +688,7 @@ func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url // HTTPBodyNotContains asserts that a specified handler returns a // body that does not contain a string. // -// assert.HTTPBodyNotContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") +// require.HTTPBodyNotContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) { @@ -704,7 +704,7 @@ func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method string, ur // HTTPBodyNotContainsf asserts that a specified handler returns a // body that does not contain a string. // -// assert.HTTPBodyNotContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") +// require.HTTPBodyNotContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) { @@ -719,7 +719,7 @@ func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, u // HTTPError asserts that a specified handler returns an error status code. // -// assert.HTTPError(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// require.HTTPError(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPError(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) { @@ -734,7 +734,7 @@ func HTTPError(t TestingT, handler http.HandlerFunc, method string, url string, // HTTPErrorf asserts that a specified handler returns an error status code. // -// assert.HTTPErrorf(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// require.HTTPErrorf(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) { @@ -749,7 +749,7 @@ func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, // HTTPRedirect asserts that a specified handler returns a redirect status code. // -// assert.HTTPRedirect(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// require.HTTPRedirect(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPRedirect(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) { @@ -764,7 +764,7 @@ func HTTPRedirect(t TestingT, handler http.HandlerFunc, method string, url strin // HTTPRedirectf asserts that a specified handler returns a redirect status code. // -// assert.HTTPRedirectf(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// require.HTTPRedirectf(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) { @@ -779,7 +779,7 @@ func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url stri // HTTPStatusCode asserts that a specified handler returns a specified status code. // -// assert.HTTPStatusCode(t, myHandler, "GET", "/notImplemented", nil, 501) +// require.HTTPStatusCode(t, myHandler, "GET", "/notImplemented", nil, 501) // // Returns whether the assertion was successful (true) or not (false). func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msgAndArgs ...interface{}) { @@ -794,7 +794,7 @@ func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method string, url str // HTTPStatusCodef asserts that a specified handler returns a specified status code. // -// assert.HTTPStatusCodef(t, myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted") +// require.HTTPStatusCodef(t, myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPStatusCodef(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msg string, args ...interface{}) { @@ -809,7 +809,7 @@ func HTTPStatusCodef(t TestingT, handler http.HandlerFunc, method string, url st // HTTPSuccess asserts that a specified handler returns a success status code. // -// assert.HTTPSuccess(t, myHandler, "POST", "http://www.google.com", nil) +// require.HTTPSuccess(t, myHandler, "POST", "http://www.google.com", nil) // // Returns whether the assertion was successful (true) or not (false). func HTTPSuccess(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) { @@ -824,7 +824,7 @@ func HTTPSuccess(t TestingT, handler http.HandlerFunc, method string, url string // HTTPSuccessf asserts that a specified handler returns a success status code. // -// assert.HTTPSuccessf(t, myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") +// require.HTTPSuccessf(t, myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) { @@ -839,7 +839,7 @@ func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url strin // Implements asserts that an object is implemented by the specified interface. // -// assert.Implements(t, (*MyInterface)(nil), new(MyObject)) +// require.Implements(t, (*MyInterface)(nil), new(MyObject)) func Implements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -852,7 +852,7 @@ func Implements(t TestingT, interfaceObject interface{}, object interface{}, msg // Implementsf asserts that an object is implemented by the specified interface. // -// assert.Implementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted") +// require.Implementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted") func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -865,7 +865,7 @@ func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, ms // InDelta asserts that the two numerals are within delta of each other. // -// assert.InDelta(t, math.Pi, 22/7.0, 0.01) +// require.InDelta(t, math.Pi, 22/7.0, 0.01) func InDelta(t TestingT, expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -922,7 +922,7 @@ func InDeltaSlicef(t TestingT, expected interface{}, actual interface{}, delta f // InDeltaf asserts that the two numerals are within delta of each other. // -// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted") +// require.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted") func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -979,9 +979,9 @@ func InEpsilonf(t TestingT, expected interface{}, actual interface{}, epsilon fl // IsDecreasing asserts that the collection is decreasing // -// assert.IsDecreasing(t, []int{2, 1, 0}) -// assert.IsDecreasing(t, []float{2, 1}) -// assert.IsDecreasing(t, []string{"b", "a"}) +// require.IsDecreasing(t, []int{2, 1, 0}) +// require.IsDecreasing(t, []float{2, 1}) +// require.IsDecreasing(t, []string{"b", "a"}) func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -994,9 +994,9 @@ func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) { // IsDecreasingf asserts that the collection is decreasing // -// assert.IsDecreasingf(t, []int{2, 1, 0}, "error message %s", "formatted") -// assert.IsDecreasingf(t, []float{2, 1}, "error message %s", "formatted") -// assert.IsDecreasingf(t, []string{"b", "a"}, "error message %s", "formatted") +// require.IsDecreasingf(t, []int{2, 1, 0}, "error message %s", "formatted") +// require.IsDecreasingf(t, []float{2, 1}, "error message %s", "formatted") +// require.IsDecreasingf(t, []string{"b", "a"}, "error message %s", "formatted") func IsDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1009,9 +1009,9 @@ func IsDecreasingf(t TestingT, object interface{}, msg string, args ...interface // IsIncreasing asserts that the collection is increasing // -// assert.IsIncreasing(t, []int{1, 2, 3}) -// assert.IsIncreasing(t, []float{1, 2}) -// assert.IsIncreasing(t, []string{"a", "b"}) +// require.IsIncreasing(t, []int{1, 2, 3}) +// require.IsIncreasing(t, []float{1, 2}) +// require.IsIncreasing(t, []string{"a", "b"}) func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1024,9 +1024,9 @@ func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) { // IsIncreasingf asserts that the collection is increasing // -// assert.IsIncreasingf(t, []int{1, 2, 3}, "error message %s", "formatted") -// assert.IsIncreasingf(t, []float{1, 2}, "error message %s", "formatted") -// assert.IsIncreasingf(t, []string{"a", "b"}, "error message %s", "formatted") +// require.IsIncreasingf(t, []int{1, 2, 3}, "error message %s", "formatted") +// require.IsIncreasingf(t, []float{1, 2}, "error message %s", "formatted") +// require.IsIncreasingf(t, []string{"a", "b"}, "error message %s", "formatted") func IsIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1039,9 +1039,9 @@ func IsIncreasingf(t TestingT, object interface{}, msg string, args ...interface // IsNonDecreasing asserts that the collection is not decreasing // -// assert.IsNonDecreasing(t, []int{1, 1, 2}) -// assert.IsNonDecreasing(t, []float{1, 2}) -// assert.IsNonDecreasing(t, []string{"a", "b"}) +// require.IsNonDecreasing(t, []int{1, 1, 2}) +// require.IsNonDecreasing(t, []float{1, 2}) +// require.IsNonDecreasing(t, []string{"a", "b"}) func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1054,9 +1054,9 @@ func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) // IsNonDecreasingf asserts that the collection is not decreasing // -// assert.IsNonDecreasingf(t, []int{1, 1, 2}, "error message %s", "formatted") -// assert.IsNonDecreasingf(t, []float{1, 2}, "error message %s", "formatted") -// assert.IsNonDecreasingf(t, []string{"a", "b"}, "error message %s", "formatted") +// require.IsNonDecreasingf(t, []int{1, 1, 2}, "error message %s", "formatted") +// require.IsNonDecreasingf(t, []float{1, 2}, "error message %s", "formatted") +// require.IsNonDecreasingf(t, []string{"a", "b"}, "error message %s", "formatted") func IsNonDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1069,9 +1069,9 @@ func IsNonDecreasingf(t TestingT, object interface{}, msg string, args ...interf // IsNonIncreasing asserts that the collection is not increasing // -// assert.IsNonIncreasing(t, []int{2, 1, 1}) -// assert.IsNonIncreasing(t, []float{2, 1}) -// assert.IsNonIncreasing(t, []string{"b", "a"}) +// require.IsNonIncreasing(t, []int{2, 1, 1}) +// require.IsNonIncreasing(t, []float{2, 1}) +// require.IsNonIncreasing(t, []string{"b", "a"}) func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1084,9 +1084,9 @@ func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) // IsNonIncreasingf asserts that the collection is not increasing // -// assert.IsNonIncreasingf(t, []int{2, 1, 1}, "error message %s", "formatted") -// assert.IsNonIncreasingf(t, []float{2, 1}, "error message %s", "formatted") -// assert.IsNonIncreasingf(t, []string{"b", "a"}, "error message %s", "formatted") +// require.IsNonIncreasingf(t, []int{2, 1, 1}, "error message %s", "formatted") +// require.IsNonIncreasingf(t, []float{2, 1}, "error message %s", "formatted") +// require.IsNonIncreasingf(t, []string{"b", "a"}, "error message %s", "formatted") func IsNonIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1121,7 +1121,7 @@ func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg strin // JSONEq asserts that two JSON strings are equivalent. // -// assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) +// require.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1134,7 +1134,7 @@ func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{ // JSONEqf asserts that two JSON strings are equivalent. // -// assert.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") +// require.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") func JSONEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1148,7 +1148,7 @@ func JSONEqf(t TestingT, expected string, actual string, msg string, args ...int // Len asserts that the specified object has specific length. // Len also fails if the object has a type that len() not accept. // -// assert.Len(t, mySlice, 3) +// require.Len(t, mySlice, 3) func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1162,7 +1162,7 @@ func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) // Lenf asserts that the specified object has specific length. // Lenf also fails if the object has a type that len() not accept. // -// assert.Lenf(t, mySlice, 3, "error message %s", "formatted") +// require.Lenf(t, mySlice, 3, "error message %s", "formatted") func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1175,9 +1175,9 @@ func Lenf(t TestingT, object interface{}, length int, msg string, args ...interf // Less asserts that the first element is less than the second // -// assert.Less(t, 1, 2) -// assert.Less(t, float64(1), float64(2)) -// assert.Less(t, "a", "b") +// require.Less(t, 1, 2) +// require.Less(t, float64(1), float64(2)) +// require.Less(t, "a", "b") func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1190,10 +1190,10 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) // LessOrEqual asserts that the first element is less than or equal to the second // -// assert.LessOrEqual(t, 1, 2) -// assert.LessOrEqual(t, 2, 2) -// assert.LessOrEqual(t, "a", "b") -// assert.LessOrEqual(t, "b", "b") +// require.LessOrEqual(t, 1, 2) +// require.LessOrEqual(t, 2, 2) +// require.LessOrEqual(t, "a", "b") +// require.LessOrEqual(t, "b", "b") func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1206,10 +1206,10 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter // LessOrEqualf asserts that the first element is less than or equal to the second // -// assert.LessOrEqualf(t, 1, 2, "error message %s", "formatted") -// assert.LessOrEqualf(t, 2, 2, "error message %s", "formatted") -// assert.LessOrEqualf(t, "a", "b", "error message %s", "formatted") -// assert.LessOrEqualf(t, "b", "b", "error message %s", "formatted") +// require.LessOrEqualf(t, 1, 2, "error message %s", "formatted") +// require.LessOrEqualf(t, 2, 2, "error message %s", "formatted") +// require.LessOrEqualf(t, "a", "b", "error message %s", "formatted") +// require.LessOrEqualf(t, "b", "b", "error message %s", "formatted") func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1222,9 +1222,9 @@ func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args . // Lessf asserts that the first element is less than the second // -// assert.Lessf(t, 1, 2, "error message %s", "formatted") -// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted") -// assert.Lessf(t, "a", "b", "error message %s", "formatted") +// require.Lessf(t, 1, 2, "error message %s", "formatted") +// require.Lessf(t, float64(1), float64(2), "error message %s", "formatted") +// require.Lessf(t, "a", "b", "error message %s", "formatted") func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1237,8 +1237,8 @@ func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...inter // Negative asserts that the specified element is negative // -// assert.Negative(t, -1) -// assert.Negative(t, -1.23) +// require.Negative(t, -1) +// require.Negative(t, -1.23) func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1251,8 +1251,8 @@ func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) { // Negativef asserts that the specified element is negative // -// assert.Negativef(t, -1, "error message %s", "formatted") -// assert.Negativef(t, -1.23, "error message %s", "formatted") +// require.Negativef(t, -1, "error message %s", "formatted") +// require.Negativef(t, -1.23, "error message %s", "formatted") func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1266,7 +1266,7 @@ func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) { // Never asserts that the given condition doesn't satisfy in waitFor time, // periodically checking the target function each tick. // -// assert.Never(t, func() bool { return false; }, time.Second, 10*time.Millisecond) +// require.Never(t, func() bool { return false; }, time.Second, 10*time.Millisecond) func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1280,7 +1280,7 @@ func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.D // Neverf asserts that the given condition doesn't satisfy in waitFor time, // periodically checking the target function each tick. // -// assert.Neverf(t, func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +// require.Neverf(t, func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") func Neverf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1293,7 +1293,7 @@ func Neverf(t TestingT, condition func() bool, waitFor time.Duration, tick time. // Nil asserts that the specified object is nil. // -// assert.Nil(t, err) +// require.Nil(t, err) func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1306,7 +1306,7 @@ func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) { // Nilf asserts that the specified object is nil. // -// assert.Nilf(t, err, "error message %s", "formatted") +// require.Nilf(t, err, "error message %s", "formatted") func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1344,8 +1344,8 @@ func NoDirExistsf(t TestingT, path string, msg string, args ...interface{}) { // NoError asserts that a function returned no error (i.e. `nil`). // // actualObj, err := SomeFunction() -// if assert.NoError(t, err) { -// assert.Equal(t, expectedObj, actualObj) +// if require.NoError(t, err) { +// require.Equal(t, expectedObj, actualObj) // } func NoError(t TestingT, err error, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { @@ -1360,8 +1360,8 @@ func NoError(t TestingT, err error, msgAndArgs ...interface{}) { // NoErrorf asserts that a function returned no error (i.e. `nil`). // // actualObj, err := SomeFunction() -// if assert.NoErrorf(t, err, "error message %s", "formatted") { -// assert.Equal(t, expectedObj, actualObj) +// if require.NoErrorf(t, err, "error message %s", "formatted") { +// require.Equal(t, expectedObj, actualObj) // } func NoErrorf(t TestingT, err error, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { @@ -1400,9 +1400,9 @@ func NoFileExistsf(t TestingT, path string, msg string, args ...interface{}) { // NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // -// assert.NotContains(t, "Hello World", "Earth") -// assert.NotContains(t, ["Hello", "World"], "Earth") -// assert.NotContains(t, {"Hello": "World"}, "Earth") +// require.NotContains(t, "Hello World", "Earth") +// require.NotContains(t, ["Hello", "World"], "Earth") +// require.NotContains(t, {"Hello": "World"}, "Earth") func NotContains(t TestingT, s interface{}, contains interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1416,9 +1416,9 @@ func NotContains(t TestingT, s interface{}, contains interface{}, msgAndArgs ... // NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // -// assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted") -// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted") -// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted") +// require.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted") +// require.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted") +// require.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted") func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1429,11 +1429,51 @@ func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, a t.FailNow() } +// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// require.NotElementsMatch(t, [1, 1, 2, 3], [1, 1, 2, 3]) -> false +// +// require.NotElementsMatch(t, [1, 1, 2, 3], [1, 2, 3]) -> true +// +// require.NotElementsMatch(t, [1, 2, 3], [1, 2, 4]) -> true +func NotElementsMatch(t TestingT, listA interface{}, listB interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotElementsMatch(t, listA, listB, msgAndArgs...) { + return + } + t.FailNow() +} + +// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// require.NotElementsMatchf(t, [1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false +// +// require.NotElementsMatchf(t, [1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true +// +// require.NotElementsMatchf(t, [1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true +func NotElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotElementsMatchf(t, listA, listB, msg, args...) { + return + } + t.FailNow() +} + // NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // -// if assert.NotEmpty(t, obj) { -// assert.Equal(t, "two", obj[1]) +// if require.NotEmpty(t, obj) { +// require.Equal(t, "two", obj[1]) // } func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { @@ -1448,8 +1488,8 @@ func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) { // NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // -// if assert.NotEmptyf(t, obj, "error message %s", "formatted") { -// assert.Equal(t, "two", obj[1]) +// if require.NotEmptyf(t, obj, "error message %s", "formatted") { +// require.Equal(t, "two", obj[1]) // } func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { @@ -1463,7 +1503,7 @@ func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) // NotEqual asserts that the specified values are NOT equal. // -// assert.NotEqual(t, obj1, obj2) +// require.NotEqual(t, obj1, obj2) // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). @@ -1479,7 +1519,7 @@ func NotEqual(t TestingT, expected interface{}, actual interface{}, msgAndArgs . // NotEqualValues asserts that two objects are not equal even when converted to the same type // -// assert.NotEqualValues(t, obj1, obj2) +// require.NotEqualValues(t, obj1, obj2) func NotEqualValues(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1492,7 +1532,7 @@ func NotEqualValues(t TestingT, expected interface{}, actual interface{}, msgAnd // NotEqualValuesf asserts that two objects are not equal even when converted to the same type // -// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted") +// require.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted") func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1505,7 +1545,7 @@ func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg s // NotEqualf asserts that the specified values are NOT equal. // -// assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted") +// require.NotEqualf(t, obj1, obj2, "error message %s", "formatted") // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). @@ -1519,7 +1559,31 @@ func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, t.FailNow() } -// NotErrorIs asserts that at none of the errors in err's chain matches target. +// NotErrorAs asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func NotErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotErrorAs(t, err, target, msgAndArgs...) { + return + } + t.FailNow() +} + +// NotErrorAsf asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func NotErrorAsf(t TestingT, err error, target interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotErrorAsf(t, err, target, msg, args...) { + return + } + t.FailNow() +} + +// NotErrorIs asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func NotErrorIs(t TestingT, err error, target error, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { @@ -1531,7 +1595,7 @@ func NotErrorIs(t TestingT, err error, target error, msgAndArgs ...interface{}) t.FailNow() } -// NotErrorIsf asserts that at none of the errors in err's chain matches target. +// NotErrorIsf asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { @@ -1545,7 +1609,7 @@ func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interf // NotImplements asserts that an object does not implement the specified interface. // -// assert.NotImplements(t, (*MyInterface)(nil), new(MyObject)) +// require.NotImplements(t, (*MyInterface)(nil), new(MyObject)) func NotImplements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1558,7 +1622,7 @@ func NotImplements(t TestingT, interfaceObject interface{}, object interface{}, // NotImplementsf asserts that an object does not implement the specified interface. // -// assert.NotImplementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted") +// require.NotImplementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted") func NotImplementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1571,7 +1635,7 @@ func NotImplementsf(t TestingT, interfaceObject interface{}, object interface{}, // NotNil asserts that the specified object is not nil. // -// assert.NotNil(t, err) +// require.NotNil(t, err) func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1584,7 +1648,7 @@ func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) { // NotNilf asserts that the specified object is not nil. // -// assert.NotNilf(t, err, "error message %s", "formatted") +// require.NotNilf(t, err, "error message %s", "formatted") func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1597,7 +1661,7 @@ func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) { // NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic. // -// assert.NotPanics(t, func(){ RemainCalm() }) +// require.NotPanics(t, func(){ RemainCalm() }) func NotPanics(t TestingT, f assert.PanicTestFunc, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1610,7 +1674,7 @@ func NotPanics(t TestingT, f assert.PanicTestFunc, msgAndArgs ...interface{}) { // NotPanicsf asserts that the code inside the specified PanicTestFunc does NOT panic. // -// assert.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted") +// require.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted") func NotPanicsf(t TestingT, f assert.PanicTestFunc, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1623,8 +1687,8 @@ func NotPanicsf(t TestingT, f assert.PanicTestFunc, msg string, args ...interfac // NotRegexp asserts that a specified regexp does not match a string. // -// assert.NotRegexp(t, regexp.MustCompile("starts"), "it's starting") -// assert.NotRegexp(t, "^start", "it's not starting") +// require.NotRegexp(t, regexp.MustCompile("starts"), "it's starting") +// require.NotRegexp(t, "^start", "it's not starting") func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1637,8 +1701,8 @@ func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interf // NotRegexpf asserts that a specified regexp does not match a string. // -// assert.NotRegexpf(t, regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted") -// assert.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted") +// require.NotRegexpf(t, regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted") +// require.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted") func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1651,7 +1715,7 @@ func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args .. // NotSame asserts that two pointers do not reference the same object. // -// assert.NotSame(t, ptr1, ptr2) +// require.NotSame(t, ptr1, ptr2) // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -1667,7 +1731,7 @@ func NotSame(t TestingT, expected interface{}, actual interface{}, msgAndArgs .. // NotSamef asserts that two pointers do not reference the same object. // -// assert.NotSamef(t, ptr1, ptr2, "error message %s", "formatted") +// require.NotSamef(t, ptr1, ptr2, "error message %s", "formatted") // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -1685,8 +1749,8 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string, // contain all elements given in the specified subset list(array, slice...) or // map. // -// assert.NotSubset(t, [1, 3, 4], [1, 2]) -// assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3}) +// require.NotSubset(t, [1, 3, 4], [1, 2]) +// require.NotSubset(t, {"x": 1, "y": 2}, {"z": 3}) func NotSubset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1701,8 +1765,8 @@ func NotSubset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...i // contain all elements given in the specified subset list(array, slice...) or // map. // -// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted") -// assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted") +// require.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted") +// require.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted") func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1737,7 +1801,7 @@ func NotZerof(t TestingT, i interface{}, msg string, args ...interface{}) { // Panics asserts that the code inside the specified PanicTestFunc panics. // -// assert.Panics(t, func(){ GoCrazy() }) +// require.Panics(t, func(){ GoCrazy() }) func Panics(t TestingT, f assert.PanicTestFunc, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1752,7 +1816,7 @@ func Panics(t TestingT, f assert.PanicTestFunc, msgAndArgs ...interface{}) { // panics, and that the recovered panic value is an error that satisfies the // EqualError comparison. // -// assert.PanicsWithError(t, "crazy error", func(){ GoCrazy() }) +// require.PanicsWithError(t, "crazy error", func(){ GoCrazy() }) func PanicsWithError(t TestingT, errString string, f assert.PanicTestFunc, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1767,7 +1831,7 @@ func PanicsWithError(t TestingT, errString string, f assert.PanicTestFunc, msgAn // panics, and that the recovered panic value is an error that satisfies the // EqualError comparison. // -// assert.PanicsWithErrorf(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") +// require.PanicsWithErrorf(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") func PanicsWithErrorf(t TestingT, errString string, f assert.PanicTestFunc, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1781,7 +1845,7 @@ func PanicsWithErrorf(t TestingT, errString string, f assert.PanicTestFunc, msg // PanicsWithValue asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // -// assert.PanicsWithValue(t, "crazy error", func(){ GoCrazy() }) +// require.PanicsWithValue(t, "crazy error", func(){ GoCrazy() }) func PanicsWithValue(t TestingT, expected interface{}, f assert.PanicTestFunc, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1795,7 +1859,7 @@ func PanicsWithValue(t TestingT, expected interface{}, f assert.PanicTestFunc, m // PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // -// assert.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") +// require.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") func PanicsWithValuef(t TestingT, expected interface{}, f assert.PanicTestFunc, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1808,7 +1872,7 @@ func PanicsWithValuef(t TestingT, expected interface{}, f assert.PanicTestFunc, // Panicsf asserts that the code inside the specified PanicTestFunc panics. // -// assert.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted") +// require.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted") func Panicsf(t TestingT, f assert.PanicTestFunc, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1821,8 +1885,8 @@ func Panicsf(t TestingT, f assert.PanicTestFunc, msg string, args ...interface{} // Positive asserts that the specified element is positive // -// assert.Positive(t, 1) -// assert.Positive(t, 1.23) +// require.Positive(t, 1) +// require.Positive(t, 1.23) func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1835,8 +1899,8 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) { // Positivef asserts that the specified element is positive // -// assert.Positivef(t, 1, "error message %s", "formatted") -// assert.Positivef(t, 1.23, "error message %s", "formatted") +// require.Positivef(t, 1, "error message %s", "formatted") +// require.Positivef(t, 1.23, "error message %s", "formatted") func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1849,8 +1913,8 @@ func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) { // Regexp asserts that a specified regexp matches a string. // -// assert.Regexp(t, regexp.MustCompile("start"), "it's starting") -// assert.Regexp(t, "start...$", "it's not starting") +// require.Regexp(t, regexp.MustCompile("start"), "it's starting") +// require.Regexp(t, "start...$", "it's not starting") func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1863,8 +1927,8 @@ func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface // Regexpf asserts that a specified regexp matches a string. // -// assert.Regexpf(t, regexp.MustCompile("start"), "it's starting", "error message %s", "formatted") -// assert.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted") +// require.Regexpf(t, regexp.MustCompile("start"), "it's starting", "error message %s", "formatted") +// require.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted") func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1877,7 +1941,7 @@ func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...in // Same asserts that two pointers reference the same object. // -// assert.Same(t, ptr1, ptr2) +// require.Same(t, ptr1, ptr2) // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -1893,7 +1957,7 @@ func Same(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...in // Samef asserts that two pointers reference the same object. // -// assert.Samef(t, ptr1, ptr2, "error message %s", "formatted") +// require.Samef(t, ptr1, ptr2, "error message %s", "formatted") // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -1910,8 +1974,8 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg // Subset asserts that the specified list(array, slice...) or map contains all // elements given in the specified subset list(array, slice...) or map. // -// assert.Subset(t, [1, 2, 3], [1, 2]) -// assert.Subset(t, {"x": 1, "y": 2}, {"x": 1}) +// require.Subset(t, [1, 2, 3], [1, 2]) +// require.Subset(t, {"x": 1, "y": 2}, {"x": 1}) func Subset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1925,8 +1989,8 @@ func Subset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...inte // Subsetf asserts that the specified list(array, slice...) or map contains all // elements given in the specified subset list(array, slice...) or map. // -// assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted") -// assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted") +// require.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted") +// require.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted") func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1939,7 +2003,7 @@ func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args // True asserts that the specified value is true. // -// assert.True(t, myBool) +// require.True(t, myBool) func True(t TestingT, value bool, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1952,7 +2016,7 @@ func True(t TestingT, value bool, msgAndArgs ...interface{}) { // Truef asserts that the specified value is true. // -// assert.Truef(t, myBool, "error message %s", "formatted") +// require.Truef(t, myBool, "error message %s", "formatted") func Truef(t TestingT, value bool, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1965,7 +2029,7 @@ func Truef(t TestingT, value bool, msg string, args ...interface{}) { // WithinDuration asserts that the two times are within duration delta of each other. // -// assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second) +// require.WithinDuration(t, time.Now(), time.Now(), 10*time.Second) func WithinDuration(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1978,7 +2042,7 @@ func WithinDuration(t TestingT, expected time.Time, actual time.Time, delta time // WithinDurationf asserts that the two times are within duration delta of each other. // -// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") +// require.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1991,7 +2055,7 @@ func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta tim // WithinRange asserts that a time is within a time range (inclusive). // -// assert.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) +// require.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) func WithinRange(t TestingT, actual time.Time, start time.Time, end time.Time, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -2004,7 +2068,7 @@ func WithinRange(t TestingT, actual time.Time, start time.Time, end time.Time, m // WithinRangef asserts that a time is within a time range (inclusive). // -// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted") +// require.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted") func WithinRangef(t TestingT, actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() diff --git a/vendor/github.com/stretchr/testify/require/require.go.tmpl b/vendor/github.com/stretchr/testify/require/require.go.tmpl index 55e42ddebd..8b32836850 100644 --- a/vendor/github.com/stretchr/testify/require/require.go.tmpl +++ b/vendor/github.com/stretchr/testify/require/require.go.tmpl @@ -1,4 +1,4 @@ -{{.Comment}} +{{ replace .Comment "assert." "require."}} func {{.DocInfo.Name}}(t TestingT, {{.Params}}) { if h, ok := t.(tHelper); ok { h.Helper() } if assert.{{.DocInfo.Name}}(t, {{.ForwardedParams}}) { return } diff --git a/vendor/github.com/stretchr/testify/require/require_forward.go b/vendor/github.com/stretchr/testify/require/require_forward.go index eee8310a5f..1bd87304f4 100644 --- a/vendor/github.com/stretchr/testify/require/require_forward.go +++ b/vendor/github.com/stretchr/testify/require/require_forward.go @@ -187,8 +187,8 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface EqualExportedValuesf(a.t, expected, actual, msg, args...) } -// EqualValues asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValues asserts that two objects are equal or convertible to the larger +// type and equal. // // a.EqualValues(uint32(123), int32(123)) func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) { @@ -198,8 +198,8 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn EqualValues(a.t, expected, actual, msgAndArgs...) } -// EqualValuesf asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValuesf asserts that two objects are equal or convertible to the larger +// type and equal. // // a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) { @@ -337,7 +337,7 @@ func (a *Assertions) Eventually(condition func() bool, waitFor time.Duration, ti // a.EventuallyWithT(func(c *assert.CollectT) { // // add assertions as needed; any assertion failure will fail the current tick // assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func (a *Assertions) EventuallyWithT(condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -362,7 +362,7 @@ func (a *Assertions) EventuallyWithT(condition func(collect *assert.CollectT), w // a.EventuallyWithTf(func(c *assert.CollectT, "error message %s", "formatted") { // // add assertions as needed; any assertion failure will fail the current tick // assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func (a *Assertions) EventuallyWithTf(condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1129,6 +1129,40 @@ func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg strin NotContainsf(a.t, s, contains, msg, args...) } +// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// a.NotElementsMatch([1, 1, 2, 3], [1, 1, 2, 3]) -> false +// +// a.NotElementsMatch([1, 1, 2, 3], [1, 2, 3]) -> true +// +// a.NotElementsMatch([1, 2, 3], [1, 2, 4]) -> true +func (a *Assertions) NotElementsMatch(listA interface{}, listB interface{}, msgAndArgs ...interface{}) { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + NotElementsMatch(a.t, listA, listB, msgAndArgs...) +} + +// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// a.NotElementsMatchf([1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false +// +// a.NotElementsMatchf([1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true +// +// a.NotElementsMatchf([1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true +func (a *Assertions) NotElementsMatchf(listA interface{}, listB interface{}, msg string, args ...interface{}) { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + NotElementsMatchf(a.t, listA, listB, msg, args...) +} + // NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either // a slice or a channel with len == 0. // @@ -1201,7 +1235,25 @@ func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg str NotEqualf(a.t, expected, actual, msg, args...) } -// NotErrorIs asserts that at none of the errors in err's chain matches target. +// NotErrorAs asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func (a *Assertions) NotErrorAs(err error, target interface{}, msgAndArgs ...interface{}) { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + NotErrorAs(a.t, err, target, msgAndArgs...) +} + +// NotErrorAsf asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func (a *Assertions) NotErrorAsf(err error, target interface{}, msg string, args ...interface{}) { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + NotErrorAsf(a.t, err, target, msg, args...) +} + +// NotErrorIs asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface{}) { if h, ok := a.t.(tHelper); ok { @@ -1210,7 +1262,7 @@ func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface NotErrorIs(a.t, err, target, msgAndArgs...) } -// NotErrorIsf asserts that at none of the errors in err's chain matches target. +// NotErrorIsf asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...interface{}) { if h, ok := a.t.(tHelper); ok { diff --git a/vendor/github.com/stretchr/testify/require/requirements.go b/vendor/github.com/stretchr/testify/require/requirements.go index 91772dfeb9..6b7ce929eb 100644 --- a/vendor/github.com/stretchr/testify/require/requirements.go +++ b/vendor/github.com/stretchr/testify/require/requirements.go @@ -6,7 +6,7 @@ type TestingT interface { FailNow() } -type tHelper interface { +type tHelper = interface { Helper() } diff --git a/vendor/github.com/stretchr/testify/suite/doc.go b/vendor/github.com/stretchr/testify/suite/doc.go index 8d55a3aa89..05a562f721 100644 --- a/vendor/github.com/stretchr/testify/suite/doc.go +++ b/vendor/github.com/stretchr/testify/suite/doc.go @@ -5,6 +5,8 @@ // or individual tests (depending on which interface(s) you // implement). // +// The suite package does not support parallel tests. See [issue 934]. +// // A testing suite is usually built by first extending the built-in // suite functionality from suite.Suite in testify. Alternatively, // you could reproduce that logic on your own if you wanted (you @@ -63,4 +65,6 @@ // func TestExampleTestSuite(t *testing.T) { // suite.Run(t, new(ExampleTestSuite)) // } +// +// [issue 934]: https://github.com/stretchr/testify/issues/934 package suite diff --git a/vendor/modules.txt b/vendor/modules.txt index f8700bbe01..d0b2b635d8 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -37,8 +37,8 @@ github.com/Microsoft/go-winio/internal/stringbuffer github.com/Microsoft/go-winio/pkg/guid # github.com/Netflix/go-expect v0.0.0-20201125194554-85d881c3777e ## explicit; go 1.13 -# github.com/ProtonMail/go-crypto v1.0.0 -## explicit; go 1.13 +# github.com/ProtonMail/go-crypto v1.1.3 +## explicit; go 1.17 github.com/ProtonMail/go-crypto/bitcurves github.com/ProtonMail/go-crypto/brainpool github.com/ProtonMail/go-crypto/eax @@ -49,6 +49,8 @@ github.com/ProtonMail/go-crypto/openpgp/aes/keywrap github.com/ProtonMail/go-crypto/openpgp/armor github.com/ProtonMail/go-crypto/openpgp/ecdh github.com/ProtonMail/go-crypto/openpgp/ecdsa +github.com/ProtonMail/go-crypto/openpgp/ed25519 +github.com/ProtonMail/go-crypto/openpgp/ed448 github.com/ProtonMail/go-crypto/openpgp/eddsa github.com/ProtonMail/go-crypto/openpgp/elgamal github.com/ProtonMail/go-crypto/openpgp/errors @@ -57,6 +59,8 @@ github.com/ProtonMail/go-crypto/openpgp/internal/ecc github.com/ProtonMail/go-crypto/openpgp/internal/encoding github.com/ProtonMail/go-crypto/openpgp/packet github.com/ProtonMail/go-crypto/openpgp/s2k +github.com/ProtonMail/go-crypto/openpgp/x25519 +github.com/ProtonMail/go-crypto/openpgp/x448 # github.com/PuerkitoBio/purell v1.1.1 ## explicit github.com/PuerkitoBio/purell @@ -187,8 +191,8 @@ github.com/containerd/console # github.com/creack/pty v1.1.11 ## explicit; go 1.13 github.com/creack/pty -# github.com/cyphar/filepath-securejoin v0.2.4 -## explicit; go 1.13 +# github.com/cyphar/filepath-securejoin v0.3.6 +## explicit; go 1.18 github.com/cyphar/filepath-securejoin # github.com/dave/jennifer v0.18.0 ## explicit @@ -236,16 +240,16 @@ github.com/go-git/gcfg github.com/go-git/gcfg/scanner github.com/go-git/gcfg/token github.com/go-git/gcfg/types -# github.com/go-git/go-billy/v5 v5.5.0 -## explicit; go 1.19 +# github.com/go-git/go-billy/v5 v5.6.1 +## explicit; go 1.21 github.com/go-git/go-billy/v5 github.com/go-git/go-billy/v5/helper/chroot github.com/go-git/go-billy/v5/helper/polyfill github.com/go-git/go-billy/v5/memfs github.com/go-git/go-billy/v5/osfs github.com/go-git/go-billy/v5/util -# github.com/go-git/go-git/v5 v5.12.0 -## explicit; go 1.19 +# github.com/go-git/go-git/v5 v5.13.1 +## explicit; go 1.21 github.com/go-git/go-git/v5 github.com/go-git/go-git/v5/config github.com/go-git/go-git/v5/internal/path_util @@ -588,7 +592,7 @@ github.com/shirou/gopsutil/v3/process # github.com/shoenig/go-m1cpu v0.1.6 ## explicit; go 1.20 github.com/shoenig/go-m1cpu -# github.com/skeema/knownhosts v1.2.2 +# github.com/skeema/knownhosts v1.3.0 ## explicit; go 1.17 github.com/skeema/knownhosts # github.com/skratchdot/open-golang v0.0.0-20190104022628-a2dfa6d0dab6 @@ -608,9 +612,10 @@ github.com/spf13/cobra github.com/spf13/pflag # github.com/stretchr/objx v0.5.2 ## explicit; go 1.20 -# github.com/stretchr/testify v1.9.0 +# github.com/stretchr/testify v1.10.0 ## explicit; go 1.17 github.com/stretchr/testify/assert +github.com/stretchr/testify/assert/yaml github.com/stretchr/testify/require github.com/stretchr/testify/suite # github.com/thoas/go-funk v0.8.0