diff --git a/.gitattributes b/.gitattributes index adc4144ffa3..13825940056 100644 --- a/.gitattributes +++ b/.gitattributes @@ -3,3 +3,7 @@ go.sum linguist-generated text gnovm/stdlibs/generated.go linguist-generated gnovm/tests/stdlibs/generated.go linguist-generated +*.gen.gno linguist-generated +*.gen_test.gno linguist-generated +*.gen.go linguist-generated +*.gen_test.go linguist-generated \ No newline at end of file diff --git a/.github/golangci.yml b/.github/golangci.yml index ca85620b7e6..afc581d2ec5 100644 --- a/.github/golangci.yml +++ b/.github/golangci.yml @@ -85,6 +85,7 @@ issues: - gosec # Disabled linting of weak number generators - makezero # Disabled linting of intentional slice appends - goconst # Disabled linting of common mnemonics and test case strings + - unused # Disabled linting of unused mock methods - path: _\.gno linters: - errorlint # Disabled linting of error comparisons, because of lacking std lib support diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index e933db9ea19..5d606a2a663 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -8,6 +8,7 @@ on: paths: - gnovm/**/*.gno - examples/**/*.gno + - examples/**/gno.mod concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -70,12 +71,20 @@ jobs: - run: make lint -C ./examples # TODO: consider running lint on every other directories, maybe in "warning" mode? # TODO: track coverage + fmt: name: Run gno fmt on examples uses: ./.github/workflows/gnofmt_template.yml with: path: "examples/..." + generate: + name: Check generated files are up to date + uses: ./.github/workflows/build_template.yml + with: + modulepath: "examples" + go-version: "1.22.x" + mod-tidy: strategy: fail-fast: false diff --git a/contribs/gnodev/go.mod b/contribs/gnodev/go.mod index 6ca47408a75..92d8494fa40 100644 --- a/contribs/gnodev/go.mod +++ b/contribs/gnodev/go.mod @@ -77,10 +77,12 @@ require ( github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/rs/cors v1.11.1 // indirect github.com/rs/xid v1.6.0 // indirect + github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b // indirect github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yuin/goldmark v1.7.2 // indirect github.com/yuin/goldmark-emoji v1.0.2 // indirect + github.com/yuin/goldmark-highlighting/v2 v2.0.0-20230729083705-37449abec8cc // indirect github.com/zondax/hid v0.9.2 // indirect github.com/zondax/ledger-go v0.14.3 // indirect go.etcd.io/bbolt v1.3.11 // indirect diff --git a/contribs/gnodev/go.sum b/contribs/gnodev/go.sum index 912345d61a8..3f22e4f2f00 100644 --- a/contribs/gnodev/go.sum +++ b/contribs/gnodev/go.sum @@ -3,8 +3,10 @@ dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE= github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/chroma/v2 v2.2.0/go.mod h1:vf4zrexSH54oEjJ7EdB65tGNHmH3pGZmVkgTP5RHvAs= github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E= github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I= +github.com/alecthomas/repr v0.0.0-20220113201626-b1b626ac65ae/go.mod h1:2kn6fqh/zIyPLmm3ugklbEi5hg5wS435eygvNfaDQL8= github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= @@ -91,6 +93,8 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeC github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= github.com/decred/dcrd/lru v1.0.0/go.mod h1:mxKOwFd7lFjN2GZYsiz/ecgqR6kkYAl+0pz0tEMk218= +github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= +github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= @@ -209,6 +213,8 @@ github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b h1:oV47z+jotrLVvhiLRNzACVe7/qZ8DcRlMlDucR/FARo= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b/go.mod h1:JprPCeMgYyLKJoAy9nxpVScm7NwFSwpibdrUKm4kcw0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -223,10 +229,13 @@ github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45 github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yuin/goldmark v1.3.7/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.15/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.7.2 h1:NjGd7lO7zrUn/A7eKwn5PEOt4ONYGqpxSEeZuduvgxc= github.com/yuin/goldmark v1.7.2/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/yuin/goldmark-emoji v1.0.2 h1:c/RgTShNgHTtc6xdz2KKI74jJr6rWi7FPgnP9GAsO5s= github.com/yuin/goldmark-emoji v1.0.2/go.mod h1:RhP/RWpexdp+KHs7ghKnifRoIs/Bq4nDS7tRbCkOwKY= +github.com/yuin/goldmark-highlighting/v2 v2.0.0-20230729083705-37449abec8cc h1:+IAOyRda+RLrxa1WC7umKOZRsGq4QrFFMYApOeHzQwQ= +github.com/yuin/goldmark-highlighting/v2 v2.0.0-20230729083705-37449abec8cc/go.mod h1:ovIvrum6DQJA4QsJSovrkC4saKHQVs7TvcaeO8AIl5I= github.com/zondax/hid v0.9.2 h1:WCJFnEDMiqGF64nlZz28E9qLVZ0KSJ7xpc5DLEyma2U= github.com/zondax/hid v0.9.2/go.mod h1:l5wttcP0jwtdLjqjMMWFVEE7d1zO0jvSPA9OPZxWpEM= github.com/zondax/ledger-go v0.14.3 h1:wEpJt2CEcBJ428md/5MgSLsXLBos98sBOyxNmCjfUCw= diff --git a/contribs/gnodev/pkg/dev/node.go b/contribs/gnodev/pkg/dev/node.go index fa9e2d11e29..12a88490515 100644 --- a/contribs/gnodev/pkg/dev/node.go +++ b/contribs/gnodev/pkg/dev/node.go @@ -489,6 +489,8 @@ func (n *Node) rebuildNode(ctx context.Context, genesis gnoland.GnoGenesisState) // Speed up stdlib loading after first start (saves about 2-3 seconds on each reload). nodeConfig.CacheStdlibLoad = true nodeConfig.Genesis.ConsensusParams.Block.MaxGas = n.config.MaxGasPerBlock + // Genesis verification is always false with Gnodev + nodeConfig.SkipGenesisVerification = true // recoverFromError handles panics and converts them to errors. recoverFromError := func() { diff --git a/contribs/gnofaucet/go.mod b/contribs/gnofaucet/go.mod index 3abc189b86a..3d1e5f54c54 100644 --- a/contribs/gnofaucet/go.mod +++ b/contribs/gnofaucet/go.mod @@ -31,6 +31,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rs/cors v1.11.1 // indirect github.com/rs/xid v1.6.0 // indirect + github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b // indirect go.opentelemetry.io/otel v1.29.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.29.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.29.0 // indirect diff --git a/contribs/gnofaucet/go.sum b/contribs/gnofaucet/go.sum index fafbc4d1060..10e2c19b408 100644 --- a/contribs/gnofaucet/go.sum +++ b/contribs/gnofaucet/go.sum @@ -111,6 +111,8 @@ github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b h1:oV47z+jotrLVvhiLRNzACVe7/qZ8DcRlMlDucR/FARo= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b/go.mod h1:JprPCeMgYyLKJoAy9nxpVScm7NwFSwpibdrUKm4kcw0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/contribs/gnogenesis/go.mod b/contribs/gnogenesis/go.mod index f1b316c2bee..3056af1d4cc 100644 --- a/contribs/gnogenesis/go.mod +++ b/contribs/gnogenesis/go.mod @@ -32,6 +32,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rs/cors v1.11.1 // indirect github.com/rs/xid v1.6.0 // indirect + github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b // indirect github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 // indirect github.com/zondax/hid v0.9.2 // indirect github.com/zondax/ledger-go v0.14.3 // indirect @@ -49,6 +50,7 @@ require ( golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/net v0.28.0 // indirect + golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.24.0 // indirect golang.org/x/term v0.23.0 // indirect golang.org/x/text v0.17.0 // indirect diff --git a/contribs/gnogenesis/go.sum b/contribs/gnogenesis/go.sum index 7ba3aede534..7e4a683cad1 100644 --- a/contribs/gnogenesis/go.sum +++ b/contribs/gnogenesis/go.sum @@ -120,6 +120,8 @@ github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b h1:oV47z+jotrLVvhiLRNzACVe7/qZ8DcRlMlDucR/FARo= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b/go.mod h1:JprPCeMgYyLKJoAy9nxpVScm7NwFSwpibdrUKm4kcw0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/contribs/gnogenesis/internal/txs/txs_add_packages.go b/contribs/gnogenesis/internal/txs/txs_add_packages.go index cf863c72116..0ab5724154e 100644 --- a/contribs/gnogenesis/internal/txs/txs_add_packages.go +++ b/contribs/gnogenesis/internal/txs/txs_add_packages.go @@ -5,8 +5,9 @@ import ( "errors" "flag" "fmt" + "os" - "github.com/gnolang/gno/tm2/pkg/crypto" + "github.com/gnolang/gno/tm2/pkg/crypto/keys" "github.com/gnolang/gno/gno.land/pkg/gnoland" "github.com/gnolang/gno/gno.land/pkg/gnoland/ugnot" @@ -15,28 +16,45 @@ import ( "github.com/gnolang/gno/tm2/pkg/std" ) -var ( - errInvalidPackageDir = errors.New("invalid package directory") - errInvalidDeployerAddr = errors.New("invalid deployer address") +const ( + defaultAccount_Name = "test1" + defaultAccount_Address = "g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5" + defaultAccount_Seed = "source bonus chronic canvas draft south burst lottery vacant surface solve popular case indicate oppose farm nothing bullet exhibit title speed wink action roast" + defaultAccount_publicKey = "gpub1pgfj7ard9eg82cjtv4u4xetrwqer2dntxyfzxz3pq0skzdkmzu0r9h6gny6eg8c9dc303xrrudee6z4he4y7cs5rnjwmyf40yaj" ) +var errInvalidPackageDir = errors.New("invalid package directory") + // Keep in sync with gno.land/cmd/start.go -var ( - defaultCreator = crypto.MustAddressFromString("g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5") // test1 - genesisDeployFee = std.NewFee(50000, std.MustParseCoin(ugnot.ValueString(1000000))) -) +var genesisDeployFee = std.NewFee(50000, std.MustParseCoin(ugnot.ValueString(1000000))) type addPkgCfg struct { - txsCfg *txsCfg - deployerAddress string + txsCfg *txsCfg + keyName string + gnoHome string // default GNOHOME env var, just here to ease testing with parallel tests + insecurePasswordStdin bool } func (c *addPkgCfg) RegisterFlags(fs *flag.FlagSet) { fs.StringVar( - &c.deployerAddress, - "deployer-address", - defaultCreator.String(), - "the address that will be used to deploy the package", + &c.keyName, + "key-name", + "", + "The package deployer key name or address contained on gnokey", + ) + + fs.StringVar( + &c.gnoHome, + "gno-home", + os.Getenv("GNOHOME"), + "the gno home directory", + ) + + fs.BoolVar( + &c.insecurePasswordStdin, + "insecure-password-stdin", + false, + "the gno home directory", ) } @@ -65,10 +83,15 @@ func execTxsAddPackages( io commands.IO, args []string, ) error { + var ( + keyname = defaultAccount_Name + keybase keys.Keybase + pass string + ) // Load the genesis - genesis, loadErr := types.GenesisDocFromFile(cfg.txsCfg.GenesisPath) - if loadErr != nil { - return fmt.Errorf("unable to load genesis, %w", loadErr) + genesis, err := types.GenesisDocFromFile(cfg.txsCfg.GenesisPath) + if err != nil { + return fmt.Errorf("unable to load genesis, %w", err) } // Make sure the package dir is set @@ -76,19 +99,30 @@ func execTxsAddPackages( return errInvalidPackageDir } - var ( - creator = defaultCreator - err error - ) - - // Check if the deployer address is set - if cfg.deployerAddress != defaultCreator.String() { - creator, err = crypto.AddressFromString(cfg.deployerAddress) + if cfg.keyName != "" { + keyname = cfg.keyName + keybase, err = keys.NewKeyBaseFromDir(cfg.gnoHome) + if err != nil { + return fmt.Errorf("unable to load keybase: %w", err) + } + pass, err = io.GetPassword("Enter password.", cfg.insecurePasswordStdin) + if err != nil { + return fmt.Errorf("cannot read password: %w", err) + } + } else { + keybase = keys.NewInMemory() + _, err := keybase.CreateAccount(defaultAccount_Name, defaultAccount_Seed, "", "", 0, 0) if err != nil { - return fmt.Errorf("%w, %w", errInvalidDeployerAddr, err) + return fmt.Errorf("unable to create account: %w", err) } } + info, err := keybase.GetByNameOrAddress(keyname) + if err != nil { + return fmt.Errorf("unable to find key in keybase: %w", err) + } + + creator := info.GetAddress() parsedTxs := make([]gnoland.TxWithMetadata, 0) for _, path := range args { // Generate transactions from the packages (recursively) @@ -97,6 +131,10 @@ func execTxsAddPackages( return fmt.Errorf("unable to load txs from directory, %w", err) } + if err := signTxs(txs, keybase, genesis.ChainID, keyname, pass); err != nil { + return fmt.Errorf("unable to sign txs, %w", err) + } + parsedTxs = append(parsedTxs, txs...) } @@ -117,3 +155,25 @@ func execTxsAddPackages( return nil } + +func signTxs(txs []gnoland.TxWithMetadata, keybase keys.Keybase, chainID, keyname string, password string) error { + for index, tx := range txs { + // Here accountNumber and sequenceNumber are set to 0 because they are considered as 0 on genesis transactions. + signBytes, err := tx.Tx.GetSignBytes(chainID, 0, 0) + if err != nil { + return fmt.Errorf("unable to load txs from directory, %w", err) + } + signature, publicKey, err := keybase.Sign(keyname, password, signBytes) + if err != nil { + return fmt.Errorf("unable sign tx %w", err) + } + txs[index].Tx.Signatures = []std.Signature{ + { + PubKey: publicKey, + Signature: signature, + }, + } + } + + return nil +} diff --git a/contribs/gnogenesis/internal/txs/txs_add_packages_test.go b/contribs/gnogenesis/internal/txs/txs_add_packages_test.go index c3405d6ff8d..38d930401e8 100644 --- a/contribs/gnogenesis/internal/txs/txs_add_packages_test.go +++ b/contribs/gnogenesis/internal/txs/txs_add_packages_test.go @@ -2,9 +2,11 @@ package txs import ( "context" + "encoding/hex" "fmt" "os" "path/filepath" + "strings" "testing" "github.com/gnolang/contribs/gnogenesis/internal/common" @@ -12,6 +14,8 @@ import ( vmm "github.com/gnolang/gno/gno.land/pkg/sdk/vm" "github.com/gnolang/gno/tm2/pkg/bft/types" "github.com/gnolang/gno/tm2/pkg/commands" + "github.com/gnolang/gno/tm2/pkg/crypto/keys" + "github.com/gnolang/gno/tm2/pkg/crypto/keys/client" "github.com/gnolang/gno/tm2/pkg/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -19,6 +23,7 @@ import ( func TestGenesis_Txs_Add_Packages(t *testing.T) { t.Parallel() + const addPkgExpectedSignature = "cfe5a15d8def04cbdaf9d08e2511db7928152b26419c4577cbfa282c83118852411f3de5d045ce934555572c21bda8042ce5c64b793a01748e49cf2cff7c2983" t.Run("invalid genesis file", func(t *testing.T) { t.Parallel() @@ -60,8 +65,10 @@ func TestGenesis_Txs_Add_Packages(t *testing.T) { assert.ErrorContains(t, cmdErr, errInvalidPackageDir.Error()) }) - t.Run("invalid deployer address", func(t *testing.T) { + t.Run("non existent key", func(t *testing.T) { t.Parallel() + keybaseDir := t.TempDir() + keyname := "beep-boop" tempGenesis, cleanup := testutils.NewTestFile(t) t.Cleanup(cleanup) @@ -69,24 +76,36 @@ func TestGenesis_Txs_Add_Packages(t *testing.T) { genesis := common.GetDefaultGenesis() require.NoError(t, genesis.SaveAs(tempGenesis.Name())) + io := commands.NewTestIO() + io.SetIn( + strings.NewReader( + fmt.Sprintf( + "%s\n", + "password", + ), + ), + ) // Create the command - cmd := NewTxsCmd(commands.NewTestIO()) + cmd := NewTxsCmd(io) args := []string{ "add", "packages", "--genesis-path", tempGenesis.Name(), t.TempDir(), // package dir - "--deployer-address", - "beep-boop", // invalid address + "--key-name", + keyname, // non-existent key name + "--gno-home", + keybaseDir, // temporaryDir for keybase + "--insecure-password-stdin", } // Run the command cmdErr := cmd.ParseAndRun(context.Background(), args) - assert.ErrorIs(t, cmdErr, errInvalidDeployerAddr) + assert.ErrorContains(t, cmdErr, "Key "+keyname+" not found") }) - t.Run("valid package", func(t *testing.T) { + t.Run("existent key wrong password", func(t *testing.T) { t.Parallel() tempGenesis, cleanup := testutils.NewTestFile(t) @@ -94,32 +113,189 @@ func TestGenesis_Txs_Add_Packages(t *testing.T) { genesis := common.GetDefaultGenesis() require.NoError(t, genesis.SaveAs(tempGenesis.Name())) + // Prepare the package + var ( + packagePath = "gno.land/p/demo/cuttlas" + dir = t.TempDir() + keybaseDir = t.TempDir() + keyname = "beep-boop" + password = "somepass" + ) + createValidFile(t, dir, packagePath) + // Create key + kb, err := keys.NewKeyBaseFromDir(keybaseDir) + require.NoError(t, err) + mnemonic, err := client.GenerateMnemonic(256) + require.NoError(t, err) + _, err = kb.CreateAccount(keyname, mnemonic, "", password+"wrong", 0, 0) + require.NoError(t, err) + + io := commands.NewTestIO() + io.SetIn( + strings.NewReader( + fmt.Sprintf( + "%s\n", + password, + ), + ), + ) + + // Create the command + cmd := NewTxsCmd(io) + args := []string{ + "add", + "packages", + "--genesis-path", + tempGenesis.Name(), + "--key-name", + keyname, // non-existent key name + "--gno-home", + keybaseDir, // temporaryDir for keybase + "--insecure-password-stdin", + dir, + } + + // Run the command + cmdErr := cmd.ParseAndRun(context.Background(), args) + assert.ErrorContains(t, cmdErr, "unable to sign txs") + }) + + t.Run("existent key correct password", func(t *testing.T) { + t.Parallel() + tempGenesis, cleanup := testutils.NewTestFile(t) + t.Cleanup(cleanup) + + genesis := common.GetDefaultGenesis() + require.NoError(t, genesis.SaveAs(tempGenesis.Name())) // Prepare the package var ( packagePath = "gno.land/p/demo/cuttlas" dir = t.TempDir() + keybaseDir = t.TempDir() + keyname = "beep-boop" + password = "somepass" ) + createValidFile(t, dir, packagePath) + // Create key + kb, err := keys.NewKeyBaseFromDir(keybaseDir) + require.NoError(t, err) + info, err := kb.CreateAccount(keyname, defaultAccount_Seed, "", password, 0, 0) + require.NoError(t, err) - createFile := func(path, data string) { - file, err := os.Create(path) - require.NoError(t, err) + io := commands.NewTestIO() + io.SetIn( + strings.NewReader( + fmt.Sprintf( + "%s\n", + password, + ), + ), + ) - _, err = file.WriteString(data) - require.NoError(t, err) + // Create the command + cmd := NewTxsCmd(io) + args := []string{ + "add", + "packages", + "--genesis-path", + tempGenesis.Name(), + "--key-name", + keyname, // non-existent key name + "--gno-home", + keybaseDir, // temporaryDir for keybase + "--insecure-password-stdin", + dir, } - // Create the gno.mod file - createFile( - filepath.Join(dir, "gno.mod"), - fmt.Sprintf("module %s\n", packagePath), + // Run the command + cmdErr := cmd.ParseAndRun(context.Background(), args) + require.NoError(t, cmdErr) + + // Validate the transactions were written down + updatedGenesis, err := types.GenesisDocFromFile(tempGenesis.Name()) + require.NoError(t, err) + require.NotNil(t, updatedGenesis.AppState) + + // Fetch the state + state := updatedGenesis.AppState.(gnoland.GnoGenesisState) + + require.Equal(t, 1, len(state.Txs)) + require.Equal(t, 1, len(state.Txs[0].Tx.Msgs)) + + msgAddPkg, ok := state.Txs[0].Tx.Msgs[0].(vmm.MsgAddPackage) + require.True(t, ok) + require.Equal(t, info.GetPubKey(), state.Txs[0].Tx.Signatures[0].PubKey) + require.Equal(t, addPkgExpectedSignature, hex.EncodeToString(state.Txs[0].Tx.Signatures[0].Signature)) + + assert.Equal(t, packagePath, msgAddPkg.Package.Path) + }) + + t.Run("ok default key", func(t *testing.T) { + t.Parallel() + + tempGenesis, cleanup := testutils.NewTestFile(t) + t.Cleanup(cleanup) + + genesis := common.GetDefaultGenesis() + require.NoError(t, genesis.SaveAs(tempGenesis.Name())) + // Prepare the package + var ( + packagePath = "gno.land/p/demo/cuttlas" + dir = t.TempDir() + keybaseDir = t.TempDir() ) + createValidFile(t, dir, packagePath) + + // Create the command + cmd := NewTxsCmd(commands.NewTestIO()) + args := []string{ + "add", + "packages", + "--genesis-path", + tempGenesis.Name(), + "--gno-home", + keybaseDir, // temporaryDir for keybase + dir, + } + + // Run the command + cmdErr := cmd.ParseAndRun(context.Background(), args) + require.NoError(t, cmdErr) + + // Validate the transactions were written down + updatedGenesis, err := types.GenesisDocFromFile(tempGenesis.Name()) + require.NoError(t, err) + require.NotNil(t, updatedGenesis.AppState) - // Create a simple main.gno - createFile( - filepath.Join(dir, "main.gno"), - "package cuttlas\n\nfunc Example() string {\nreturn \"Manos arriba!\"\n}", + // Fetch the state + state := updatedGenesis.AppState.(gnoland.GnoGenesisState) + + require.Equal(t, 1, len(state.Txs)) + require.Equal(t, 1, len(state.Txs[0].Tx.Msgs)) + + msgAddPkg, ok := state.Txs[0].Tx.Msgs[0].(vmm.MsgAddPackage) + require.True(t, ok) + require.Equal(t, defaultAccount_publicKey, state.Txs[0].Tx.Signatures[0].PubKey.String()) + require.Equal(t, addPkgExpectedSignature, hex.EncodeToString(state.Txs[0].Tx.Signatures[0].Signature)) + + assert.Equal(t, packagePath, msgAddPkg.Package.Path) + }) + + t.Run("valid package", func(t *testing.T) { + t.Parallel() + + tempGenesis, cleanup := testutils.NewTestFile(t) + t.Cleanup(cleanup) + + genesis := common.GetDefaultGenesis() + require.NoError(t, genesis.SaveAs(tempGenesis.Name())) + // Prepare the package + var ( + packagePath = "gno.land/p/demo/cuttlas" + dir = t.TempDir() ) + createValidFile(t, dir, packagePath) // Create the command cmd := NewTxsCmd(commands.NewTestIO()) @@ -148,7 +324,32 @@ func TestGenesis_Txs_Add_Packages(t *testing.T) { msgAddPkg, ok := state.Txs[0].Tx.Msgs[0].(vmm.MsgAddPackage) require.True(t, ok) + require.Equal(t, defaultAccount_publicKey, state.Txs[0].Tx.Signatures[0].PubKey.String()) + require.Equal(t, addPkgExpectedSignature, hex.EncodeToString(state.Txs[0].Tx.Signatures[0].Signature)) assert.Equal(t, packagePath, msgAddPkg.Package.Path) }) } + +func createValidFile(t *testing.T, dir string, packagePath string) { + t.Helper() + createFile := func(path, data string) { + file, err := os.Create(path) + require.NoError(t, err) + + _, err = file.WriteString(data) + require.NoError(t, err) + } + + // Create the gno.mod file + createFile( + filepath.Join(dir, "gno.mod"), + fmt.Sprintf("module %s\n", packagePath), + ) + + // Create a simple main.gno + createFile( + filepath.Join(dir, "main.gno"), + "package cuttlas\n\nfunc Example() string {\nreturn \"Manos arriba!\"\n}", + ) +} diff --git a/contribs/gnohealth/go.mod b/contribs/gnohealth/go.mod index 4f5862a0d2e..4a3f6392804 100644 --- a/contribs/gnohealth/go.mod +++ b/contribs/gnohealth/go.mod @@ -21,6 +21,7 @@ require ( github.com/peterbourgon/ff/v3 v3.4.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rs/xid v1.6.0 // indirect + github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b // indirect github.com/stretchr/testify v1.9.0 // indirect go.opentelemetry.io/otel v1.29.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.29.0 // indirect @@ -34,6 +35,7 @@ require ( golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/net v0.28.0 // indirect + golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.24.0 // indirect golang.org/x/term v0.23.0 // indirect golang.org/x/text v0.17.0 // indirect diff --git a/contribs/gnohealth/go.sum b/contribs/gnohealth/go.sum index dd287d9ca84..02e8893406a 100644 --- a/contribs/gnohealth/go.sum +++ b/contribs/gnohealth/go.sum @@ -103,6 +103,8 @@ github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b h1:oV47z+jotrLVvhiLRNzACVe7/qZ8DcRlMlDucR/FARo= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b/go.mod h1:JprPCeMgYyLKJoAy9nxpVScm7NwFSwpibdrUKm4kcw0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -149,6 +151,8 @@ golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/contribs/gnokeykc/go.mod b/contribs/gnokeykc/go.mod index 479daed22f6..157b5585828 100644 --- a/contribs/gnokeykc/go.mod +++ b/contribs/gnokeykc/go.mod @@ -35,6 +35,7 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rs/xid v1.6.0 // indirect + github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b // indirect github.com/stretchr/testify v1.9.0 // indirect github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 // indirect github.com/zondax/hid v0.9.2 // indirect @@ -52,6 +53,7 @@ require ( golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/net v0.28.0 // indirect + golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.24.0 // indirect golang.org/x/term v0.23.0 // indirect golang.org/x/text v0.17.0 // indirect diff --git a/contribs/gnokeykc/go.sum b/contribs/gnokeykc/go.sum index cacf6788d45..7aac05b84a0 100644 --- a/contribs/gnokeykc/go.sum +++ b/contribs/gnokeykc/go.sum @@ -124,6 +124,8 @@ github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b h1:oV47z+jotrLVvhiLRNzACVe7/qZ8DcRlMlDucR/FARo= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b/go.mod h1:JprPCeMgYyLKJoAy9nxpVScm7NwFSwpibdrUKm4kcw0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -180,6 +182,8 @@ golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/contribs/gnomd/go.mod b/contribs/gnomd/go.mod index 423e4414a79..57c07621324 100644 --- a/contribs/gnomd/go.mod +++ b/contribs/gnomd/go.mod @@ -21,7 +21,7 @@ require ( github.com/mattn/go-isatty v0.0.11 // indirect github.com/mattn/go-runewidth v0.0.12 // indirect github.com/rivo/uniseg v0.1.0 // indirect - golang.org/x/image v0.0.0-20191206065243-da761ea9ff43 // indirect + golang.org/x/image v0.18.0 // indirect golang.org/x/net v0.23.0 // indirect golang.org/x/sys v0.18.0 // indirect ) diff --git a/contribs/gnomd/go.sum b/contribs/gnomd/go.sum index 0ff70dd99fb..3d4666530b1 100644 --- a/contribs/gnomd/go.sum +++ b/contribs/gnomd/go.sum @@ -55,8 +55,9 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ golang.org/dl v0.0.0-20190829154251-82a15e2f2ead/go.mod h1:IUMfjQLJQd4UTqG1Z90tenwKoCX93Gn3MAQJMOSBsDQ= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20191206065243-da761ea9ff43 h1:gQ6GUSD102fPgli+Yb4cR/cGaHF7tNBt+GYoRCpGC7s= golang.org/x/image v0.0.0-20191206065243-da761ea9ff43/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= +golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= diff --git a/contribs/gnomigrate/go.mod b/contribs/gnomigrate/go.mod index cd31adc4f6f..49f40eb79af 100644 --- a/contribs/gnomigrate/go.mod +++ b/contribs/gnomigrate/go.mod @@ -29,6 +29,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rs/cors v1.11.1 // indirect github.com/rs/xid v1.6.0 // indirect + github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b // indirect github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 // indirect go.etcd.io/bbolt v1.3.11 // indirect go.opentelemetry.io/otel v1.29.0 // indirect @@ -44,6 +45,7 @@ require ( golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/net v0.28.0 // indirect + golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.24.0 // indirect golang.org/x/term v0.23.0 // indirect golang.org/x/text v0.17.0 // indirect diff --git a/contribs/gnomigrate/go.sum b/contribs/gnomigrate/go.sum index 7ba3aede534..7e4a683cad1 100644 --- a/contribs/gnomigrate/go.sum +++ b/contribs/gnomigrate/go.sum @@ -120,6 +120,8 @@ github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b h1:oV47z+jotrLVvhiLRNzACVe7/qZ8DcRlMlDucR/FARo= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b/go.mod h1:JprPCeMgYyLKJoAy9nxpVScm7NwFSwpibdrUKm4kcw0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/contribs/gnomigrate/internal/txs/txs.go b/contribs/gnomigrate/internal/txs/txs.go index 4c65ca6ef0b..231428d5064 100644 --- a/contribs/gnomigrate/internal/txs/txs.go +++ b/contribs/gnomigrate/internal/txs/txs.go @@ -184,7 +184,7 @@ func processFile(ctx context.Context, io commands.IO, source, destination string continue } - if _, err = outputFile.WriteString(fmt.Sprintf("%s\n", string(marshaledData))); err != nil { + if _, err = fmt.Fprintf(outputFile, "%s\n", marshaledData); err != nil { io.ErrPrintfln("unable to save to output file, %s", err) } } diff --git a/examples/gno.land/p/demo/grc/grc20/token.gno b/examples/gno.land/p/demo/grc/grc20/token.gno index 4634bae933b..4986eaebf04 100644 --- a/examples/gno.land/p/demo/grc/grc20/token.gno +++ b/examples/gno.land/p/demo/grc/grc20/token.gno @@ -1,7 +1,6 @@ package grc20 import ( - "math/overflow" "std" "strconv" @@ -65,6 +64,14 @@ func (tok *Token) RenderHome() string { return str } +// Getter returns a TokenGetter function that returns this token. This allows +// storing indirect pointers to a token in a remote realm. +func (tok *Token) Getter() TokenGetter { + return func() *Token { + return tok + } +} + // SpendAllowance decreases the allowance of the specified owner and spender. func (led *PrivateLedger) SpendAllowance(owner, spender std.Address, amount uint64) error { if !owner.IsValid() { @@ -162,17 +169,24 @@ func (led *PrivateLedger) Approve(owner, spender std.Address, amount uint64) err } // Mint increases the total supply of the token and adds the specified amount to the specified address. -func (led *PrivateLedger) Mint(address std.Address, amount uint64) error { +func (led *PrivateLedger) Mint(address std.Address, amount uint64) (err error) { if !address.IsValid() { return ErrInvalidAddress } - // XXX: math/overflow is not supporting uint64. - // This checks prevents overflow but makes the totalSupply limited to a uint63. - sum, ok := overflow.Add64(int64(led.totalSupply), int64(amount)) - if !ok { - return ErrOverflow - } + defer func() { + if r := recover(); r != nil { + if r != "addition overflow" { + panic(r) + } + err = ErrOverflow + } + }() + + // Convert amount and totalSupply to signed integers to enable + // overflow checking (not occuring on unsigned) when computing the sum. + // The maximum value for totalSupply is therefore 1<<63. + sum := int64(led.totalSupply) + int64(amount) led.totalSupply = uint64(sum) currentBalance := led.balanceOf(address) diff --git a/examples/gno.land/p/moul/addrset/addrset.gno b/examples/gno.land/p/moul/addrset/addrset.gno new file mode 100644 index 00000000000..0bb8165f9fe --- /dev/null +++ b/examples/gno.land/p/moul/addrset/addrset.gno @@ -0,0 +1,100 @@ +// Package addrset provides a specialized set data structure for managing unique Gno addresses. +// +// It is built on top of an AVL tree for efficient operations and maintains addresses in sorted order. +// This package is particularly useful when you need to: +// - Track a collection of unique addresses (e.g., for whitelists, participants, etc.) +// - Efficiently check address membership +// - Support pagination when displaying addresses +// +// Example usage: +// +// import ( +// "std" +// "gno.land/p/moul/addrset" +// ) +// +// func MyHandler() { +// // Create a new address set +// var set addrset.Set +// +// // Add some addresses +// addr1 := std.Address("g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5") +// addr2 := std.Address("g1sss5g0rkqr88k4u648yd5d3l9t4d8vvqwszqth") +// +// set.Add(addr1) // returns true (newly added) +// set.Add(addr2) // returns true (newly added) +// set.Add(addr1) // returns false (already exists) +// +// // Check membership +// if set.Has(addr1) { +// // addr1 is in the set +// } +// +// // Get size +// size := set.Size() // returns 2 +// +// // Iterate with pagination (10 items per page, starting at offset 0) +// set.IterateByOffset(0, 10, func(addr std.Address) bool { +// // Process addr +// return false // continue iteration +// }) +// +// // Remove an address +// set.Remove(addr1) // returns true (was present) +// set.Remove(addr1) // returns false (not present) +// } +package addrset + +import ( + "std" + + "gno.land/p/demo/avl" +) + +type Set struct { + tree avl.Tree +} + +// Add inserts an address into the set. +// Returns true if the address was newly added, false if it already existed. +func (s *Set) Add(addr std.Address) bool { + return !s.tree.Set(string(addr), nil) +} + +// Remove deletes an address from the set. +// Returns true if the address was found and removed, false if it didn't exist. +func (s *Set) Remove(addr std.Address) bool { + _, removed := s.tree.Remove(string(addr)) + return removed +} + +// Has checks if an address exists in the set. +func (s *Set) Has(addr std.Address) bool { + return s.tree.Has(string(addr)) +} + +// Size returns the number of addresses in the set. +func (s *Set) Size() int { + return s.tree.Size() +} + +// IterateByOffset walks through addresses starting at the given offset. +// The callback should return true to stop iteration. +func (s *Set) IterateByOffset(offset int, count int, cb func(addr std.Address) bool) { + s.tree.IterateByOffset(offset, count, func(key string, _ interface{}) bool { + return cb(std.Address(key)) + }) +} + +// ReverseIterateByOffset walks through addresses in reverse order starting at the given offset. +// The callback should return true to stop iteration. +func (s *Set) ReverseIterateByOffset(offset int, count int, cb func(addr std.Address) bool) { + s.tree.ReverseIterateByOffset(offset, count, func(key string, _ interface{}) bool { + return cb(std.Address(key)) + }) +} + +// Tree returns the underlying AVL tree for advanced usage. +func (s *Set) Tree() avl.ITree { + return &s.tree +} diff --git a/examples/gno.land/p/moul/addrset/addrset_test.gno b/examples/gno.land/p/moul/addrset/addrset_test.gno new file mode 100644 index 00000000000..c3e27eab1df --- /dev/null +++ b/examples/gno.land/p/moul/addrset/addrset_test.gno @@ -0,0 +1,174 @@ +package addrset + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" +) + +func TestSet(t *testing.T) { + addr1 := std.Address("addr1") + addr2 := std.Address("addr2") + addr3 := std.Address("addr3") + + tests := []struct { + name string + actions func(s *Set) + size int + has map[std.Address]bool + addrs []std.Address // for iteration checks + }{ + { + name: "empty set", + actions: func(s *Set) {}, + size: 0, + has: map[std.Address]bool{addr1: false}, + }, + { + name: "single address", + actions: func(s *Set) { + s.Add(addr1) + }, + size: 1, + has: map[std.Address]bool{ + addr1: true, + addr2: false, + }, + addrs: []std.Address{addr1}, + }, + { + name: "multiple addresses", + actions: func(s *Set) { + s.Add(addr1) + s.Add(addr2) + s.Add(addr3) + }, + size: 3, + has: map[std.Address]bool{ + addr1: true, + addr2: true, + addr3: true, + }, + addrs: []std.Address{addr1, addr2, addr3}, + }, + { + name: "remove address", + actions: func(s *Set) { + s.Add(addr1) + s.Add(addr2) + s.Remove(addr1) + }, + size: 1, + has: map[std.Address]bool{ + addr1: false, + addr2: true, + }, + addrs: []std.Address{addr2}, + }, + { + name: "duplicate adds", + actions: func(s *Set) { + uassert.True(t, s.Add(addr1)) // first add returns true + uassert.False(t, s.Add(addr1)) // second add returns false + uassert.True(t, s.Remove(addr1)) // remove existing returns true + uassert.False(t, s.Remove(addr1)) // remove non-existing returns false + }, + size: 0, + has: map[std.Address]bool{ + addr1: false, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var set Set + + // Execute test actions + tt.actions(&set) + + // Check size + uassert.Equal(t, tt.size, set.Size()) + + // Check existence + for addr, expected := range tt.has { + uassert.Equal(t, expected, set.Has(addr)) + } + + // Check iteration if addresses are specified + if tt.addrs != nil { + collected := []std.Address{} + set.IterateByOffset(0, 10, func(addr std.Address) bool { + collected = append(collected, addr) + return false + }) + + // Check length + uassert.Equal(t, len(tt.addrs), len(collected)) + + // Check each address + for i, addr := range tt.addrs { + uassert.Equal(t, addr, collected[i]) + } + } + }) + } +} + +func TestSetIterationLimits(t *testing.T) { + tests := []struct { + name string + addrs []std.Address + offset int + limit int + expected int + }{ + { + name: "zero offset full list", + addrs: []std.Address{"a1", "a2", "a3"}, + offset: 0, + limit: 10, + expected: 3, + }, + { + name: "offset with limit", + addrs: []std.Address{"a1", "a2", "a3", "a4"}, + offset: 1, + limit: 2, + expected: 2, + }, + { + name: "offset beyond size", + addrs: []std.Address{"a1", "a2"}, + offset: 3, + limit: 1, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var set Set + for _, addr := range tt.addrs { + set.Add(addr) + } + + // Test forward iteration + count := 0 + set.IterateByOffset(tt.offset, tt.limit, func(addr std.Address) bool { + count++ + return false + }) + uassert.Equal(t, tt.expected, count) + + // Test reverse iteration + count = 0 + set.ReverseIterateByOffset(tt.offset, tt.limit, func(addr std.Address) bool { + count++ + return false + }) + uassert.Equal(t, tt.expected, count) + }) + } +} diff --git a/examples/gno.land/p/moul/addrset/gno.mod b/examples/gno.land/p/moul/addrset/gno.mod new file mode 100644 index 00000000000..45bb53b399c --- /dev/null +++ b/examples/gno.land/p/moul/addrset/gno.mod @@ -0,0 +1 @@ +module gno.land/p/moul/addrset diff --git a/examples/gno.land/p/moul/fp/fp.gno b/examples/gno.land/p/moul/fp/fp.gno new file mode 100644 index 00000000000..b2811c77d5a --- /dev/null +++ b/examples/gno.land/p/moul/fp/fp.gno @@ -0,0 +1,270 @@ +// Package fp provides functional programming utilities for Gno, enabling +// transformations, filtering, and other operations on slices of interface{}. +// +// Example of chaining operations: +// +// numbers := []interface{}{1, 2, 3, 4, 5, 6} +// +// // Define predicates, mappers and reducers +// isEven := func(v interface{}) bool { return v.(int)%2 == 0 } +// double := func(v interface{}) interface{} { return v.(int) * 2 } +// sum := func(a, b interface{}) interface{} { return a.(int) + b.(int) } +// +// // Chain operations: filter even numbers, double them, then sum +// evenNums := Filter(numbers, isEven) // [2, 4, 6] +// doubled := Map(evenNums, double) // [4, 8, 12] +// result := Reduce(doubled, sum, 0) // 24 +// +// // Alternative: group by even/odd, then get even numbers +// byMod2 := func(v interface{}) interface{} { return v.(int) % 2 } +// grouped := GroupBy(numbers, byMod2) // {0: [2,4,6], 1: [1,3,5]} +// evens := grouped[0] // [2,4,6] +package fp + +// Mapper is a function type that maps an element to another element. +type Mapper func(interface{}) interface{} + +// Predicate is a function type that evaluates a condition on an element. +type Predicate func(interface{}) bool + +// Reducer is a function type that reduces two elements to a single value. +type Reducer func(interface{}, interface{}) interface{} + +// Filter filters elements from the slice that satisfy the given predicate. +// +// Example: +// +// numbers := []interface{}{-1, 0, 1, 2} +// isPositive := func(v interface{}) bool { return v.(int) > 0 } +// result := Filter(numbers, isPositive) // [1, 2] +func Filter(values []interface{}, fn Predicate) []interface{} { + result := []interface{}{} + for _, v := range values { + if fn(v) { + result = append(result, v) + } + } + return result +} + +// Map applies a function to each element in the slice. +// +// Example: +// +// numbers := []interface{}{1, 2, 3} +// toString := func(v interface{}) interface{} { return fmt.Sprintf("%d", v) } +// result := Map(numbers, toString) // ["1", "2", "3"] +func Map(values []interface{}, fn Mapper) []interface{} { + result := make([]interface{}, len(values)) + for i, v := range values { + result[i] = fn(v) + } + return result +} + +// Reduce reduces a slice to a single value by applying a function. +// +// Example: +// +// numbers := []interface{}{1, 2, 3, 4} +// sum := func(a, b interface{}) interface{} { return a.(int) + b.(int) } +// result := Reduce(numbers, sum, 0) // 10 +func Reduce(values []interface{}, fn Reducer, initial interface{}) interface{} { + acc := initial + for _, v := range values { + acc = fn(acc, v) + } + return acc +} + +// FlatMap maps each element to a collection and flattens the results. +// +// Example: +// +// words := []interface{}{"hello", "world"} +// split := func(v interface{}) interface{} { +// chars := []interface{}{} +// for _, c := range v.(string) { +// chars = append(chars, string(c)) +// } +// return chars +// } +// result := FlatMap(words, split) // ["h","e","l","l","o","w","o","r","l","d"] +func FlatMap(values []interface{}, fn Mapper) []interface{} { + result := []interface{}{} + for _, v := range values { + inner := fn(v).([]interface{}) + result = append(result, inner...) + } + return result +} + +// All returns true if all elements satisfy the predicate. +// +// Example: +// +// numbers := []interface{}{2, 4, 6, 8} +// isEven := func(v interface{}) bool { return v.(int)%2 == 0 } +// result := All(numbers, isEven) // true +func All(values []interface{}, fn Predicate) bool { + for _, v := range values { + if !fn(v) { + return false + } + } + return true +} + +// Any returns true if at least one element satisfies the predicate. +// +// Example: +// +// numbers := []interface{}{1, 3, 4, 7} +// isEven := func(v interface{}) bool { return v.(int)%2 == 0 } +// result := Any(numbers, isEven) // true (4 is even) +func Any(values []interface{}, fn Predicate) bool { + for _, v := range values { + if fn(v) { + return true + } + } + return false +} + +// None returns true if no elements satisfy the predicate. +// +// Example: +// +// numbers := []interface{}{1, 3, 5, 7} +// isEven := func(v interface{}) bool { return v.(int)%2 == 0 } +// result := None(numbers, isEven) // true (no even numbers) +func None(values []interface{}, fn Predicate) bool { + for _, v := range values { + if fn(v) { + return false + } + } + return true +} + +// Chunk splits a slice into chunks of the given size. +// +// Example: +// +// numbers := []interface{}{1, 2, 3, 4, 5} +// result := Chunk(numbers, 2) // [[1,2], [3,4], [5]] +func Chunk(values []interface{}, size int) [][]interface{} { + if size <= 0 { + return nil + } + var chunks [][]interface{} + for i := 0; i < len(values); i += size { + end := i + size + if end > len(values) { + end = len(values) + } + chunks = append(chunks, values[i:end]) + } + return chunks +} + +// Find returns the first element that satisfies the predicate and a boolean indicating if an element was found. +// +// Example: +// +// numbers := []interface{}{1, 2, 3, 4} +// isEven := func(v interface{}) bool { return v.(int)%2 == 0 } +// result, found := Find(numbers, isEven) // 2, true +func Find(values []interface{}, fn Predicate) (interface{}, bool) { + for _, v := range values { + if fn(v) { + return v, true + } + } + return nil, false +} + +// Reverse reverses the order of elements in a slice. +// +// Example: +// +// numbers := []interface{}{1, 2, 3} +// result := Reverse(numbers) // [3, 2, 1] +func Reverse(values []interface{}) []interface{} { + result := make([]interface{}, len(values)) + for i, v := range values { + result[len(values)-1-i] = v + } + return result +} + +// Zip combines two slices into a slice of pairs. If the slices have different lengths, +// extra elements from the longer slice are ignored. +// +// Example: +// +// a := []interface{}{1, 2, 3} +// b := []interface{}{"a", "b", "c"} +// result := Zip(a, b) // [[1,"a"], [2,"b"], [3,"c"]] +func Zip(a, b []interface{}) [][2]interface{} { + length := min(len(a), len(b)) + result := make([][2]interface{}, length) + for i := 0; i < length; i++ { + result[i] = [2]interface{}{a[i], b[i]} + } + return result +} + +// Unzip splits a slice of pairs into two separate slices. +// +// Example: +// +// pairs := [][2]interface{}{{1,"a"}, {2,"b"}, {3,"c"}} +// numbers, letters := Unzip(pairs) // [1,2,3], ["a","b","c"] +func Unzip(pairs [][2]interface{}) ([]interface{}, []interface{}) { + a := make([]interface{}, len(pairs)) + b := make([]interface{}, len(pairs)) + for i, pair := range pairs { + a[i] = pair[0] + b[i] = pair[1] + } + return a, b +} + +// GroupBy groups elements based on a key returned by a Mapper. +// +// Example: +// +// numbers := []interface{}{1, 2, 3, 4, 5, 6} +// byMod3 := func(v interface{}) interface{} { return v.(int) % 3 } +// result := GroupBy(numbers, byMod3) // {0: [3,6], 1: [1,4], 2: [2,5]} +func GroupBy(values []interface{}, fn Mapper) map[interface{}][]interface{} { + result := make(map[interface{}][]interface{}) + for _, v := range values { + key := fn(v) + result[key] = append(result[key], v) + } + return result +} + +// Flatten flattens a slice of slices into a single slice. +// +// Example: +// +// nested := [][]interface{}{{1,2}, {3,4}, {5}} +// result := Flatten(nested) // [1,2,3,4,5] +func Flatten(values [][]interface{}) []interface{} { + result := []interface{}{} + for _, v := range values { + result = append(result, v...) + } + return result +} + +// Helper functions +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/examples/gno.land/p/moul/fp/fp_test.gno b/examples/gno.land/p/moul/fp/fp_test.gno new file mode 100644 index 00000000000..00957486fe9 --- /dev/null +++ b/examples/gno.land/p/moul/fp/fp_test.gno @@ -0,0 +1,666 @@ +package fp + +import ( + "fmt" + "testing" +) + +func TestMap(t *testing.T) { + tests := []struct { + name string + input []interface{} + fn func(interface{}) interface{} + expected []interface{} + }{ + { + name: "multiply numbers by 2", + input: []interface{}{1, 2, 3}, + fn: func(v interface{}) interface{} { return v.(int) * 2 }, + expected: []interface{}{2, 4, 6}, + }, + { + name: "empty slice", + input: []interface{}{}, + fn: func(v interface{}) interface{} { return v.(int) * 2 }, + expected: []interface{}{}, + }, + { + name: "convert numbers to strings", + input: []interface{}{1, 2, 3}, + fn: func(v interface{}) interface{} { return fmt.Sprintf("%d", v.(int)) }, + expected: []interface{}{"1", "2", "3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Map(tt.input, tt.fn) + if !equalSlices(result, tt.expected) { + t.Errorf("Map failed, expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestFilter(t *testing.T) { + tests := []struct { + name string + input []interface{} + fn func(interface{}) bool + expected []interface{} + }{ + { + name: "filter even numbers", + input: []interface{}{1, 2, 3, 4}, + fn: func(v interface{}) bool { return v.(int)%2 == 0 }, + expected: []interface{}{2, 4}, + }, + { + name: "empty slice", + input: []interface{}{}, + fn: func(v interface{}) bool { return v.(int)%2 == 0 }, + expected: []interface{}{}, + }, + { + name: "no matches", + input: []interface{}{1, 3, 5}, + fn: func(v interface{}) bool { return v.(int)%2 == 0 }, + expected: []interface{}{}, + }, + { + name: "all matches", + input: []interface{}{2, 4, 6}, + fn: func(v interface{}) bool { return v.(int)%2 == 0 }, + expected: []interface{}{2, 4, 6}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Filter(tt.input, tt.fn) + if !equalSlices(result, tt.expected) { + t.Errorf("Filter failed, expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestReduce(t *testing.T) { + tests := []struct { + name string + input []interface{} + fn func(interface{}, interface{}) interface{} + initial interface{} + expected interface{} + }{ + { + name: "sum numbers", + input: []interface{}{1, 2, 3}, + fn: func(a, b interface{}) interface{} { return a.(int) + b.(int) }, + initial: 0, + expected: 6, + }, + { + name: "empty slice", + input: []interface{}{}, + fn: func(a, b interface{}) interface{} { return a.(int) + b.(int) }, + initial: 0, + expected: 0, + }, + { + name: "concatenate strings", + input: []interface{}{"a", "b", "c"}, + fn: func(a, b interface{}) interface{} { return a.(string) + b.(string) }, + initial: "", + expected: "abc", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Reduce(tt.input, tt.fn, tt.initial) + if result != tt.expected { + t.Errorf("Reduce failed, expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestFlatMap(t *testing.T) { + tests := []struct { + name string + input []interface{} + fn func(interface{}) interface{} + expected []interface{} + }{ + { + name: "split words into chars", + input: []interface{}{"go", "fn"}, + fn: func(word interface{}) interface{} { + chars := []interface{}{} + for _, c := range word.(string) { + chars = append(chars, string(c)) + } + return chars + }, + expected: []interface{}{"g", "o", "f", "n"}, + }, + { + name: "empty string handling", + input: []interface{}{"", "a", ""}, + fn: func(word interface{}) interface{} { + chars := []interface{}{} + for _, c := range word.(string) { + chars = append(chars, string(c)) + } + return chars + }, + expected: []interface{}{"a"}, + }, + { + name: "nil handling", + input: []interface{}{nil, "a", nil}, + fn: func(word interface{}) interface{} { + if word == nil { + return []interface{}{} + } + return []interface{}{word} + }, + expected: []interface{}{"a"}, + }, + { + name: "empty slice result", + input: []interface{}{"", "", ""}, + fn: func(word interface{}) interface{} { + return []interface{}{} + }, + expected: []interface{}{}, + }, + { + name: "nested array flattening", + input: []interface{}{1, 2, 3}, + fn: func(n interface{}) interface{} { + return []interface{}{n, n} + }, + expected: []interface{}{1, 1, 2, 2, 3, 3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := FlatMap(tt.input, tt.fn) + if !equalSlices(result, tt.expected) { + t.Errorf("FlatMap failed, expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestAllAnyNone(t *testing.T) { + tests := []struct { + name string + input []interface{} + fn func(interface{}) bool + expectedAll bool + expectedAny bool + expectedNone bool + }{ + { + name: "all even numbers", + input: []interface{}{2, 4, 6, 8}, + fn: func(x interface{}) bool { return x.(int)%2 == 0 }, + expectedAll: true, + expectedAny: true, + expectedNone: false, + }, + { + name: "no even numbers", + input: []interface{}{1, 3, 5, 7}, + fn: func(x interface{}) bool { return x.(int)%2 == 0 }, + expectedAll: false, + expectedAny: false, + expectedNone: true, + }, + { + name: "mixed even/odd numbers", + input: []interface{}{1, 2, 3, 4}, + fn: func(x interface{}) bool { return x.(int)%2 == 0 }, + expectedAll: false, + expectedAny: true, + expectedNone: false, + }, + { + name: "empty slice", + input: []interface{}{}, + fn: func(x interface{}) bool { return x.(int)%2 == 0 }, + expectedAll: true, // vacuously true + expectedAny: false, // vacuously false + expectedNone: true, // vacuously true + }, + { + name: "nil predicate handling", + input: []interface{}{nil, nil, nil}, + fn: func(x interface{}) bool { return x == nil }, + expectedAll: true, + expectedAny: true, + expectedNone: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resultAll := All(tt.input, tt.fn) + if resultAll != tt.expectedAll { + t.Errorf("All failed, expected %v, got %v", tt.expectedAll, resultAll) + } + + resultAny := Any(tt.input, tt.fn) + if resultAny != tt.expectedAny { + t.Errorf("Any failed, expected %v, got %v", tt.expectedAny, resultAny) + } + + resultNone := None(tt.input, tt.fn) + if resultNone != tt.expectedNone { + t.Errorf("None failed, expected %v, got %v", tt.expectedNone, resultNone) + } + }) + } +} + +func TestChunk(t *testing.T) { + tests := []struct { + name string + input []interface{} + size int + expected [][]interface{} + }{ + { + name: "normal chunks", + input: []interface{}{1, 2, 3, 4, 5}, + size: 2, + expected: [][]interface{}{{1, 2}, {3, 4}, {5}}, + }, + { + name: "empty slice", + input: []interface{}{}, + size: 2, + expected: [][]interface{}{}, + }, + { + name: "chunk size equals length", + input: []interface{}{1, 2, 3}, + size: 3, + expected: [][]interface{}{{1, 2, 3}}, + }, + { + name: "chunk size larger than length", + input: []interface{}{1, 2}, + size: 3, + expected: [][]interface{}{{1, 2}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Chunk(tt.input, tt.size) + if !equalNestedSlices(result, tt.expected) { + t.Errorf("Chunk failed, expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestFind(t *testing.T) { + tests := []struct { + name string + input []interface{} + fn func(interface{}) bool + expected interface{} + shouldFound bool + }{ + { + name: "find first number greater than 2", + input: []interface{}{1, 2, 3, 4}, + fn: func(v interface{}) bool { return v.(int) > 2 }, + expected: 3, + shouldFound: true, + }, + { + name: "empty slice", + input: []interface{}{}, + fn: func(v interface{}) bool { return v.(int) > 2 }, + expected: nil, + shouldFound: false, + }, + { + name: "no match", + input: []interface{}{1, 2}, + fn: func(v interface{}) bool { return v.(int) > 10 }, + expected: nil, + shouldFound: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, found := Find(tt.input, tt.fn) + if found != tt.shouldFound { + t.Errorf("Find failed, expected found=%v, got found=%v", tt.shouldFound, found) + } + if found && result != tt.expected { + t.Errorf("Find failed, expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestReverse(t *testing.T) { + tests := []struct { + name string + input []interface{} + expected []interface{} + }{ + { + name: "normal sequence", + input: []interface{}{1, 2, 3, 4}, + expected: []interface{}{4, 3, 2, 1}, + }, + { + name: "empty slice", + input: []interface{}{}, + expected: []interface{}{}, + }, + { + name: "single element", + input: []interface{}{1}, + expected: []interface{}{1}, + }, + { + name: "mixed types", + input: []interface{}{1, "a", true, 2.5}, + expected: []interface{}{2.5, true, "a", 1}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Reverse(tt.input) + if !equalSlices(result, tt.expected) { + t.Errorf("Reverse failed, expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestZipUnzip(t *testing.T) { + tests := []struct { + name string + a []interface{} + b []interface{} + expectedZip [][2]interface{} + expectedA []interface{} + expectedB []interface{} + }{ + { + name: "normal case", + a: []interface{}{1, 2, 3}, + b: []interface{}{"a", "b", "c"}, + expectedZip: [][2]interface{}{{1, "a"}, {2, "b"}, {3, "c"}}, + expectedA: []interface{}{1, 2, 3}, + expectedB: []interface{}{"a", "b", "c"}, + }, + { + name: "empty slices", + a: []interface{}{}, + b: []interface{}{}, + expectedZip: [][2]interface{}{}, + expectedA: []interface{}{}, + expectedB: []interface{}{}, + }, + { + name: "different lengths - a shorter", + a: []interface{}{1, 2}, + b: []interface{}{"a", "b", "c"}, + expectedZip: [][2]interface{}{{1, "a"}, {2, "b"}}, + expectedA: []interface{}{1, 2}, + expectedB: []interface{}{"a", "b"}, + }, + { + name: "different lengths - b shorter", + a: []interface{}{1, 2, 3}, + b: []interface{}{"a"}, + expectedZip: [][2]interface{}{{1, "a"}}, + expectedA: []interface{}{1}, + expectedB: []interface{}{"a"}, + }, + { + name: "mixed types", + a: []interface{}{1, true, "x"}, + b: []interface{}{2.5, false, "y"}, + expectedZip: [][2]interface{}{{1, 2.5}, {true, false}, {"x", "y"}}, + expectedA: []interface{}{1, true, "x"}, + expectedB: []interface{}{2.5, false, "y"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + zipped := Zip(tt.a, tt.b) + if len(zipped) != len(tt.expectedZip) { + t.Errorf("Zip failed, expected length %v, got %v", len(tt.expectedZip), len(zipped)) + } + for i, pair := range zipped { + if pair[0] != tt.expectedZip[i][0] || pair[1] != tt.expectedZip[i][1] { + t.Errorf("Zip failed at index %d, expected %v, got %v", i, tt.expectedZip[i], pair) + } + } + + unzippedA, unzippedB := Unzip(zipped) + if !equalSlices(unzippedA, tt.expectedA) { + t.Errorf("Unzip failed for slice A, expected %v, got %v", tt.expectedA, unzippedA) + } + if !equalSlices(unzippedB, tt.expectedB) { + t.Errorf("Unzip failed for slice B, expected %v, got %v", tt.expectedB, unzippedB) + } + }) + } +} + +func TestGroupBy(t *testing.T) { + tests := []struct { + name string + input []interface{} + fn func(interface{}) interface{} + expected map[interface{}][]interface{} + }{ + { + name: "group by even/odd", + input: []interface{}{1, 2, 3, 4, 5, 6}, + fn: func(v interface{}) interface{} { return v.(int) % 2 }, + expected: map[interface{}][]interface{}{ + 0: {2, 4, 6}, + 1: {1, 3, 5}, + }, + }, + { + name: "empty slice", + input: []interface{}{}, + fn: func(v interface{}) interface{} { return v.(int) % 2 }, + expected: map[interface{}][]interface{}{}, + }, + { + name: "single group", + input: []interface{}{2, 4, 6}, + fn: func(v interface{}) interface{} { return v.(int) % 2 }, + expected: map[interface{}][]interface{}{ + 0: {2, 4, 6}, + }, + }, + { + name: "group by type", + input: []interface{}{1, "a", 2, "b", true}, + fn: func(v interface{}) interface{} { + switch v.(type) { + case int: + return "int" + case string: + return "string" + default: + return "other" + } + }, + expected: map[interface{}][]interface{}{ + "int": {1, 2}, + "string": {"a", "b"}, + "other": {true}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GroupBy(tt.input, tt.fn) + if len(result) != len(tt.expected) { + t.Errorf("GroupBy failed, expected %d groups, got %d", len(tt.expected), len(result)) + } + for k, v := range tt.expected { + if !equalSlices(result[k], v) { + t.Errorf("GroupBy failed for key %v, expected %v, got %v", k, v, result[k]) + } + } + }) + } +} + +func TestFlatten(t *testing.T) { + tests := []struct { + name string + input [][]interface{} + expected []interface{} + }{ + { + name: "normal nested slices", + input: [][]interface{}{{1, 2}, {3, 4}, {5}}, + expected: []interface{}{1, 2, 3, 4, 5}, + }, + { + name: "empty outer slice", + input: [][]interface{}{}, + expected: []interface{}{}, + }, + { + name: "empty inner slices", + input: [][]interface{}{{}, {}, {}}, + expected: []interface{}{}, + }, + { + name: "mixed types", + input: [][]interface{}{{1, "a"}, {true, 2.5}, {nil}}, + expected: []interface{}{1, "a", true, 2.5, nil}, + }, + { + name: "single element slices", + input: [][]interface{}{{1}, {2}, {3}}, + expected: []interface{}{1, 2, 3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Flatten(tt.input) + if !equalSlices(result, tt.expected) { + t.Errorf("Flatten failed, expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestContains(t *testing.T) { + tests := []struct { + name string + slice []interface{} + item interface{} + expected bool + }{ + { + name: "contains integer", + slice: []interface{}{1, 2, 3}, + item: 2, + expected: true, + }, + { + name: "does not contain integer", + slice: []interface{}{1, 2, 3}, + item: 4, + expected: false, + }, + { + name: "contains string", + slice: []interface{}{"a", "b", "c"}, + item: "b", + expected: true, + }, + { + name: "empty slice", + slice: []interface{}{}, + item: 1, + expected: false, + }, + { + name: "contains nil", + slice: []interface{}{1, nil, 3}, + item: nil, + expected: true, + }, + { + name: "mixed types", + slice: []interface{}{1, "a", true}, + item: true, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := contains(tt.slice, tt.item) + if result != tt.expected { + t.Errorf("contains failed, expected %v, got %v", tt.expected, result) + } + }) + } +} + +// Helper function for testing +func contains(slice []interface{}, item interface{}) bool { + for _, v := range slice { + if v == item { + return true + } + } + return false +} + +// Helper functions for comparing slices +func equalSlices(a, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func equalNestedSlices(a, b [][]interface{}) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !equalSlices(a[i], b[i]) { + return false + } + } + return true +} diff --git a/examples/gno.land/p/moul/fp/gno.mod b/examples/gno.land/p/moul/fp/gno.mod new file mode 100644 index 00000000000..905fa0f1c0e --- /dev/null +++ b/examples/gno.land/p/moul/fp/gno.mod @@ -0,0 +1 @@ +module gno.land/p/moul/fp diff --git a/examples/gno.land/p/moul/memo/gno.mod b/examples/gno.land/p/moul/memo/gno.mod new file mode 100644 index 00000000000..4a9948c30f7 --- /dev/null +++ b/examples/gno.land/p/moul/memo/gno.mod @@ -0,0 +1 @@ +module gno.land/p/moul/memo diff --git a/examples/gno.land/p/moul/memo/memo.gno b/examples/gno.land/p/moul/memo/memo.gno new file mode 100644 index 00000000000..e31f13aab15 --- /dev/null +++ b/examples/gno.land/p/moul/memo/memo.gno @@ -0,0 +1,134 @@ +// Package memo provides a simple memoization utility to cache function results. +// +// The package offers a Memoizer type that can cache function results based on keys, +// with optional validation of cached values. This is useful for expensive computations +// that need to be cached and potentially invalidated based on custom conditions. +// +// /!\ Important Warning for Gno Usage: +// In Gno, storage updates only persist during transactions. This means: +// - Cache entries created during queries will NOT persist +// - Creating cache entries during queries will actually decrease performance +// as it wastes resources trying to save data that won't be saved +// +// Best Practices: +// - Use this pattern in transaction-driven contexts rather than query/render scenarios +// - Consider controlled cache updates, e.g., by specific accounts (like oracles) +// - Ideal for cases where cache updates happen every N blocks or on specific events +// - Carefully evaluate if caching will actually improve performance in your use case +// +// Basic usage example: +// +// m := memo.New() +// +// // Cache expensive computation +// result := m.Memoize("key", func() interface{} { +// // expensive operation +// return "computed-value" +// }) +// +// // Subsequent calls with same key return cached result +// result = m.Memoize("key", func() interface{} { +// // function won't be called, cached value is returned +// return "computed-value" +// }) +// +// Example with validation: +// +// type TimestampedValue struct { +// Value string +// Timestamp time.Time +// } +// +// m := memo.New() +// +// // Cache value with timestamp +// result := m.MemoizeWithValidator( +// "key", +// func() interface{} { +// return TimestampedValue{ +// Value: "data", +// Timestamp: time.Now(), +// } +// }, +// func(cached interface{}) bool { +// // Validate that the cached value is not older than 1 hour +// if tv, ok := cached.(TimestampedValue); ok { +// return time.Since(tv.Timestamp) < time.Hour +// } +// return false +// }, +// ) +package memo + +import ( + "gno.land/p/demo/btree" + "gno.land/p/demo/ufmt" +) + +// Record implements the btree.Record interface for our cache entries +type cacheEntry struct { + key interface{} + value interface{} +} + +// Less implements btree.Record interface +func (e cacheEntry) Less(than btree.Record) bool { + // Convert the other record to cacheEntry + other := than.(cacheEntry) + // Compare string representations of keys for consistent ordering + return ufmt.Sprintf("%v", e.key) < ufmt.Sprintf("%v", other.key) +} + +// Memoizer is a structure to handle memoization of function results. +type Memoizer struct { + cache *btree.BTree +} + +// New creates a new Memoizer instance. +func New() *Memoizer { + return &Memoizer{ + cache: btree.New(), + } +} + +// Memoize ensures the result of the given function is cached for the specified key. +func (m *Memoizer) Memoize(key interface{}, fn func() interface{}) interface{} { + entry := cacheEntry{key: key} + if found := m.cache.Get(entry); found != nil { + return found.(cacheEntry).value + } + + value := fn() + m.cache.Insert(cacheEntry{key: key, value: value}) + return value +} + +// MemoizeWithValidator ensures the result is cached and valid according to the validator function. +func (m *Memoizer) MemoizeWithValidator(key interface{}, fn func() interface{}, isValid func(interface{}) bool) interface{} { + entry := cacheEntry{key: key} + if found := m.cache.Get(entry); found != nil { + cachedEntry := found.(cacheEntry) + if isValid(cachedEntry.value) { + return cachedEntry.value + } + } + + value := fn() + m.cache.Insert(cacheEntry{key: key, value: value}) + return value +} + +// Invalidate removes the cached value for the specified key. +func (m *Memoizer) Invalidate(key interface{}) { + m.cache.Delete(cacheEntry{key: key}) +} + +// Clear clears all cached values. +func (m *Memoizer) Clear() { + m.cache.Clear(true) +} + +// Size returns the number of items currently in the cache. +func (m *Memoizer) Size() int { + return m.cache.Len() +} diff --git a/examples/gno.land/p/moul/memo/memo_test.gno b/examples/gno.land/p/moul/memo/memo_test.gno new file mode 100644 index 00000000000..44dde5df640 --- /dev/null +++ b/examples/gno.land/p/moul/memo/memo_test.gno @@ -0,0 +1,449 @@ +package memo + +import ( + "std" + "testing" + "time" +) + +type timestampedValue struct { + value interface{} + timestamp time.Time +} + +// complexKey is used to test struct keys +type complexKey struct { + ID int + Name string +} + +func TestMemoize(t *testing.T) { + tests := []struct { + name string + key interface{} + value interface{} + callCount *int + }{ + { + name: "string key and value", + key: "test-key", + value: "test-value", + callCount: new(int), + }, + { + name: "int key and value", + key: 42, + value: 123, + callCount: new(int), + }, + { + name: "mixed types", + key: "number", + value: 42, + callCount: new(int), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := New() + if m.Size() != 0 { + t.Errorf("Initial size = %d, want 0", m.Size()) + } + + fn := func() interface{} { + *tt.callCount++ + return tt.value + } + + // First call should compute + result := m.Memoize(tt.key, fn) + if result != tt.value { + t.Errorf("Memoize() = %v, want %v", result, tt.value) + } + if *tt.callCount != 1 { + t.Errorf("Function called %d times, want 1", *tt.callCount) + } + if m.Size() != 1 { + t.Errorf("Size after first call = %d, want 1", m.Size()) + } + + // Second call should use cache + result = m.Memoize(tt.key, fn) + if result != tt.value { + t.Errorf("Memoize() second call = %v, want %v", result, tt.value) + } + if *tt.callCount != 1 { + t.Errorf("Function called %d times, want 1", *tt.callCount) + } + if m.Size() != 1 { + t.Errorf("Size after second call = %d, want 1", m.Size()) + } + }) + } +} + +func TestMemoizeWithValidator(t *testing.T) { + tests := []struct { + name string + key interface{} + value interface{} + validDuration time.Duration + waitDuration time.Duration + expectedCalls int + shouldRecompute bool + }{ + { + name: "valid cache", + key: "key1", + value: "value1", + validDuration: time.Hour, + waitDuration: time.Millisecond, + expectedCalls: 1, + shouldRecompute: false, + }, + { + name: "expired cache", + key: "key2", + value: "value2", + validDuration: time.Millisecond, + waitDuration: time.Millisecond * 2, + expectedCalls: 2, + shouldRecompute: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := New() + callCount := 0 + + fn := func() interface{} { + callCount++ + return timestampedValue{ + value: tt.value, + timestamp: time.Now(), + } + } + + isValid := func(cached interface{}) bool { + if tv, ok := cached.(timestampedValue); ok { + return time.Since(tv.timestamp) < tt.validDuration + } + return false + } + + // First call + result := m.MemoizeWithValidator(tt.key, fn, isValid) + if tv, ok := result.(timestampedValue); !ok || tv.value != tt.value { + t.Errorf("MemoizeWithValidator() = %v, want value %v", result, tt.value) + } + + // Wait + std.TestSkipHeights(10) + + // Second call + result = m.MemoizeWithValidator(tt.key, fn, isValid) + if tv, ok := result.(timestampedValue); !ok || tv.value != tt.value { + t.Errorf("MemoizeWithValidator() second call = %v, want value %v", result, tt.value) + } + + if callCount != tt.expectedCalls { + t.Errorf("Function called %d times, want %d", callCount, tt.expectedCalls) + } + }) + } +} + +func TestInvalidate(t *testing.T) { + tests := []struct { + name string + key interface{} + value interface{} + callCount *int + }{ + { + name: "invalidate existing key", + key: "test-key", + value: "test-value", + callCount: new(int), + }, + { + name: "invalidate non-existing key", + key: "missing-key", + value: "test-value", + callCount: new(int), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := New() + fn := func() interface{} { + *tt.callCount++ + return tt.value + } + + // First call + m.Memoize(tt.key, fn) + if m.Size() != 1 { + t.Errorf("Size after first call = %d, want 1", m.Size()) + } + + // Invalidate + m.Invalidate(tt.key) + if m.Size() != 0 { + t.Errorf("Size after invalidate = %d, want 0", m.Size()) + } + + // Call again should recompute + result := m.Memoize(tt.key, fn) + if result != tt.value { + t.Errorf("Memoize() after invalidate = %v, want %v", result, tt.value) + } + if *tt.callCount != 2 { + t.Errorf("Function called %d times, want 2", *tt.callCount) + } + if m.Size() != 1 { + t.Errorf("Size after recompute = %d, want 1", m.Size()) + } + }) + } +} + +func TestClear(t *testing.T) { + m := New() + callCount := 0 + + fn := func() interface{} { + callCount++ + return "value" + } + + // Cache some values + m.Memoize("key1", fn) + m.Memoize("key2", fn) + + if callCount != 2 { + t.Errorf("Initial calls = %d, want 2", callCount) + } + if m.Size() != 2 { + t.Errorf("Size after initial calls = %d, want 2", m.Size()) + } + + // Clear cache + m.Clear() + if m.Size() != 0 { + t.Errorf("Size after clear = %d, want 0", m.Size()) + } + + // Recompute values + m.Memoize("key1", fn) + m.Memoize("key2", fn) + + if callCount != 4 { + t.Errorf("Calls after clear = %d, want 4", callCount) + } + if m.Size() != 2 { + t.Errorf("Size after recompute = %d, want 2", m.Size()) + } +} + +func TestSize(t *testing.T) { + m := New() + + if m.Size() != 0 { + t.Errorf("Initial size = %d, want 0", m.Size()) + } + + callCount := 0 + fn := func() interface{} { + callCount++ + return "value" + } + + // Add items + m.Memoize("key1", fn) + if m.Size() != 1 { + t.Errorf("Size after first insert = %d, want 1", m.Size()) + } + + m.Memoize("key2", fn) + if m.Size() != 2 { + t.Errorf("Size after second insert = %d, want 2", m.Size()) + } + + // Duplicate key should not increase size + m.Memoize("key1", fn) + if m.Size() != 2 { + t.Errorf("Size after duplicate insert = %d, want 2", m.Size()) + } + + // Remove item + m.Invalidate("key1") + if m.Size() != 1 { + t.Errorf("Size after invalidate = %d, want 1", m.Size()) + } + + // Clear all + m.Clear() + if m.Size() != 0 { + t.Errorf("Size after clear = %d, want 0", m.Size()) + } +} + +func TestMemoizeWithDifferentKeyTypes(t *testing.T) { + tests := []struct { + name string + keys []interface{} // Now an array of keys + values []string // Corresponding values + callCount *int + }{ + { + name: "integer keys", + keys: []interface{}{42, 43}, + values: []string{"value-for-42", "value-for-43"}, + callCount: new(int), + }, + { + name: "float keys", + keys: []interface{}{3.14, 2.718}, + values: []string{"value-for-pi", "value-for-e"}, + callCount: new(int), + }, + { + name: "bool keys", + keys: []interface{}{true, false}, + values: []string{"value-for-true", "value-for-false"}, + callCount: new(int), + }, + /* + { + name: "struct keys", + keys: []interface{}{ + complexKey{ID: 1, Name: "test1"}, + complexKey{ID: 2, Name: "test2"}, + }, + values: []string{"value-for-struct1", "value-for-struct2"}, + callCount: new(int), + }, + { + name: "nil and empty interface keys", + keys: []interface{}{nil, interface{}(nil)}, + values: []string{"value-for-nil", "value-for-empty-interface"}, + callCount: new(int), + }, + */ + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := New() + + // Test both keys + for i, key := range tt.keys { + value := tt.values[i] + fn := func() interface{} { + *tt.callCount++ + return value + } + + // First call should compute + result := m.Memoize(key, fn) + if result != value { + t.Errorf("Memoize() for key %v = %v, want %v", key, result, value) + } + if *tt.callCount != i+1 { + t.Errorf("Function called %d times, want %d", *tt.callCount, i+1) + } + } + + // Verify size includes both entries + if m.Size() != 2 { + t.Errorf("Size after both inserts = %d, want 2", m.Size()) + } + + // Second call for each key should use cache + for i, key := range tt.keys { + initialCount := *tt.callCount + result := m.Memoize(key, func() interface{} { + *tt.callCount++ + return "should-not-be-called" + }) + + if result != tt.values[i] { + t.Errorf("Memoize() second call for key %v = %v, want %v", key, result, tt.values[i]) + } + if *tt.callCount != initialCount { + t.Errorf("Cache miss for key %v", key) + } + } + + // Test invalidate for each key + for i, key := range tt.keys { + m.Invalidate(key) + if m.Size() != 1-i { + t.Errorf("Size after invalidate %d = %d, want %d", i+1, m.Size(), 1-i) + } + } + }) + } +} + +func TestMultipleKeyTypes(t *testing.T) { + m := New() + callCount := 0 + + // Insert different key types simultaneously (two of each type) + keys := []interface{}{ + 42, 43, // ints + "string-key1", "string-key2", // strings + 3.14, 2.718, // floats + true, false, // bools + } + + for i, key := range keys { + value := i + m.Memoize(key, func() interface{} { + callCount++ + return value + }) + } + + // Verify size includes all entries + if m.Size() != len(keys) { + t.Errorf("Size = %d, want %d", m.Size(), len(keys)) + } + + // Verify all values are cached correctly + for i, key := range keys { + initialCount := callCount + result := m.Memoize(key, func() interface{} { + callCount++ + return -1 // Should never be returned if cache works + }) + + if result != i { + t.Errorf("Memoize(%v) = %v, want %v", key, result, i) + } + if callCount != initialCount { + t.Errorf("Cache miss for key %v", key) + } + } + + // Test invalidation of pairs + for i := 0; i < len(keys); i += 2 { + m.Invalidate(keys[i]) + m.Invalidate(keys[i+1]) + expectedSize := len(keys) - (i + 2) + if m.Size() != expectedSize { + t.Errorf("Size after invalidating pair %d = %d, want %d", i/2, m.Size(), expectedSize) + } + } + + // Clear and verify + m.Clear() + if m.Size() != 0 { + t.Errorf("Size after clear = %d, want 0", m.Size()) + } +} diff --git a/examples/gno.land/p/moul/typeutil/gno.mod b/examples/gno.land/p/moul/typeutil/gno.mod new file mode 100644 index 00000000000..4f9c432456b --- /dev/null +++ b/examples/gno.land/p/moul/typeutil/gno.mod @@ -0,0 +1 @@ +module gno.land/p/moul/typeutil diff --git a/examples/gno.land/p/moul/typeutil/typeutil.gno b/examples/gno.land/p/moul/typeutil/typeutil.gno new file mode 100644 index 00000000000..1fa79b94549 --- /dev/null +++ b/examples/gno.land/p/moul/typeutil/typeutil.gno @@ -0,0 +1,715 @@ +// Package typeutil provides utility functions for converting between different types +// and checking their states. It aims to provide consistent behavior across different +// types while remaining lightweight and dependency-free. +package typeutil + +import ( + "errors" + "sort" + "std" + "strconv" + "strings" + "time" +) + +// stringer is the interface that wraps the String method. +type stringer interface { + String() string +} + +// ToString converts any value to its string representation. +// It supports a wide range of Go types including: +// - Basic: string, bool +// - Numbers: int, int8-64, uint, uint8-64, float32, float64 +// - Special: time.Time, std.Address, []byte +// - Slices: []T for most basic types +// - Maps: map[string]string, map[string]interface{} +// - Interface: types implementing String() string +// +// Example usage: +// +// str := typeutil.ToString(42) // "42" +// str = typeutil.ToString([]int{1, 2}) // "[1 2]" +// str = typeutil.ToString(map[string]string{ // "map[a:1 b:2]" +// "a": "1", +// "b": "2", +// }) +func ToString(val interface{}) string { + if val == nil { + return "" + } + + // First check if value implements Stringer interface + if s, ok := val.(interface{ String() string }); ok { + return s.String() + } + + switch v := val.(type) { + // Pointer types - dereference and recurse + case *string: + if v == nil { + return "" + } + return *v + case *int: + if v == nil { + return "" + } + return strconv.Itoa(*v) + case *bool: + if v == nil { + return "" + } + return strconv.FormatBool(*v) + case *time.Time: + if v == nil { + return "" + } + return v.String() + case *std.Address: + if v == nil { + return "" + } + return string(*v) + + // String types + case string: + return v + case stringer: + return v.String() + + // Special types + case time.Time: + return v.String() + case std.Address: + return string(v) + case []byte: + return string(v) + case struct{}: + return "{}" + + // Integer types + case int: + return strconv.Itoa(v) + case int8: + return strconv.FormatInt(int64(v), 10) + case int16: + return strconv.FormatInt(int64(v), 10) + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case uint: + return strconv.FormatUint(uint64(v), 10) + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint16: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + + // Float types + case float32: + return strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + return strconv.FormatFloat(v, 'f', -1, 64) + + // Boolean + case bool: + if v { + return "true" + } + return "false" + + // Slice types + case []string: + return join(v) + case []int: + return join(v) + case []int32: + return join(v) + case []int64: + return join(v) + case []float32: + return join(v) + case []float64: + return join(v) + case []interface{}: + return join(v) + case []time.Time: + return joinTimes(v) + case []stringer: + return join(v) + case []std.Address: + return joinAddresses(v) + case [][]byte: + return joinBytes(v) + + // Map types with various key types + case map[interface{}]interface{}, map[string]interface{}, map[string]string, map[string]int: + var b strings.Builder + b.WriteString("map[") + first := true + + switch m := v.(type) { + case map[interface{}]interface{}: + // Convert all keys to strings for consistent ordering + keys := make([]string, 0) + keyMap := make(map[string]interface{}) + + for k := range m { + keyStr := ToString(k) + keys = append(keys, keyStr) + keyMap[keyStr] = k + } + sort.Strings(keys) + + for _, keyStr := range keys { + if !first { + b.WriteString(" ") + } + origKey := keyMap[keyStr] + b.WriteString(keyStr) + b.WriteString(":") + b.WriteString(ToString(m[origKey])) + first = false + } + + case map[string]interface{}: + keys := make([]string, 0) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + if !first { + b.WriteString(" ") + } + b.WriteString(k) + b.WriteString(":") + b.WriteString(ToString(m[k])) + first = false + } + + case map[string]string: + keys := make([]string, 0) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + if !first { + b.WriteString(" ") + } + b.WriteString(k) + b.WriteString(":") + b.WriteString(m[k]) + first = false + } + + case map[string]int: + keys := make([]string, 0) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, k := range keys { + if !first { + b.WriteString(" ") + } + b.WriteString(k) + b.WriteString(":") + b.WriteString(strconv.Itoa(m[k])) + first = false + } + } + b.WriteString("]") + return b.String() + + // Default + default: + return "" + } +} + +func join(slice interface{}) string { + if IsZero(slice) { + return "[]" + } + + items := ToInterfaceSlice(slice) + if items == nil { + return "[]" + } + + var b strings.Builder + b.WriteString("[") + for i, item := range items { + if i > 0 { + b.WriteString(" ") + } + b.WriteString(ToString(item)) + } + b.WriteString("]") + return b.String() +} + +func joinTimes(slice []time.Time) string { + if len(slice) == 0 { + return "[]" + } + var b strings.Builder + b.WriteString("[") + for i, t := range slice { + if i > 0 { + b.WriteString(" ") + } + b.WriteString(t.String()) + } + b.WriteString("]") + return b.String() +} + +func joinAddresses(slice []std.Address) string { + if len(slice) == 0 { + return "[]" + } + var b strings.Builder + b.WriteString("[") + for i, addr := range slice { + if i > 0 { + b.WriteString(" ") + } + b.WriteString(string(addr)) + } + b.WriteString("]") + return b.String() +} + +func joinBytes(slice [][]byte) string { + if len(slice) == 0 { + return "[]" + } + var b strings.Builder + b.WriteString("[") + for i, bytes := range slice { + if i > 0 { + b.WriteString(" ") + } + b.WriteString(string(bytes)) + } + b.WriteString("]") + return b.String() +} + +// ToBool converts any value to a boolean based on common programming conventions. +// For example: +// - Numbers: 0 is false, any other number is true +// - Strings: "", "0", "false", "f", "no", "n", "off" are false, others are true +// - Slices/Maps: empty is false, non-empty is true +// - nil: always false +// - bool: direct value +func ToBool(val interface{}) bool { + if IsZero(val) { + return false + } + + // Handle special string cases + if str, ok := val.(string); ok { + str = strings.ToLower(strings.TrimSpace(str)) + return str != "" && str != "0" && str != "false" && str != "f" && str != "no" && str != "n" && str != "off" + } + + return true +} + +// IsZero returns true if the value represents a "zero" or "empty" state for its type. +// For example: +// - Numbers: 0 +// - Strings: "" +// - Slices/Maps: empty +// - nil: true +// - bool: false +// - time.Time: IsZero() +// - std.Address: empty string +func IsZero(val interface{}) bool { + if val == nil { + return true + } + + switch v := val.(type) { + // Pointer types - nil pointer is zero, otherwise check pointed value + case *bool: + return v == nil || !*v + case *string: + return v == nil || *v == "" + case *int: + return v == nil || *v == 0 + case *time.Time: + return v == nil || v.IsZero() + case *std.Address: + return v == nil || string(*v) == "" + + // Bool + case bool: + return !v + + // String types + case string: + return v == "" + case stringer: + return v.String() == "" + + // Integer types + case int: + return v == 0 + case int8: + return v == 0 + case int16: + return v == 0 + case int32: + return v == 0 + case int64: + return v == 0 + case uint: + return v == 0 + case uint8: + return v == 0 + case uint16: + return v == 0 + case uint32: + return v == 0 + case uint64: + return v == 0 + + // Float types + case float32: + return v == 0 + case float64: + return v == 0 + + // Special types + case []byte: + return len(v) == 0 + case time.Time: + return v.IsZero() + case std.Address: + return string(v) == "" + + // Slices (check if empty) + case []string: + return len(v) == 0 + case []int: + return len(v) == 0 + case []int32: + return len(v) == 0 + case []int64: + return len(v) == 0 + case []float32: + return len(v) == 0 + case []float64: + return len(v) == 0 + case []interface{}: + return len(v) == 0 + case []time.Time: + return len(v) == 0 + case []std.Address: + return len(v) == 0 + case [][]byte: + return len(v) == 0 + case []stringer: + return len(v) == 0 + + // Maps (check if empty) + case map[string]string: + return len(v) == 0 + case map[string]interface{}: + return len(v) == 0 + + default: + return false // non-nil unknown types are considered non-zero + } +} + +// ToInterfaceSlice converts various slice types to []interface{} +func ToInterfaceSlice(val interface{}) []interface{} { + switch v := val.(type) { + case []interface{}: + return v + case []string: + result := make([]interface{}, len(v)) + for i, s := range v { + result[i] = s + } + return result + case []int: + result := make([]interface{}, len(v)) + for i, n := range v { + result[i] = n + } + return result + case []int32: + result := make([]interface{}, len(v)) + for i, n := range v { + result[i] = n + } + return result + case []int64: + result := make([]interface{}, len(v)) + for i, n := range v { + result[i] = n + } + return result + case []float32: + result := make([]interface{}, len(v)) + for i, n := range v { + result[i] = n + } + return result + case []float64: + result := make([]interface{}, len(v)) + for i, n := range v { + result[i] = n + } + return result + case []bool: + result := make([]interface{}, len(v)) + for i, b := range v { + result[i] = b + } + return result + default: + return nil + } +} + +// ToMapStringInterface converts a map with string keys and any value type to map[string]interface{} +func ToMapStringInterface(m interface{}) (map[string]interface{}, error) { + result := make(map[string]interface{}) + + switch v := m.(type) { + case map[string]interface{}: + return v, nil + case map[string]string: + for k, val := range v { + result[k] = val + } + case map[string]int: + for k, val := range v { + result[k] = val + } + case map[string]int64: + for k, val := range v { + result[k] = val + } + case map[string]float64: + for k, val := range v { + result[k] = val + } + case map[string]bool: + for k, val := range v { + result[k] = val + } + case map[string][]string: + for k, val := range v { + result[k] = ToInterfaceSlice(val) + } + case map[string][]int: + for k, val := range v { + result[k] = ToInterfaceSlice(val) + } + case map[string][]interface{}: + for k, val := range v { + result[k] = val + } + case map[string]map[string]interface{}: + for k, val := range v { + result[k] = val + } + case map[string]map[string]string: + for k, val := range v { + if converted, err := ToMapStringInterface(val); err == nil { + result[k] = converted + } else { + return nil, errors.New("failed to convert nested map at key: " + k) + } + } + default: + return nil, errors.New("unsupported map type: " + ToString(m)) + } + + return result, nil +} + +// ToMapIntInterface converts a map with int keys and any value type to map[int]interface{} +func ToMapIntInterface(m interface{}) (map[int]interface{}, error) { + result := make(map[int]interface{}) + + switch v := m.(type) { + case map[int]interface{}: + return v, nil + case map[int]string: + for k, val := range v { + result[k] = val + } + case map[int]int: + for k, val := range v { + result[k] = val + } + case map[int]int64: + for k, val := range v { + result[k] = val + } + case map[int]float64: + for k, val := range v { + result[k] = val + } + case map[int]bool: + for k, val := range v { + result[k] = val + } + case map[int][]string: + for k, val := range v { + result[k] = ToInterfaceSlice(val) + } + case map[int][]int: + for k, val := range v { + result[k] = ToInterfaceSlice(val) + } + case map[int][]interface{}: + for k, val := range v { + result[k] = val + } + case map[int]map[string]interface{}: + for k, val := range v { + result[k] = val + } + case map[int]map[int]interface{}: + for k, val := range v { + result[k] = val + } + default: + return nil, errors.New("unsupported map type: " + ToString(m)) + } + + return result, nil +} + +// ToStringSlice converts various slice types to []string +func ToStringSlice(val interface{}) []string { + switch v := val.(type) { + case []string: + return v + case []interface{}: + result := make([]string, len(v)) + for i, item := range v { + result[i] = ToString(item) + } + return result + case []int: + result := make([]string, len(v)) + for i, n := range v { + result[i] = strconv.Itoa(n) + } + return result + case []int32: + result := make([]string, len(v)) + for i, n := range v { + result[i] = strconv.FormatInt(int64(n), 10) + } + return result + case []int64: + result := make([]string, len(v)) + for i, n := range v { + result[i] = strconv.FormatInt(n, 10) + } + return result + case []float32: + result := make([]string, len(v)) + for i, n := range v { + result[i] = strconv.FormatFloat(float64(n), 'f', -1, 32) + } + return result + case []float64: + result := make([]string, len(v)) + for i, n := range v { + result[i] = strconv.FormatFloat(n, 'f', -1, 64) + } + return result + case []bool: + result := make([]string, len(v)) + for i, b := range v { + result[i] = strconv.FormatBool(b) + } + return result + case []time.Time: + result := make([]string, len(v)) + for i, t := range v { + result[i] = t.String() + } + return result + case []std.Address: + result := make([]string, len(v)) + for i, addr := range v { + result[i] = string(addr) + } + return result + case [][]byte: + result := make([]string, len(v)) + for i, b := range v { + result[i] = string(b) + } + return result + case []stringer: + result := make([]string, len(v)) + for i, s := range v { + result[i] = s.String() + } + return result + case []uint: + result := make([]string, len(v)) + for i, n := range v { + result[i] = strconv.FormatUint(uint64(n), 10) + } + return result + case []uint8: + result := make([]string, len(v)) + for i, n := range v { + result[i] = strconv.FormatUint(uint64(n), 10) + } + return result + case []uint16: + result := make([]string, len(v)) + for i, n := range v { + result[i] = strconv.FormatUint(uint64(n), 10) + } + return result + case []uint32: + result := make([]string, len(v)) + for i, n := range v { + result[i] = strconv.FormatUint(uint64(n), 10) + } + return result + case []uint64: + result := make([]string, len(v)) + for i, n := range v { + result[i] = strconv.FormatUint(n, 10) + } + return result + default: + // Try to convert using reflection if it's a slice + if slice := ToInterfaceSlice(val); slice != nil { + result := make([]string, len(slice)) + for i, item := range slice { + result[i] = ToString(item) + } + return result + } + return nil + } +} diff --git a/examples/gno.land/p/moul/typeutil/typeutil_test.gno b/examples/gno.land/p/moul/typeutil/typeutil_test.gno new file mode 100644 index 00000000000..543ea1deec4 --- /dev/null +++ b/examples/gno.land/p/moul/typeutil/typeutil_test.gno @@ -0,0 +1,1075 @@ +package typeutil + +import ( + "std" + "strings" + "testing" + "time" +) + +type testStringer struct { + value string +} + +func (t testStringer) String() string { + return "test:" + t.value +} + +func TestToString(t *testing.T) { + // setup test data + str := "hello" + num := 42 + b := true + now := time.Now() + addr := std.Address("g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5") + stringer := testStringer{value: "hello"} + + type testCase struct { + name string + input interface{} + expected string + } + + tests := []testCase{ + // basic types + {"string", "hello", "hello"}, + {"empty_string", "", ""}, + {"nil", nil, ""}, + + // integer types + {"int", 42, "42"}, + {"int8", int8(8), "8"}, + {"int16", int16(16), "16"}, + {"int32", int32(32), "32"}, + {"int64", int64(64), "64"}, + {"uint", uint(42), "42"}, + {"uint8", uint8(8), "8"}, + {"uint16", uint16(16), "16"}, + {"uint32", uint32(32), "32"}, + {"uint64", uint64(64), "64"}, + + // float types + {"float32", float32(3.14), "3.14"}, + {"float64", 3.14159, "3.14159"}, + + // boolean + {"bool_true", true, "true"}, + {"bool_false", false, "false"}, + + // special types + {"time", now, now.String()}, + {"address", addr, string(addr)}, + {"bytes", []byte("hello"), "hello"}, + {"stringer", stringer, "test:hello"}, + + // slices + {"empty_slice", []string{}, "[]"}, + {"string_slice", []string{"a", "b"}, "[a b]"}, + {"int_slice", []int{1, 2}, "[1 2]"}, + {"int32_slice", []int32{1, 2}, "[1 2]"}, + {"int64_slice", []int64{1, 2}, "[1 2]"}, + {"float32_slice", []float32{1.1, 2.2}, "[1.1 2.2]"}, + {"float64_slice", []float64{1.1, 2.2}, "[1.1 2.2]"}, + {"bytes_slice", [][]byte{[]byte("a"), []byte("b")}, "[a b]"}, + {"time_slice", []time.Time{now, now}, "[" + now.String() + " " + now.String() + "]"}, + {"address_slice", []std.Address{addr, addr}, "[" + string(addr) + " " + string(addr) + "]"}, + {"interface_slice", []interface{}{1, "a", true}, "[1 a true]"}, + + // empty slices + {"empty_string_slice", []string{}, "[]"}, + {"empty_int_slice", []int{}, "[]"}, + {"empty_int32_slice", []int32{}, "[]"}, + {"empty_int64_slice", []int64{}, "[]"}, + {"empty_float32_slice", []float32{}, "[]"}, + {"empty_float64_slice", []float64{}, "[]"}, + {"empty_bytes_slice", [][]byte{}, "[]"}, + {"empty_time_slice", []time.Time{}, "[]"}, + {"empty_address_slice", []std.Address{}, "[]"}, + {"empty_interface_slice", []interface{}{}, "[]"}, + + // maps + {"empty_string_map", map[string]string{}, "map[]"}, + {"string_map", map[string]string{"a": "1", "b": "2"}, "map[a:1 b:2]"}, + {"empty_interface_map", map[string]interface{}{}, "map[]"}, + {"interface_map", map[string]interface{}{"a": 1, "b": "2"}, "map[a:1 b:2]"}, + + // edge cases + {"empty_bytes", []byte{}, ""}, + {"nil_interface", interface{}(nil), ""}, + {"empty_struct", struct{}{}, "{}"}, + {"unknown_type", struct{ foo string }{}, ""}, + + // pointer types + {"nil_string_ptr", (*string)(nil), ""}, + {"string_ptr", &str, "hello"}, + {"nil_int_ptr", (*int)(nil), ""}, + {"int_ptr", &num, "42"}, + {"nil_bool_ptr", (*bool)(nil), ""}, + {"bool_ptr", &b, "true"}, + // {"nil_time_ptr", (*time.Time)(nil), ""}, // TODO: fix this + {"time_ptr", &now, now.String()}, + // {"nil_address_ptr", (*std.Address)(nil), ""}, // TODO: fix this + {"address_ptr", &addr, string(addr)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ToString(tt.input) + if got != tt.expected { + t.Errorf("%s: ToString(%v) = %q, want %q", tt.name, tt.input, got, tt.expected) + } + }) + } +} + +func TestToBool(t *testing.T) { + str := "true" + num := 42 + b := true + now := time.Now() + addr := std.Address("g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5") + zero := 0 + empty := "" + falseVal := false + + type testCase struct { + name string + input interface{} + expected bool + } + + tests := []testCase{ + // basic types + {"true", true, true}, + {"false", false, false}, + {"nil", nil, false}, + + // strings + {"empty_string", "", false}, + {"zero_string", "0", false}, + {"false_string", "false", false}, + {"f_string", "f", false}, + {"no_string", "no", false}, + {"n_string", "n", false}, + {"off_string", "off", false}, + {"space_string", " ", false}, + {"true_string", "true", true}, + {"yes_string", "yes", true}, + {"random_string", "hello", true}, + + // numbers + {"zero_int", 0, false}, + {"positive_int", 1, true}, + {"negative_int", -1, true}, + {"zero_float", 0.0, false}, + {"positive_float", 0.1, true}, + {"negative_float", -0.1, true}, + + // special types + {"empty_bytes", []byte{}, false}, + {"non_empty_bytes", []byte{1}, true}, + /*{"zero_time", time.Time{}, false},*/ // TODO: fix this + {"empty_address", std.Address(""), false}, + + // slices + {"empty_slice", []string{}, false}, + {"non_empty_slice", []string{"a"}, true}, + + // maps + {"empty_map", map[string]string{}, false}, + {"non_empty_map", map[string]string{"a": "b"}, true}, + + // pointer types + {"nil_bool_ptr", (*bool)(nil), false}, + {"true_ptr", &b, true}, + {"false_ptr", &falseVal, false}, + {"nil_string_ptr", (*string)(nil), false}, + {"string_ptr", &str, true}, + {"empty_string_ptr", &empty, false}, + {"nil_int_ptr", (*int)(nil), false}, + {"int_ptr", &num, true}, + {"zero_int_ptr", &zero, false}, + // {"nil_time_ptr", (*time.Time)(nil), false}, // TODO: fix this + {"time_ptr", &now, true}, + // {"nil_address_ptr", (*std.Address)(nil), false}, // TODO: fix this + {"address_ptr", &addr, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ToBool(tt.input) + if got != tt.expected { + t.Errorf("%s: ToBool(%v) = %v, want %v", tt.name, tt.input, got, tt.expected) + } + }) + } +} + +func TestIsZero(t *testing.T) { + str := "hello" + num := 42 + b := true + now := time.Now() + addr := std.Address("g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5") + zero := 0 + empty := "" + falseVal := false + + type testCase struct { + name string + input interface{} + expected bool + } + + tests := []testCase{ + // basic types + {"true", true, false}, + {"false", false, true}, + {"nil", nil, true}, + + // strings + {"empty_string", "", true}, + {"non_empty_string", "hello", false}, + + // numbers + {"zero_int", 0, true}, + {"non_zero_int", 1, false}, + {"zero_float", 0.0, true}, + {"non_zero_float", 0.1, false}, + + // special types + {"empty_bytes", []byte{}, true}, + {"non_empty_bytes", []byte{1}, false}, + /*{"zero_time", time.Time{}, true},*/ // TODO: fix this + {"empty_address", std.Address(""), true}, + + // slices + {"empty_slice", []string{}, true}, + {"non_empty_slice", []string{"a"}, false}, + + // maps + {"empty_map", map[string]string{}, true}, + {"non_empty_map", map[string]string{"a": "b"}, false}, + + // pointer types + {"nil_bool_ptr", (*bool)(nil), true}, + {"false_ptr", &falseVal, true}, + {"true_ptr", &b, false}, + {"nil_string_ptr", (*string)(nil), true}, + {"empty_string_ptr", &empty, true}, + {"string_ptr", &str, false}, + {"nil_int_ptr", (*int)(nil), true}, + {"zero_int_ptr", &zero, true}, + {"int_ptr", &num, false}, + // {"nil_time_ptr", (*time.Time)(nil), true}, // TODO: fix this + {"time_ptr", &now, false}, + // {"nil_address_ptr", (*std.Address)(nil), true}, // TODO: fix this + {"address_ptr", &addr, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsZero(tt.input) + if got != tt.expected { + t.Errorf("%s: IsZero(%v) = %v, want %v", tt.name, tt.input, got, tt.expected) + } + }) + } +} + +func TestToInterfaceSlice(t *testing.T) { + now := time.Now() + addr := std.Address("g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5") + str := testStringer{value: "hello"} + + tests := []struct { + name string + input interface{} + expected []interface{} + compare func([]interface{}, []interface{}) bool + }{ + { + name: "nil", + input: nil, + expected: nil, + compare: compareNil, + }, + { + name: "empty_interface_slice", + input: []interface{}{}, + expected: []interface{}{}, + compare: compareEmpty, + }, + { + name: "interface_slice", + input: []interface{}{1, "two", true}, + expected: []interface{}{1, "two", true}, + compare: compareInterfaces, + }, + { + name: "string_slice", + input: []string{"a", "b", "c"}, + expected: []interface{}{"a", "b", "c"}, + compare: compareStrings, + }, + { + name: "int_slice", + input: []int{1, 2, 3}, + expected: []interface{}{1, 2, 3}, + compare: compareInts, + }, + { + name: "int32_slice", + input: []int32{1, 2, 3}, + expected: []interface{}{int32(1), int32(2), int32(3)}, + compare: compareInt32s, + }, + { + name: "int64_slice", + input: []int64{1, 2, 3}, + expected: []interface{}{int64(1), int64(2), int64(3)}, + compare: compareInt64s, + }, + { + name: "float32_slice", + input: []float32{1.1, 2.2, 3.3}, + expected: []interface{}{float32(1.1), float32(2.2), float32(3.3)}, + compare: compareFloat32s, + }, + { + name: "float64_slice", + input: []float64{1.1, 2.2, 3.3}, + expected: []interface{}{1.1, 2.2, 3.3}, + compare: compareFloat64s, + }, + { + name: "bool_slice", + input: []bool{true, false, true}, + expected: []interface{}{true, false, true}, + compare: compareBools, + }, + /* { + name: "time_slice", + input: []time.Time{now}, + expected: []interface{}{now}, + compare: compareTimes, + }, */ // TODO: fix this + /* { + name: "address_slice", + input: []std.Address{addr}, + expected: []interface{}{addr}, + compare: compareAddresses, + },*/ // TODO: fix this + /* { + name: "bytes_slice", + input: [][]byte{[]byte("hello"), []byte("world")}, + expected: []interface{}{[]byte("hello"), []byte("world")}, + compare: compareBytes, + },*/ // TODO: fix this + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ToInterfaceSlice(tt.input) + if !tt.compare(got, tt.expected) { + t.Errorf("ToInterfaceSlice() = %v, want %v", got, tt.expected) + } + }) + } +} + +func compareNil(a, b []interface{}) bool { + return a == nil && b == nil +} + +func compareEmpty(a, b []interface{}) bool { + return len(a) == 0 && len(b) == 0 +} + +func compareInterfaces(a, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func compareStrings(a, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i := range a { + as, ok1 := a[i].(string) + bs, ok2 := b[i].(string) + if !ok1 || !ok2 || as != bs { + return false + } + } + return true +} + +func compareInts(a, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i := range a { + ai, ok1 := a[i].(int) + bi, ok2 := b[i].(int) + if !ok1 || !ok2 || ai != bi { + return false + } + } + return true +} + +func compareInt32s(a, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i := range a { + ai, ok1 := a[i].(int32) + bi, ok2 := b[i].(int32) + if !ok1 || !ok2 || ai != bi { + return false + } + } + return true +} + +func compareInt64s(a, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i := range a { + ai, ok1 := a[i].(int64) + bi, ok2 := b[i].(int64) + if !ok1 || !ok2 || ai != bi { + return false + } + } + return true +} + +func compareFloat32s(a, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i := range a { + ai, ok1 := a[i].(float32) + bi, ok2 := b[i].(float32) + if !ok1 || !ok2 || ai != bi { + return false + } + } + return true +} + +func compareFloat64s(a, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i := range a { + ai, ok1 := a[i].(float64) + bi, ok2 := b[i].(float64) + if !ok1 || !ok2 || ai != bi { + return false + } + } + return true +} + +func compareBools(a, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i := range a { + ab, ok1 := a[i].(bool) + bb, ok2 := b[i].(bool) + if !ok1 || !ok2 || ab != bb { + return false + } + } + return true +} + +func compareTimes(a, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i := range a { + at, ok1 := a[i].(time.Time) + bt, ok2 := b[i].(time.Time) + if !ok1 || !ok2 || !at.Equal(bt) { + return false + } + } + return true +} + +func compareAddresses(a, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i := range a { + aa, ok1 := a[i].(std.Address) + ba, ok2 := b[i].(std.Address) + if !ok1 || !ok2 || aa != ba { + return false + } + } + return true +} + +func compareBytes(a, b []interface{}) bool { + if len(a) != len(b) { + return false + } + for i := range a { + ab, ok1 := a[i].([]byte) + bb, ok2 := b[i].([]byte) + if !ok1 || !ok2 || string(ab) != string(bb) { + return false + } + } + return true +} + +// compareStringInterfaceMaps compares two map[string]interface{} for equality +func compareStringInterfaceMaps(a, b map[string]interface{}) bool { + if len(a) != len(b) { + return false + } + for k, v1 := range a { + v2, ok := b[k] + if !ok { + return false + } + // Compare values based on their type + switch val1 := v1.(type) { + case string: + val2, ok := v2.(string) + if !ok || val1 != val2 { + return false + } + case int: + val2, ok := v2.(int) + if !ok || val1 != val2 { + return false + } + case float64: + val2, ok := v2.(float64) + if !ok || val1 != val2 { + return false + } + case bool: + val2, ok := v2.(bool) + if !ok || val1 != val2 { + return false + } + case []interface{}: + val2, ok := v2.([]interface{}) + if !ok || len(val1) != len(val2) { + return false + } + for i := range val1 { + if val1[i] != val2[i] { + return false + } + } + case map[string]interface{}: + val2, ok := v2.(map[string]interface{}) + if !ok || !compareStringInterfaceMaps(val1, val2) { + return false + } + default: + return false + } + } + return true +} + +func TestToMapStringInterface(t *testing.T) { + tests := []struct { + name string + input interface{} + expected map[string]interface{} + wantErr bool + }{ + { + name: "map[string]interface{}", + input: map[string]interface{}{ + "key1": "value1", + "key2": 42, + }, + expected: map[string]interface{}{ + "key1": "value1", + "key2": 42, + }, + wantErr: false, + }, + { + name: "map[string]string", + input: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + expected: map[string]interface{}{ + "key1": "value1", + "key2": "value2", + }, + wantErr: false, + }, + { + name: "map[string]int", + input: map[string]int{ + "key1": 1, + "key2": 2, + }, + expected: map[string]interface{}{ + "key1": 1, + "key2": 2, + }, + wantErr: false, + }, + { + name: "map[string]float64", + input: map[string]float64{ + "key1": 1.1, + "key2": 2.2, + }, + expected: map[string]interface{}{ + "key1": 1.1, + "key2": 2.2, + }, + wantErr: false, + }, + { + name: "map[string]bool", + input: map[string]bool{ + "key1": true, + "key2": false, + }, + expected: map[string]interface{}{ + "key1": true, + "key2": false, + }, + wantErr: false, + }, + { + name: "map[string][]string", + input: map[string][]string{ + "key1": {"a", "b"}, + "key2": {"c", "d"}, + }, + expected: map[string]interface{}{ + "key1": []interface{}{"a", "b"}, + "key2": []interface{}{"c", "d"}, + }, + wantErr: false, + }, + { + name: "nested map[string]map[string]string", + input: map[string]map[string]string{ + "key1": {"nested1": "value1"}, + "key2": {"nested2": "value2"}, + }, + expected: map[string]interface{}{ + "key1": map[string]interface{}{"nested1": "value1"}, + "key2": map[string]interface{}{"nested2": "value2"}, + }, + wantErr: false, + }, + { + name: "unsupported type", + input: 42, // not a map + expected: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ToMapStringInterface(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ToMapStringInterface() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if !compareStringInterfaceMaps(got, tt.expected) { + t.Errorf("ToMapStringInterface() = %v, expected %v", got, tt.expected) + } + } + }) + } +} + +// Test error messages +func TestToMapStringInterfaceErrors(t *testing.T) { + _, err := ToMapStringInterface(42) + if err == nil || !strings.Contains(err.Error(), "unsupported map type") { + t.Errorf("Expected error containing 'unsupported map type', got %v", err) + } +} + +// compareIntInterfaceMaps compares two map[int]interface{} for equality +func compareIntInterfaceMaps(a, b map[int]interface{}) bool { + if len(a) != len(b) { + return false + } + for k, v1 := range a { + v2, ok := b[k] + if !ok { + return false + } + // Compare values based on their type + switch val1 := v1.(type) { + case string: + val2, ok := v2.(string) + if !ok || val1 != val2 { + return false + } + case int: + val2, ok := v2.(int) + if !ok || val1 != val2 { + return false + } + case float64: + val2, ok := v2.(float64) + if !ok || val1 != val2 { + return false + } + case bool: + val2, ok := v2.(bool) + if !ok || val1 != val2 { + return false + } + case []interface{}: + val2, ok := v2.([]interface{}) + if !ok || len(val1) != len(val2) { + return false + } + for i := range val1 { + if val1[i] != val2[i] { + return false + } + } + case map[string]interface{}: + val2, ok := v2.(map[string]interface{}) + if !ok || !compareStringInterfaceMaps(val1, val2) { + return false + } + default: + return false + } + } + return true +} + +func TestToMapIntInterface(t *testing.T) { + tests := []struct { + name string + input interface{} + expected map[int]interface{} + wantErr bool + }{ + { + name: "map[int]interface{}", + input: map[int]interface{}{ + 1: "value1", + 2: 42, + }, + expected: map[int]interface{}{ + 1: "value1", + 2: 42, + }, + wantErr: false, + }, + { + name: "map[int]string", + input: map[int]string{ + 1: "value1", + 2: "value2", + }, + expected: map[int]interface{}{ + 1: "value1", + 2: "value2", + }, + wantErr: false, + }, + { + name: "map[int]int", + input: map[int]int{ + 1: 10, + 2: 20, + }, + expected: map[int]interface{}{ + 1: 10, + 2: 20, + }, + wantErr: false, + }, + { + name: "map[int]float64", + input: map[int]float64{ + 1: 1.1, + 2: 2.2, + }, + expected: map[int]interface{}{ + 1: 1.1, + 2: 2.2, + }, + wantErr: false, + }, + { + name: "map[int]bool", + input: map[int]bool{ + 1: true, + 2: false, + }, + expected: map[int]interface{}{ + 1: true, + 2: false, + }, + wantErr: false, + }, + { + name: "map[int][]string", + input: map[int][]string{ + 1: {"a", "b"}, + 2: {"c", "d"}, + }, + expected: map[int]interface{}{ + 1: []interface{}{"a", "b"}, + 2: []interface{}{"c", "d"}, + }, + wantErr: false, + }, + { + name: "map[int]map[string]interface{}", + input: map[int]map[string]interface{}{ + 1: {"nested1": "value1"}, + 2: {"nested2": "value2"}, + }, + expected: map[int]interface{}{ + 1: map[string]interface{}{"nested1": "value1"}, + 2: map[string]interface{}{"nested2": "value2"}, + }, + wantErr: false, + }, + { + name: "unsupported type", + input: 42, // not a map + expected: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ToMapIntInterface(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ToMapIntInterface() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if !compareIntInterfaceMaps(got, tt.expected) { + t.Errorf("ToMapIntInterface() = %v, expected %v", got, tt.expected) + } + } + }) + } +} + +func TestToStringSlice(t *testing.T) { + tests := []struct { + name string + input interface{} + expected []string + }{ + { + name: "nil input", + input: nil, + expected: nil, + }, + { + name: "empty slice", + input: []string{}, + expected: []string{}, + }, + { + name: "string slice", + input: []string{"a", "b", "c"}, + expected: []string{"a", "b", "c"}, + }, + { + name: "int slice", + input: []int{1, 2, 3}, + expected: []string{"1", "2", "3"}, + }, + { + name: "int32 slice", + input: []int32{1, 2, 3}, + expected: []string{"1", "2", "3"}, + }, + { + name: "int64 slice", + input: []int64{1, 2, 3}, + expected: []string{"1", "2", "3"}, + }, + { + name: "uint slice", + input: []uint{1, 2, 3}, + expected: []string{"1", "2", "3"}, + }, + { + name: "uint8 slice", + input: []uint8{1, 2, 3}, + expected: []string{"1", "2", "3"}, + }, + { + name: "uint16 slice", + input: []uint16{1, 2, 3}, + expected: []string{"1", "2", "3"}, + }, + { + name: "uint32 slice", + input: []uint32{1, 2, 3}, + expected: []string{"1", "2", "3"}, + }, + { + name: "uint64 slice", + input: []uint64{1, 2, 3}, + expected: []string{"1", "2", "3"}, + }, + { + name: "float32 slice", + input: []float32{1.1, 2.2, 3.3}, + expected: []string{"1.1", "2.2", "3.3"}, + }, + { + name: "float64 slice", + input: []float64{1.1, 2.2, 3.3}, + expected: []string{"1.1", "2.2", "3.3"}, + }, + { + name: "bool slice", + input: []bool{true, false, true}, + expected: []string{"true", "false", "true"}, + }, + { + name: "[]byte slice", + input: [][]byte{[]byte("hello"), []byte("world")}, + expected: []string{"hello", "world"}, + }, + { + name: "interface slice", + input: []interface{}{1, "hello", true}, + expected: []string{"1", "hello", "true"}, + }, + { + name: "time slice", + input: []time.Time{time.Time{}, time.Time{}}, + expected: []string{"0001-01-01 00:00:00 +0000 UTC", "0001-01-01 00:00:00 +0000 UTC"}, + }, + { + name: "address slice", + input: []std.Address{"addr1", "addr2"}, + expected: []string{"addr1", "addr2"}, + }, + { + name: "non-slice input", + input: 42, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ToStringSlice(tt.input) + if !slicesEqual(result, tt.expected) { + t.Errorf("ToStringSlice(%v) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +// Helper function to compare string slices +func slicesEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func TestToStringAdvanced(t *testing.T) { + tests := []struct { + name string + input interface{} + expected string + }{ + { + name: "slice with mixed basic types", + input: []interface{}{ + 42, + "hello", + true, + 3.14, + }, + expected: "[42 hello true 3.14]", + }, + { + name: "map with basic types", + input: map[string]interface{}{ + "int": 42, + "str": "hello", + "bool": true, + "float": 3.14, + }, + expected: "map[bool:true float:3.14 int:42 str:hello]", + }, + { + name: "mixed types map", + input: map[interface{}]interface{}{ + 42: "number", + "string": 123, + true: []int{1, 2, 3}, + struct{}{}: "empty", + }, + expected: "map[42:number string:123 true:[1 2 3] {}:empty]", + }, + { + name: "nested maps", + input: map[string]interface{}{ + "a": map[string]int{ + "x": 1, + "y": 2, + }, + "b": []interface{}{1, "two", true}, + }, + expected: "map[a:map[x:1 y:2] b:[1 two true]]", + }, + { + name: "empty struct", + input: struct{}{}, + expected: "{}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ToString(tt.input) + if result != tt.expected { + t.Errorf("\nToString(%v) =\n%v\nwant:\n%v", tt.input, result, tt.expected) + } + }) + } +} diff --git a/examples/gno.land/p/moul/ulist/gno.mod b/examples/gno.land/p/moul/ulist/gno.mod new file mode 100644 index 00000000000..077f8c556f3 --- /dev/null +++ b/examples/gno.land/p/moul/ulist/gno.mod @@ -0,0 +1 @@ +module gno.land/p/moul/ulist diff --git a/examples/gno.land/p/moul/ulist/ulist.gno b/examples/gno.land/p/moul/ulist/ulist.gno new file mode 100644 index 00000000000..507a02a4e45 --- /dev/null +++ b/examples/gno.land/p/moul/ulist/ulist.gno @@ -0,0 +1,437 @@ +// Package ulist provides an append-only list implementation using a binary tree structure, +// optimized for scenarios requiring sequential inserts with auto-incrementing indices. +// +// The implementation uses a binary tree where new elements are added by following a path +// determined by the binary representation of the index. This provides automatic balancing +// for append operations without requiring any balancing logic. +// +// Unlike the AVL tree-based list implementation (p/demo/avl/list), ulist is specifically +// designed for append-only operations and does not require rebalancing. This makes it more +// efficient for sequential inserts but less flexible for general-purpose list operations. +// +// Key differences from AVL list: +// * Append-only design (no arbitrary inserts) +// * No tree rebalancing needed +// * Simpler implementation +// * More memory efficient for sequential operations +// * Less flexible than AVL (no arbitrary inserts/reordering) +// +// Key characteristics: +// * O(log n) append and access operations +// * Perfect balance for power-of-2 sizes +// * No balancing needed +// * Memory efficient +// * Natural support for range queries +// * Support for soft deletion of elements +// * Forward and reverse iteration capabilities +// * Offset-based iteration with count control +package ulist + +// TODO: Use this ulist in moul/collection for the primary index. +// TODO: Benchmarks. + +import ( + "errors" +) + +// List represents an append-only binary tree list +type List struct { + root *treeNode + totalSize int + activeSize int +} + +// Entry represents a key-value pair in the list, where Index is the position +// and Value is the stored data +type Entry struct { + Index int + Value interface{} +} + +// treeNode represents a node in the binary tree +type treeNode struct { + data interface{} + left *treeNode + right *treeNode +} + +// Error variables +var ( + ErrOutOfBounds = errors.New("index out of bounds") + ErrDeleted = errors.New("element already deleted") +) + +// New creates a new empty List instance +func New() *List { + return &List{} +} + +// Append adds one or more values to the end of the list. +// Values are added sequentially, and the list grows automatically. +func (l *List) Append(values ...interface{}) { + for _, value := range values { + index := l.totalSize + node := l.findNode(index, true) + node.data = value + l.totalSize++ + l.activeSize++ + } +} + +// Get retrieves the value at the specified index. +// Returns nil if the index is out of bounds or if the element was deleted. +func (l *List) Get(index int) interface{} { + node := l.findNode(index, false) + if node == nil { + return nil + } + return node.data +} + +// Delete marks the elements at the specified indices as deleted. +// Returns ErrOutOfBounds if any index is invalid or ErrDeleted if +// the element was already deleted. +func (l *List) Delete(indices ...int) error { + if len(indices) == 0 { + return nil + } + if l == nil || l.totalSize == 0 { + return ErrOutOfBounds + } + + for _, index := range indices { + if index < 0 || index >= l.totalSize { + return ErrOutOfBounds + } + + node := l.findNode(index, false) + if node == nil || node.data == nil { + return ErrDeleted + } + node.data = nil + l.activeSize-- + } + + return nil +} + +// Set updates or restores a value at the specified index if within bounds +// Returns ErrOutOfBounds if the index is invalid +func (l *List) Set(index int, value interface{}) error { + if l == nil || index < 0 || index >= l.totalSize { + return ErrOutOfBounds + } + + node := l.findNode(index, false) + if node == nil { + return ErrOutOfBounds + } + + // If this is restoring a deleted element + if value != nil && node.data == nil { + l.activeSize++ + } + + // If this is deleting an element + if value == nil && node.data != nil { + l.activeSize-- + } + + node.data = value + return nil +} + +// Size returns the number of active (non-deleted) elements in the list +func (l *List) Size() int { + if l == nil { + return 0 + } + return l.activeSize +} + +// TotalSize returns the total number of elements ever added to the list, +// including deleted elements +func (l *List) TotalSize() int { + if l == nil { + return 0 + } + return l.totalSize +} + +// IterCbFn is a callback function type used in iteration methods. +// Return true to stop iteration, false to continue. +type IterCbFn func(index int, value interface{}) bool + +// Iterator performs iteration between start and end indices, calling cb for each entry. +// If start > end, iteration is performed in reverse order. +// Returns true if iteration was stopped early by the callback returning true. +// Skips deleted elements. +func (l *List) Iterator(start, end int, cb IterCbFn) bool { + // For empty list or invalid range + if l == nil || l.totalSize == 0 { + return false + } + if start < 0 && end < 0 { + return false + } + if start >= l.totalSize && end >= l.totalSize { + return false + } + + // Normalize indices + if start < 0 { + start = 0 + } + if end < 0 { + end = 0 + } + if end >= l.totalSize { + end = l.totalSize - 1 + } + if start >= l.totalSize { + start = l.totalSize - 1 + } + + // Handle reverse iteration + if start > end { + for i := start; i >= end; i-- { + val := l.Get(i) + if val != nil { + if cb(i, val) { + return true + } + } + } + return false + } + + // Handle forward iteration + for i := start; i <= end; i++ { + val := l.Get(i) + if val != nil { + if cb(i, val) { + return true + } + } + } + return false +} + +// IteratorByOffset performs iteration starting from offset for count elements. +// If count is positive, iterates forward; if negative, iterates backward. +// The iteration stops after abs(count) elements or when reaching list bounds. +// Skips deleted elements. +func (l *List) IteratorByOffset(offset int, count int, cb IterCbFn) bool { + if count == 0 || l == nil || l.totalSize == 0 { + return false + } + + // Normalize offset + if offset < 0 { + offset = 0 + } + if offset >= l.totalSize { + offset = l.totalSize - 1 + } + + // Determine end based on count direction + var end int + if count > 0 { + end = l.totalSize - 1 + } else { + end = 0 + } + + wrapperReturned := false + + // Wrap the callback to limit iterations + remaining := abs(count) + wrapper := func(index int, value interface{}) bool { + if remaining <= 0 { + wrapperReturned = true + return true + } + remaining-- + return cb(index, value) + } + ret := l.Iterator(offset, end, wrapper) + if wrapperReturned { + return false + } + return ret +} + +// abs returns the absolute value of x +func abs(x int) int { + if x < 0 { + return -x + } + return x +} + +// findNode locates or creates a node at the given index in the binary tree. +// The tree is structured such that the path to a node is determined by the binary +// representation of the index. For example, a tree with 15 elements would look like: +// +// 0 +// / \ +// 1 2 +// / \ / \ +// 3 4 5 6 +// / \ / \ / \ / \ +// 7 8 9 10 11 12 13 14 +// +// To find index 13 (binary 1101): +// 1. Start at root (0) +// 2. Calculate bits needed (4 bits for index 13) +// 3. Skip the highest bit position and start from bits-2 +// 4. Read bits from left to right: +// - 1 -> go right to 2 +// - 1 -> go right to 6 +// - 0 -> go left to 13 +// +// Special cases: +// - Index 0 always returns the root node +// - For create=true, missing nodes are created along the path +// - For create=false, returns nil if any node is missing +func (l *List) findNode(index int, create bool) *treeNode { + // For read operations, check bounds strictly + if !create && (l == nil || index < 0 || index >= l.totalSize) { + return nil + } + + // For create operations, allow index == totalSize for append + if create && (l == nil || index < 0 || index > l.totalSize) { + return nil + } + + // Initialize root if needed + if l.root == nil { + if !create { + return nil + } + l.root = &treeNode{} + return l.root + } + + node := l.root + + // Special case for root node + if index == 0 { + return node + } + + // Calculate the number of bits needed (inline highestBit logic) + bits := 0 + n := index + 1 + for n > 0 { + n >>= 1 + bits++ + } + + // Start from the second highest bit + for level := bits - 2; level >= 0; level-- { + bit := (index & (1 << uint(level))) != 0 + + if bit { + if node.right == nil { + if !create { + return nil + } + node.right = &treeNode{} + } + node = node.right + } else { + if node.left == nil { + if !create { + return nil + } + node.left = &treeNode{} + } + node = node.left + } + } + + return node +} + +// MustDelete deletes elements at the specified indices. +// Panics if any index is invalid or if any element was already deleted. +func (l *List) MustDelete(indices ...int) { + if err := l.Delete(indices...); err != nil { + panic(err) + } +} + +// MustGet retrieves the value at the specified index. +// Panics if the index is out of bounds or if the element was deleted. +func (l *List) MustGet(index int) interface{} { + if l == nil || index < 0 || index >= l.totalSize { + panic(ErrOutOfBounds) + } + value := l.Get(index) + if value == nil { + panic(ErrDeleted) + } + return value +} + +// MustSet updates or restores a value at the specified index. +// Panics if the index is out of bounds. +func (l *List) MustSet(index int, value interface{}) { + if err := l.Set(index, value); err != nil { + panic(err) + } +} + +// GetRange returns a slice of Entry containing elements between start and end indices. +// If start > end, elements are returned in reverse order. +// Deleted elements are skipped. +func (l *List) GetRange(start, end int) []Entry { + var entries []Entry + l.Iterator(start, end, func(index int, value interface{}) bool { + entries = append(entries, Entry{Index: index, Value: value}) + return false + }) + return entries +} + +// GetByOffset returns a slice of Entry starting from offset for count elements. +// If count is positive, returns elements forward; if negative, returns elements backward. +// The operation stops after abs(count) elements or when reaching list bounds. +// Deleted elements are skipped. +func (l *List) GetByOffset(offset int, count int) []Entry { + var entries []Entry + l.IteratorByOffset(offset, count, func(index int, value interface{}) bool { + entries = append(entries, Entry{Index: index, Value: value}) + return false + }) + return entries +} + +// IList defines the interface for an ulist.List compatible structure. +type IList interface { + // Basic operations + Append(values ...interface{}) + Get(index int) interface{} + Delete(indices ...int) error + Size() int + TotalSize() int + Set(index int, value interface{}) error + + // Must variants that panic instead of returning errors + MustDelete(indices ...int) + MustGet(index int) interface{} + MustSet(index int, value interface{}) + + // Range operations + GetRange(start, end int) []Entry + GetByOffset(offset int, count int) []Entry + + // Iterator operations + Iterator(start, end int, cb IterCbFn) bool + IteratorByOffset(offset int, count int, cb IterCbFn) bool +} + +// Verify that List implements IList +var _ IList = (*List)(nil) diff --git a/examples/gno.land/p/moul/ulist/ulist_test.gno b/examples/gno.land/p/moul/ulist/ulist_test.gno new file mode 100644 index 00000000000..f098731a7db --- /dev/null +++ b/examples/gno.land/p/moul/ulist/ulist_test.gno @@ -0,0 +1,1422 @@ +package ulist + +import ( + "testing" + + "gno.land/p/demo/uassert" + "gno.land/p/demo/ufmt" + "gno.land/p/moul/typeutil" +) + +func TestNew(t *testing.T) { + l := New() + uassert.Equal(t, 0, l.Size()) + uassert.Equal(t, 0, l.TotalSize()) +} + +func TestListAppendAndGet(t *testing.T) { + tests := []struct { + name string + setup func() *List + index int + expected interface{} + }{ + { + name: "empty list", + setup: func() *List { + return New() + }, + index: 0, + expected: nil, + }, + { + name: "single append and get", + setup: func() *List { + l := New() + l.Append(42) + return l + }, + index: 0, + expected: 42, + }, + { + name: "multiple appends and get first", + setup: func() *List { + l := New() + l.Append(1) + l.Append(2) + l.Append(3) + return l + }, + index: 0, + expected: 1, + }, + { + name: "multiple appends and get last", + setup: func() *List { + l := New() + l.Append(1) + l.Append(2) + l.Append(3) + return l + }, + index: 2, + expected: 3, + }, + { + name: "get with invalid index", + setup: func() *List { + l := New() + l.Append(1) + return l + }, + index: 1, + expected: nil, + }, + { + name: "31 items get first", + setup: func() *List { + l := New() + for i := 0; i < 31; i++ { + l.Append(i) + } + return l + }, + index: 0, + expected: 0, + }, + { + name: "31 items get last", + setup: func() *List { + l := New() + for i := 0; i < 31; i++ { + l.Append(i) + } + return l + }, + index: 30, + expected: 30, + }, + { + name: "31 items get middle", + setup: func() *List { + l := New() + for i := 0; i < 31; i++ { + l.Append(i) + } + return l + }, + index: 15, + expected: 15, + }, + { + name: "values around power of 2 boundary", + setup: func() *List { + l := New() + for i := 0; i < 18; i++ { + l.Append(i) + } + return l + }, + index: 15, + expected: 15, + }, + { + name: "values at power of 2", + setup: func() *List { + l := New() + for i := 0; i < 18; i++ { + l.Append(i) + } + return l + }, + index: 16, + expected: 16, + }, + { + name: "values after power of 2", + setup: func() *List { + l := New() + for i := 0; i < 18; i++ { + l.Append(i) + } + return l + }, + index: 17, + expected: 17, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := tt.setup() + got := l.Get(tt.index) + if got != tt.expected { + t.Errorf("List.Get() = %v, want %v", got, tt.expected) + } + }) + } +} + +// generateSequence creates a slice of integers from 0 to n-1 +func generateSequence(n int) []interface{} { + result := make([]interface{}, n) + for i := 0; i < n; i++ { + result[i] = i + } + return result +} + +func TestListDelete(t *testing.T) { + tests := []struct { + name string + setup func() *List + deleteIndices []int + expectedErr error + expectedSize int + }{ + { + name: "delete single element", + setup: func() *List { + l := New() + l.Append(1, 2, 3) + return l + }, + deleteIndices: []int{1}, + expectedErr: nil, + expectedSize: 2, + }, + { + name: "delete multiple elements", + setup: func() *List { + l := New() + l.Append(1, 2, 3, 4, 5) + return l + }, + deleteIndices: []int{0, 2, 4}, + expectedErr: nil, + expectedSize: 2, + }, + { + name: "delete with negative index", + setup: func() *List { + l := New() + l.Append(1) + return l + }, + deleteIndices: []int{-1}, + expectedErr: ErrOutOfBounds, + expectedSize: 1, + }, + { + name: "delete beyond size", + setup: func() *List { + l := New() + l.Append(1) + return l + }, + deleteIndices: []int{1}, + expectedErr: ErrOutOfBounds, + expectedSize: 1, + }, + { + name: "delete already deleted element", + setup: func() *List { + l := New() + l.Append(1) + l.Delete(0) + return l + }, + deleteIndices: []int{0}, + expectedErr: ErrDeleted, + expectedSize: 0, + }, + { + name: "delete multiple elements in reverse", + setup: func() *List { + l := New() + l.Append(1, 2, 3, 4, 5) + return l + }, + deleteIndices: []int{4, 2, 0}, + expectedErr: nil, + expectedSize: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := tt.setup() + initialSize := l.Size() + err := l.Delete(tt.deleteIndices...) + if err != nil && tt.expectedErr != nil { + uassert.Equal(t, tt.expectedErr.Error(), err.Error()) + } else { + uassert.Equal(t, tt.expectedErr, err) + } + uassert.Equal(t, tt.expectedSize, l.Size(), + ufmt.Sprintf("Expected size %d after deleting %d elements from size %d, got %d", + tt.expectedSize, len(tt.deleteIndices), initialSize, l.Size())) + }) + } +} + +func TestListSizeAndTotalSize(t *testing.T) { + t.Run("empty list", func(t *testing.T) { + list := New() + uassert.Equal(t, 0, list.Size()) + uassert.Equal(t, 0, list.TotalSize()) + }) + + t.Run("list with elements", func(t *testing.T) { + list := New() + list.Append(1) + list.Append(2) + list.Append(3) + uassert.Equal(t, 3, list.Size()) + uassert.Equal(t, 3, list.TotalSize()) + }) + + t.Run("list with deleted elements", func(t *testing.T) { + list := New() + list.Append(1) + list.Append(2) + list.Append(3) + list.Delete(1) + uassert.Equal(t, 2, list.Size()) + uassert.Equal(t, 3, list.TotalSize()) + }) +} + +func TestIterator(t *testing.T) { + tests := []struct { + name string + values []interface{} + start int + end int + expected []Entry + wantStop bool + stopAfter int // stop after N elements, -1 for no stop + }{ + { + name: "empty list", + values: []interface{}{}, + start: 0, + end: 10, + expected: []Entry{}, + stopAfter: -1, + }, + { + name: "nil list", + values: nil, + start: 0, + end: 0, + expected: []Entry{}, + stopAfter: -1, + }, + { + name: "single element forward", + values: []interface{}{42}, + start: 0, + end: 0, + expected: []Entry{ + {Index: 0, Value: 42}, + }, + stopAfter: -1, + }, + { + name: "multiple elements forward", + values: []interface{}{1, 2, 3, 4, 5}, + start: 0, + end: 4, + expected: []Entry{ + {Index: 0, Value: 1}, + {Index: 1, Value: 2}, + {Index: 2, Value: 3}, + {Index: 3, Value: 4}, + {Index: 4, Value: 5}, + }, + stopAfter: -1, + }, + { + name: "multiple elements reverse", + values: []interface{}{1, 2, 3, 4, 5}, + start: 4, + end: 0, + expected: []Entry{ + {Index: 4, Value: 5}, + {Index: 3, Value: 4}, + {Index: 2, Value: 3}, + {Index: 1, Value: 2}, + {Index: 0, Value: 1}, + }, + stopAfter: -1, + }, + { + name: "partial range forward", + values: []interface{}{1, 2, 3, 4, 5}, + start: 1, + end: 3, + expected: []Entry{ + {Index: 1, Value: 2}, + {Index: 2, Value: 3}, + {Index: 3, Value: 4}, + }, + stopAfter: -1, + }, + { + name: "partial range reverse", + values: []interface{}{1, 2, 3, 4, 5}, + start: 3, + end: 1, + expected: []Entry{ + {Index: 3, Value: 4}, + {Index: 2, Value: 3}, + {Index: 1, Value: 2}, + }, + stopAfter: -1, + }, + { + name: "stop iteration early", + values: []interface{}{1, 2, 3, 4, 5}, + start: 0, + end: 4, + wantStop: true, + stopAfter: 2, + expected: []Entry{ + {Index: 0, Value: 1}, + {Index: 1, Value: 2}, + }, + }, + { + name: "negative start", + values: []interface{}{1, 2, 3}, + start: -1, + end: 2, + expected: []Entry{ + {Index: 0, Value: 1}, + {Index: 1, Value: 2}, + {Index: 2, Value: 3}, + }, + stopAfter: -1, + }, + { + name: "negative end", + values: []interface{}{1, 2, 3}, + start: 0, + end: -2, + expected: []Entry{ + {Index: 0, Value: 1}, + }, + stopAfter: -1, + }, + { + name: "start beyond size", + values: []interface{}{1, 2, 3}, + start: 5, + end: 6, + expected: []Entry{}, + stopAfter: -1, + }, + { + name: "end beyond size", + values: []interface{}{1, 2, 3}, + start: 0, + end: 5, + expected: []Entry{ + {Index: 0, Value: 1}, + {Index: 1, Value: 2}, + {Index: 2, Value: 3}, + }, + stopAfter: -1, + }, + { + name: "with deleted elements", + values: []interface{}{1, 2, nil, 4, 5}, + start: 0, + end: 4, + expected: []Entry{ + {Index: 0, Value: 1}, + {Index: 1, Value: 2}, + {Index: 3, Value: 4}, + {Index: 4, Value: 5}, + }, + stopAfter: -1, + }, + { + name: "with deleted elements reverse", + values: []interface{}{1, nil, 3, nil, 5}, + start: 4, + end: 0, + expected: []Entry{ + {Index: 4, Value: 5}, + {Index: 2, Value: 3}, + {Index: 0, Value: 1}, + }, + stopAfter: -1, + }, + { + name: "start equals end", + values: []interface{}{1, 2, 3}, + start: 1, + end: 1, + expected: []Entry{{Index: 1, Value: 2}}, + stopAfter: -1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + list := New() + list.Append(tt.values...) + + var result []Entry + stopped := list.Iterator(tt.start, tt.end, func(index int, value interface{}) bool { + result = append(result, Entry{Index: index, Value: value}) + return tt.stopAfter >= 0 && len(result) >= tt.stopAfter + }) + + uassert.Equal(t, len(result), len(tt.expected), "comparing length") + + for i := range result { + uassert.Equal(t, result[i].Index, tt.expected[i].Index, "comparing index") + uassert.Equal(t, typeutil.ToString(result[i].Value), typeutil.ToString(tt.expected[i].Value), "comparing value") + } + + uassert.Equal(t, stopped, tt.wantStop, "comparing stopped") + }) + } +} + +func TestLargeListAppendGetAndDelete(t *testing.T) { + l := New() + size := 100 + + // Append values from 0 to 99 + for i := 0; i < size; i++ { + l.Append(i) + val := l.Get(i) + uassert.Equal(t, i, val) + } + + // Verify size + uassert.Equal(t, size, l.Size()) + uassert.Equal(t, size, l.TotalSize()) + + // Get and verify each value + for i := 0; i < size; i++ { + val := l.Get(i) + uassert.Equal(t, i, val) + } + + // Get and verify each value + for i := 0; i < size; i++ { + err := l.Delete(i) + uassert.Equal(t, nil, err) + } + + // Verify size + uassert.Equal(t, 0, l.Size()) + uassert.Equal(t, size, l.TotalSize()) + + // Get and verify each value + for i := 0; i < size; i++ { + val := l.Get(i) + uassert.Equal(t, nil, val) + } +} + +func TestEdgeCases(t *testing.T) { + tests := []struct { + name string + test func(t *testing.T) + }{ + { + name: "nil list operations", + test: func(t *testing.T) { + var l *List + uassert.Equal(t, 0, l.Size()) + uassert.Equal(t, 0, l.TotalSize()) + uassert.Equal(t, nil, l.Get(0)) + err := l.Delete(0) + uassert.Equal(t, ErrOutOfBounds.Error(), err.Error()) + }, + }, + { + name: "delete empty indices slice", + test: func(t *testing.T) { + l := New() + l.Append(1) + err := l.Delete() + uassert.Equal(t, nil, err) + uassert.Equal(t, 1, l.Size()) + }, + }, + { + name: "append nil values", + test: func(t *testing.T) { + l := New() + l.Append(nil, nil) + uassert.Equal(t, 2, l.Size()) + uassert.Equal(t, nil, l.Get(0)) + uassert.Equal(t, nil, l.Get(1)) + }, + }, + { + name: "delete same index multiple times", + test: func(t *testing.T) { + l := New() + l.Append(1, 2, 3) + err := l.Delete(1) + uassert.Equal(t, nil, err) + err = l.Delete(1) + uassert.Equal(t, ErrDeleted.Error(), err.Error()) + }, + }, + { + name: "iterator with all deleted elements", + test: func(t *testing.T) { + l := New() + l.Append(1, 2, 3) + l.Delete(0, 1, 2) + var count int + l.Iterator(0, 2, func(index int, value interface{}) bool { + count++ + return false + }) + uassert.Equal(t, 0, count) + }, + }, + { + name: "append after delete", + test: func(t *testing.T) { + l := New() + l.Append(1, 2) + l.Delete(1) + l.Append(3) + uassert.Equal(t, 2, l.Size()) + uassert.Equal(t, 3, l.TotalSize()) + uassert.Equal(t, 1, l.Get(0)) + uassert.Equal(t, nil, l.Get(1)) + uassert.Equal(t, 3, l.Get(2)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.test(t) + }) + } +} + +func TestIteratorByOffset(t *testing.T) { + tests := []struct { + name string + values []interface{} + offset int + count int + expected []Entry + wantStop bool + }{ + { + name: "empty list", + values: []interface{}{}, + offset: 0, + count: 5, + expected: []Entry{}, + wantStop: false, + }, + { + name: "positive count forward iteration", + values: []interface{}{1, 2, 3, 4, 5}, + offset: 1, + count: 2, + expected: []Entry{ + {Index: 1, Value: 2}, + {Index: 2, Value: 3}, + }, + wantStop: false, + }, + { + name: "negative count backward iteration", + values: []interface{}{1, 2, 3, 4, 5}, + offset: 3, + count: -2, + expected: []Entry{ + {Index: 3, Value: 4}, + {Index: 2, Value: 3}, + }, + wantStop: false, + }, + { + name: "count exceeds available elements forward", + values: []interface{}{1, 2, 3}, + offset: 1, + count: 5, + expected: []Entry{ + {Index: 1, Value: 2}, + {Index: 2, Value: 3}, + }, + wantStop: false, + }, + { + name: "count exceeds available elements backward", + values: []interface{}{1, 2, 3}, + offset: 1, + count: -5, + expected: []Entry{ + {Index: 1, Value: 2}, + {Index: 0, Value: 1}, + }, + wantStop: false, + }, + { + name: "zero count", + values: []interface{}{1, 2, 3}, + offset: 0, + count: 0, + expected: []Entry{}, + wantStop: false, + }, + { + name: "negative offset", + values: []interface{}{1, 2, 3}, + offset: -1, + count: 2, + expected: []Entry{ + {Index: 0, Value: 1}, + {Index: 1, Value: 2}, + }, + wantStop: false, + }, + { + name: "offset beyond size", + values: []interface{}{1, 2, 3}, + offset: 5, + count: -2, + expected: []Entry{ + {Index: 2, Value: 3}, + {Index: 1, Value: 2}, + }, + wantStop: false, + }, + { + name: "with deleted elements", + values: []interface{}{1, nil, 3, nil, 5}, + offset: 0, + count: 3, + expected: []Entry{ + {Index: 0, Value: 1}, + {Index: 2, Value: 3}, + {Index: 4, Value: 5}, + }, + wantStop: false, + }, + { + name: "early stop in forward iteration", + values: []interface{}{1, 2, 3, 4, 5}, + offset: 0, + count: 5, + expected: []Entry{ + {Index: 0, Value: 1}, + {Index: 1, Value: 2}, + }, + wantStop: true, // The callback will return true after 2 elements + }, + { + name: "early stop in backward iteration", + values: []interface{}{1, 2, 3, 4, 5}, + offset: 4, + count: -5, + expected: []Entry{ + {Index: 4, Value: 5}, + {Index: 3, Value: 4}, + }, + wantStop: true, // The callback will return true after 2 elements + }, + { + name: "nil list", + values: nil, + offset: 0, + count: 5, + expected: []Entry{}, + wantStop: false, + }, + { + name: "single element forward", + values: []interface{}{1}, + offset: 0, + count: 5, + expected: []Entry{ + {Index: 0, Value: 1}, + }, + wantStop: false, + }, + { + name: "single element backward", + values: []interface{}{1}, + offset: 0, + count: -5, + expected: []Entry{ + {Index: 0, Value: 1}, + }, + wantStop: false, + }, + { + name: "all deleted elements", + values: []interface{}{nil, nil, nil}, + offset: 0, + count: 3, + expected: []Entry{}, + wantStop: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + list := New() + list.Append(tt.values...) + + var result []Entry + var cb IterCbFn + if tt.wantStop { + cb = func(index int, value interface{}) bool { + result = append(result, Entry{Index: index, Value: value}) + return len(result) >= 2 // Stop after 2 elements for early stop tests + } + } else { + cb = func(index int, value interface{}) bool { + result = append(result, Entry{Index: index, Value: value}) + return false + } + } + + stopped := list.IteratorByOffset(tt.offset, tt.count, cb) + + uassert.Equal(t, len(tt.expected), len(result), "comparing length") + for i := range result { + uassert.Equal(t, tt.expected[i].Index, result[i].Index, "comparing index") + uassert.Equal(t, typeutil.ToString(tt.expected[i].Value), typeutil.ToString(result[i].Value), "comparing value") + } + uassert.Equal(t, tt.wantStop, stopped, "comparing stopped") + }) + } +} + +func TestMustDelete(t *testing.T) { + tests := []struct { + name string + setup func() *List + indices []int + shouldPanic bool + panicMsg string + }{ + { + name: "successful delete", + setup: func() *List { + l := New() + l.Append(1, 2, 3) + return l + }, + indices: []int{1}, + shouldPanic: false, + }, + { + name: "out of bounds", + setup: func() *List { + l := New() + l.Append(1) + return l + }, + indices: []int{1}, + shouldPanic: true, + panicMsg: ErrOutOfBounds.Error(), + }, + { + name: "already deleted", + setup: func() *List { + l := New() + l.Append(1) + l.Delete(0) + return l + }, + indices: []int{0}, + shouldPanic: true, + panicMsg: ErrDeleted.Error(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := tt.setup() + if tt.shouldPanic { + defer func() { + r := recover() + if r == nil { + t.Error("Expected panic but got none") + } + err, ok := r.(error) + if !ok { + t.Errorf("Expected error but got %v", r) + } + uassert.Equal(t, tt.panicMsg, err.Error()) + }() + } + l.MustDelete(tt.indices...) + if tt.shouldPanic { + t.Error("Expected panic") + } + }) + } +} + +func TestMustGet(t *testing.T) { + tests := []struct { + name string + setup func() *List + index int + expected interface{} + shouldPanic bool + panicMsg string + }{ + { + name: "successful get", + setup: func() *List { + l := New() + l.Append(42) + return l + }, + index: 0, + expected: 42, + shouldPanic: false, + }, + { + name: "out of bounds negative", + setup: func() *List { + l := New() + l.Append(1) + return l + }, + index: -1, + shouldPanic: true, + panicMsg: ErrOutOfBounds.Error(), + }, + { + name: "out of bounds positive", + setup: func() *List { + l := New() + l.Append(1) + return l + }, + index: 1, + shouldPanic: true, + panicMsg: ErrOutOfBounds.Error(), + }, + { + name: "deleted element", + setup: func() *List { + l := New() + l.Append(1) + l.Delete(0) + return l + }, + index: 0, + shouldPanic: true, + panicMsg: ErrDeleted.Error(), + }, + { + name: "nil list", + setup: func() *List { + return nil + }, + index: 0, + shouldPanic: true, + panicMsg: ErrOutOfBounds.Error(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := tt.setup() + if tt.shouldPanic { + defer func() { + r := recover() + if r == nil { + t.Error("Expected panic but got none") + } + err, ok := r.(error) + if !ok { + t.Errorf("Expected error but got %v", r) + } + uassert.Equal(t, tt.panicMsg, err.Error()) + }() + } + result := l.MustGet(tt.index) + if tt.shouldPanic { + t.Error("Expected panic") + } + uassert.Equal(t, typeutil.ToString(tt.expected), typeutil.ToString(result)) + }) + } +} + +func TestGetRange(t *testing.T) { + tests := []struct { + name string + values []interface{} + start int + end int + expected []Entry + }{ + { + name: "empty list", + values: []interface{}{}, + start: 0, + end: 10, + expected: []Entry{}, + }, + { + name: "single element", + values: []interface{}{42}, + start: 0, + end: 0, + expected: []Entry{ + {Index: 0, Value: 42}, + }, + }, + { + name: "multiple elements forward", + values: []interface{}{1, 2, 3, 4, 5}, + start: 1, + end: 3, + expected: []Entry{ + {Index: 1, Value: 2}, + {Index: 2, Value: 3}, + {Index: 3, Value: 4}, + }, + }, + { + name: "multiple elements reverse", + values: []interface{}{1, 2, 3, 4, 5}, + start: 3, + end: 1, + expected: []Entry{ + {Index: 3, Value: 4}, + {Index: 2, Value: 3}, + {Index: 1, Value: 2}, + }, + }, + { + name: "with deleted elements", + values: []interface{}{1, nil, 3, nil, 5}, + start: 0, + end: 4, + expected: []Entry{ + {Index: 0, Value: 1}, + {Index: 2, Value: 3}, + {Index: 4, Value: 5}, + }, + }, + { + name: "nil list", + values: nil, + start: 0, + end: 5, + expected: []Entry{}, + }, + { + name: "negative indices", + values: []interface{}{1, 2, 3}, + start: -1, + end: -2, + expected: []Entry{}, + }, + { + name: "indices beyond size", + values: []interface{}{1, 2, 3}, + start: 1, + end: 5, + expected: []Entry{ + {Index: 1, Value: 2}, + {Index: 2, Value: 3}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + list := New() + list.Append(tt.values...) + + result := list.GetRange(tt.start, tt.end) + + uassert.Equal(t, len(tt.expected), len(result), "comparing length") + for i := range result { + uassert.Equal(t, tt.expected[i].Index, result[i].Index, "comparing index") + uassert.Equal(t, typeutil.ToString(tt.expected[i].Value), typeutil.ToString(result[i].Value), "comparing value") + } + }) + } +} + +func TestGetByOffset(t *testing.T) { + tests := []struct { + name string + values []interface{} + offset int + count int + expected []Entry + }{ + { + name: "empty list", + values: []interface{}{}, + offset: 0, + count: 5, + expected: []Entry{}, + }, + { + name: "positive count forward", + values: []interface{}{1, 2, 3, 4, 5}, + offset: 1, + count: 2, + expected: []Entry{ + {Index: 1, Value: 2}, + {Index: 2, Value: 3}, + }, + }, + { + name: "negative count backward", + values: []interface{}{1, 2, 3, 4, 5}, + offset: 3, + count: -2, + expected: []Entry{ + {Index: 3, Value: 4}, + {Index: 2, Value: 3}, + }, + }, + { + name: "count exceeds available elements", + values: []interface{}{1, 2, 3}, + offset: 1, + count: 5, + expected: []Entry{ + {Index: 1, Value: 2}, + {Index: 2, Value: 3}, + }, + }, + { + name: "zero count", + values: []interface{}{1, 2, 3}, + offset: 0, + count: 0, + expected: []Entry{}, + }, + { + name: "with deleted elements", + values: []interface{}{1, nil, 3, nil, 5}, + offset: 0, + count: 3, + expected: []Entry{ + {Index: 0, Value: 1}, + {Index: 2, Value: 3}, + {Index: 4, Value: 5}, + }, + }, + { + name: "negative offset", + values: []interface{}{1, 2, 3}, + offset: -1, + count: 2, + expected: []Entry{ + {Index: 0, Value: 1}, + {Index: 1, Value: 2}, + }, + }, + { + name: "offset beyond size", + values: []interface{}{1, 2, 3}, + offset: 5, + count: -2, + expected: []Entry{ + {Index: 2, Value: 3}, + {Index: 1, Value: 2}, + }, + }, + { + name: "nil list", + values: nil, + offset: 0, + count: 5, + expected: []Entry{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + list := New() + list.Append(tt.values...) + + result := list.GetByOffset(tt.offset, tt.count) + + uassert.Equal(t, len(tt.expected), len(result), "comparing length") + for i := range result { + uassert.Equal(t, tt.expected[i].Index, result[i].Index, "comparing index") + uassert.Equal(t, typeutil.ToString(tt.expected[i].Value), typeutil.ToString(result[i].Value), "comparing value") + } + }) + } +} + +func TestMustSet(t *testing.T) { + tests := []struct { + name string + setup func() *List + index int + value interface{} + shouldPanic bool + panicMsg string + }{ + { + name: "successful set", + setup: func() *List { + l := New() + l.Append(42) + return l + }, + index: 0, + value: 99, + shouldPanic: false, + }, + { + name: "restore deleted element", + setup: func() *List { + l := New() + l.Append(42) + l.Delete(0) + return l + }, + index: 0, + value: 99, + shouldPanic: false, + }, + { + name: "out of bounds negative", + setup: func() *List { + l := New() + l.Append(1) + return l + }, + index: -1, + value: 99, + shouldPanic: true, + panicMsg: ErrOutOfBounds.Error(), + }, + { + name: "out of bounds positive", + setup: func() *List { + l := New() + l.Append(1) + return l + }, + index: 1, + value: 99, + shouldPanic: true, + panicMsg: ErrOutOfBounds.Error(), + }, + { + name: "nil list", + setup: func() *List { + return nil + }, + index: 0, + value: 99, + shouldPanic: true, + panicMsg: ErrOutOfBounds.Error(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := tt.setup() + if tt.shouldPanic { + defer func() { + r := recover() + if r == nil { + t.Error("Expected panic but got none") + } + err, ok := r.(error) + if !ok { + t.Errorf("Expected error but got %v", r) + } + uassert.Equal(t, tt.panicMsg, err.Error()) + }() + } + l.MustSet(tt.index, tt.value) + if tt.shouldPanic { + t.Error("Expected panic") + } + // Verify the value was set correctly for non-panic cases + if !tt.shouldPanic { + result := l.Get(tt.index) + uassert.Equal(t, typeutil.ToString(tt.value), typeutil.ToString(result)) + } + }) + } +} + +func TestSet(t *testing.T) { + tests := []struct { + name string + setup func() *List + index int + value interface{} + expectedErr error + verify func(t *testing.T, l *List) + }{ + { + name: "set value in empty list", + setup: func() *List { + return New() + }, + index: 0, + value: 42, + expectedErr: ErrOutOfBounds, + verify: func(t *testing.T, l *List) { + uassert.Equal(t, 0, l.Size()) + }, + }, + { + name: "set value at valid index", + setup: func() *List { + l := New() + l.Append(1) + return l + }, + index: 0, + value: 42, + verify: func(t *testing.T, l *List) { + uassert.Equal(t, 42, l.Get(0)) + uassert.Equal(t, 1, l.Size()) + uassert.Equal(t, 1, l.TotalSize()) + }, + }, + { + name: "set value at negative index", + setup: func() *List { + l := New() + l.Append(1) + return l + }, + index: -1, + value: 42, + expectedErr: ErrOutOfBounds, + verify: func(t *testing.T, l *List) { + uassert.Equal(t, 1, l.Get(0)) + }, + }, + { + name: "set value beyond size", + setup: func() *List { + l := New() + l.Append(1) + return l + }, + index: 1, + value: 42, + expectedErr: ErrOutOfBounds, + verify: func(t *testing.T, l *List) { + uassert.Equal(t, 1, l.Get(0)) + uassert.Equal(t, 1, l.Size()) + }, + }, + { + name: "set nil value", + setup: func() *List { + l := New() + l.Append(1) + return l + }, + index: 0, + value: nil, + verify: func(t *testing.T, l *List) { + uassert.Equal(t, nil, l.Get(0)) + uassert.Equal(t, 0, l.Size()) + }, + }, + { + name: "set value at deleted index", + setup: func() *List { + l := New() + l.Append(1, 2, 3) + l.Delete(1) + return l + }, + index: 1, + value: 42, + verify: func(t *testing.T, l *List) { + uassert.Equal(t, 42, l.Get(1)) + uassert.Equal(t, 3, l.Size()) + uassert.Equal(t, 3, l.TotalSize()) + }, + }, + { + name: "set value in nil list", + setup: func() *List { + return nil + }, + index: 0, + value: 42, + expectedErr: ErrOutOfBounds, + verify: func(t *testing.T, l *List) { + uassert.Equal(t, 0, l.Size()) + }, + }, + { + name: "set multiple values at same index", + setup: func() *List { + l := New() + l.Append(1) + return l + }, + index: 0, + value: 42, + verify: func(t *testing.T, l *List) { + uassert.Equal(t, 42, l.Get(0)) + err := l.Set(0, 99) + uassert.Equal(t, nil, err) + uassert.Equal(t, 99, l.Get(0)) + uassert.Equal(t, 1, l.Size()) + }, + }, + { + name: "set value at last index", + setup: func() *List { + l := New() + l.Append(1, 2, 3) + return l + }, + index: 2, + value: 42, + verify: func(t *testing.T, l *List) { + uassert.Equal(t, 42, l.Get(2)) + uassert.Equal(t, 3, l.Size()) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := tt.setup() + err := l.Set(tt.index, tt.value) + + if tt.expectedErr != nil { + uassert.Equal(t, tt.expectedErr.Error(), err.Error()) + } else { + uassert.Equal(t, nil, err) + } + + tt.verify(t, l) + }) + } +} diff --git a/examples/gno.land/p/moul/xmath/generate.go b/examples/gno.land/p/moul/xmath/generate.go new file mode 100644 index 00000000000..ad70adb06bd --- /dev/null +++ b/examples/gno.land/p/moul/xmath/generate.go @@ -0,0 +1,3 @@ +package xmath + +//go:generate go run generator.go diff --git a/examples/gno.land/p/moul/xmath/generator.go b/examples/gno.land/p/moul/xmath/generator.go new file mode 100644 index 00000000000..afe5a4341fa --- /dev/null +++ b/examples/gno.land/p/moul/xmath/generator.go @@ -0,0 +1,184 @@ +//go:build ignore + +package main + +import ( + "bytes" + "fmt" + "go/format" + "log" + "os" + "strings" + "text/template" +) + +type Type struct { + Name string + ZeroValue string + Signed bool + Float bool +} + +var types = []Type{ + {"Int8", "0", true, false}, + {"Int16", "0", true, false}, + {"Int32", "0", true, false}, + {"Int64", "0", true, false}, + {"Int", "0", true, false}, + {"Uint8", "0", false, false}, + {"Uint16", "0", false, false}, + {"Uint32", "0", false, false}, + {"Uint64", "0", false, false}, + {"Uint", "0", false, false}, + {"Float32", "0.0", true, true}, + {"Float64", "0.0", true, true}, +} + +const sourceTpl = `// Code generated by generator.go; DO NOT EDIT. +package xmath + +{{ range .Types }} +// {{.Name}} helpers +func Max{{.Name}}(a, b {{.Name | lower}}) {{.Name | lower}} { + if a > b { + return a + } + return b +} + +func Min{{.Name}}(a, b {{.Name | lower}}) {{.Name | lower}} { + if a < b { + return a + } + return b +} + +func Clamp{{.Name}}(value, min, max {{.Name | lower}}) {{.Name | lower}} { + if value < min { + return min + } + if value > max { + return max + } + return value +} +{{if .Signed}} +func Abs{{.Name}}(x {{.Name | lower}}) {{.Name | lower}} { + if x < 0 { + return -x + } + return x +} + +func Sign{{.Name}}(x {{.Name | lower}}) {{.Name | lower}} { + if x < 0 { + return -1 + } + if x > 0 { + return 1 + } + return 0 +} +{{end}} +{{end}} +` + +const testTpl = `package xmath + +import "testing" + +{{range .Types}} +func Test{{.Name}}Helpers(t *testing.T) { + // Test Max{{.Name}} + if Max{{.Name}}(1, 2) != 2 { + t.Error("Max{{.Name}}(1, 2) should be 2") + } + {{if .Signed}}if Max{{.Name}}(-1, -2) != -1 { + t.Error("Max{{.Name}}(-1, -2) should be -1") + }{{end}} + + // Test Min{{.Name}} + if Min{{.Name}}(1, 2) != 1 { + t.Error("Min{{.Name}}(1, 2) should be 1") + } + {{if .Signed}}if Min{{.Name}}(-1, -2) != -2 { + t.Error("Min{{.Name}}(-1, -2) should be -2") + }{{end}} + + // Test Clamp{{.Name}} + if Clamp{{.Name}}(5, 1, 3) != 3 { + t.Error("Clamp{{.Name}}(5, 1, 3) should be 3") + } + if Clamp{{.Name}}(0, 1, 3) != 1 { + t.Error("Clamp{{.Name}}(0, 1, 3) should be 1") + } + if Clamp{{.Name}}(2, 1, 3) != 2 { + t.Error("Clamp{{.Name}}(2, 1, 3) should be 2") + } + {{if .Signed}} + // Test Abs{{.Name}} + if Abs{{.Name}}(-5) != 5 { + t.Error("Abs{{.Name}}(-5) should be 5") + } + if Abs{{.Name}}(5) != 5 { + t.Error("Abs{{.Name}}(5) should be 5") + } + + // Test Sign{{.Name}} + if Sign{{.Name}}(-5) != -1 { + t.Error("Sign{{.Name}}(-5) should be -1") + } + if Sign{{.Name}}(5) != 1 { + t.Error("Sign{{.Name}}(5) should be 1") + } + if Sign{{.Name}}({{.ZeroValue}}) != 0 { + t.Error("Sign{{.Name}}({{.ZeroValue}}) should be 0") + } + {{end}} +} +{{end}} +` + +func main() { + funcMap := template.FuncMap{ + "lower": strings.ToLower, + } + + // Generate source file + sourceTmpl := template.Must(template.New("source").Funcs(funcMap).Parse(sourceTpl)) + var sourceOut bytes.Buffer + if err := sourceTmpl.Execute(&sourceOut, struct{ Types []Type }{types}); err != nil { + log.Fatal(err) + } + + // Format the generated code + formattedSource, err := format.Source(sourceOut.Bytes()) + if err != nil { + log.Fatal(err) + } + + // Write source file + if err := os.WriteFile("xmath.gen.gno", formattedSource, 0644); err != nil { + log.Fatal(err) + } + + // Generate test file + testTmpl := template.Must(template.New("test").Parse(testTpl)) + var testOut bytes.Buffer + if err := testTmpl.Execute(&testOut, struct{ Types []Type }{types}); err != nil { + log.Fatal(err) + } + + // Format the generated test code + formattedTest, err := format.Source(testOut.Bytes()) + if err != nil { + log.Fatal(err) + } + + // Write test file + if err := os.WriteFile("xmath.gen_test.gno", formattedTest, 0644); err != nil { + log.Fatal(err) + } + + fmt.Println("Generated xmath.gen.gno and xmath.gen_test.gno") +} diff --git a/examples/gno.land/p/moul/xmath/gno.mod b/examples/gno.land/p/moul/xmath/gno.mod new file mode 100644 index 00000000000..63b782c88f2 --- /dev/null +++ b/examples/gno.land/p/moul/xmath/gno.mod @@ -0,0 +1 @@ +module gno.land/p/moul/xmath diff --git a/examples/gno.land/p/moul/xmath/xmath.gen.gno b/examples/gno.land/p/moul/xmath/xmath.gen.gno new file mode 100644 index 00000000000..266c77e1e84 --- /dev/null +++ b/examples/gno.land/p/moul/xmath/xmath.gen.gno @@ -0,0 +1,421 @@ +// Code generated by generator.go; DO NOT EDIT. +package xmath + +// Int8 helpers +func MaxInt8(a, b int8) int8 { + if a > b { + return a + } + return b +} + +func MinInt8(a, b int8) int8 { + if a < b { + return a + } + return b +} + +func ClampInt8(value, min, max int8) int8 { + if value < min { + return min + } + if value > max { + return max + } + return value +} + +func AbsInt8(x int8) int8 { + if x < 0 { + return -x + } + return x +} + +func SignInt8(x int8) int8 { + if x < 0 { + return -1 + } + if x > 0 { + return 1 + } + return 0 +} + +// Int16 helpers +func MaxInt16(a, b int16) int16 { + if a > b { + return a + } + return b +} + +func MinInt16(a, b int16) int16 { + if a < b { + return a + } + return b +} + +func ClampInt16(value, min, max int16) int16 { + if value < min { + return min + } + if value > max { + return max + } + return value +} + +func AbsInt16(x int16) int16 { + if x < 0 { + return -x + } + return x +} + +func SignInt16(x int16) int16 { + if x < 0 { + return -1 + } + if x > 0 { + return 1 + } + return 0 +} + +// Int32 helpers +func MaxInt32(a, b int32) int32 { + if a > b { + return a + } + return b +} + +func MinInt32(a, b int32) int32 { + if a < b { + return a + } + return b +} + +func ClampInt32(value, min, max int32) int32 { + if value < min { + return min + } + if value > max { + return max + } + return value +} + +func AbsInt32(x int32) int32 { + if x < 0 { + return -x + } + return x +} + +func SignInt32(x int32) int32 { + if x < 0 { + return -1 + } + if x > 0 { + return 1 + } + return 0 +} + +// Int64 helpers +func MaxInt64(a, b int64) int64 { + if a > b { + return a + } + return b +} + +func MinInt64(a, b int64) int64 { + if a < b { + return a + } + return b +} + +func ClampInt64(value, min, max int64) int64 { + if value < min { + return min + } + if value > max { + return max + } + return value +} + +func AbsInt64(x int64) int64 { + if x < 0 { + return -x + } + return x +} + +func SignInt64(x int64) int64 { + if x < 0 { + return -1 + } + if x > 0 { + return 1 + } + return 0 +} + +// Int helpers +func MaxInt(a, b int) int { + if a > b { + return a + } + return b +} + +func MinInt(a, b int) int { + if a < b { + return a + } + return b +} + +func ClampInt(value, min, max int) int { + if value < min { + return min + } + if value > max { + return max + } + return value +} + +func AbsInt(x int) int { + if x < 0 { + return -x + } + return x +} + +func SignInt(x int) int { + if x < 0 { + return -1 + } + if x > 0 { + return 1 + } + return 0 +} + +// Uint8 helpers +func MaxUint8(a, b uint8) uint8 { + if a > b { + return a + } + return b +} + +func MinUint8(a, b uint8) uint8 { + if a < b { + return a + } + return b +} + +func ClampUint8(value, min, max uint8) uint8 { + if value < min { + return min + } + if value > max { + return max + } + return value +} + +// Uint16 helpers +func MaxUint16(a, b uint16) uint16 { + if a > b { + return a + } + return b +} + +func MinUint16(a, b uint16) uint16 { + if a < b { + return a + } + return b +} + +func ClampUint16(value, min, max uint16) uint16 { + if value < min { + return min + } + if value > max { + return max + } + return value +} + +// Uint32 helpers +func MaxUint32(a, b uint32) uint32 { + if a > b { + return a + } + return b +} + +func MinUint32(a, b uint32) uint32 { + if a < b { + return a + } + return b +} + +func ClampUint32(value, min, max uint32) uint32 { + if value < min { + return min + } + if value > max { + return max + } + return value +} + +// Uint64 helpers +func MaxUint64(a, b uint64) uint64 { + if a > b { + return a + } + return b +} + +func MinUint64(a, b uint64) uint64 { + if a < b { + return a + } + return b +} + +func ClampUint64(value, min, max uint64) uint64 { + if value < min { + return min + } + if value > max { + return max + } + return value +} + +// Uint helpers +func MaxUint(a, b uint) uint { + if a > b { + return a + } + return b +} + +func MinUint(a, b uint) uint { + if a < b { + return a + } + return b +} + +func ClampUint(value, min, max uint) uint { + if value < min { + return min + } + if value > max { + return max + } + return value +} + +// Float32 helpers +func MaxFloat32(a, b float32) float32 { + if a > b { + return a + } + return b +} + +func MinFloat32(a, b float32) float32 { + if a < b { + return a + } + return b +} + +func ClampFloat32(value, min, max float32) float32 { + if value < min { + return min + } + if value > max { + return max + } + return value +} + +func AbsFloat32(x float32) float32 { + if x < 0 { + return -x + } + return x +} + +func SignFloat32(x float32) float32 { + if x < 0 { + return -1 + } + if x > 0 { + return 1 + } + return 0 +} + +// Float64 helpers +func MaxFloat64(a, b float64) float64 { + if a > b { + return a + } + return b +} + +func MinFloat64(a, b float64) float64 { + if a < b { + return a + } + return b +} + +func ClampFloat64(value, min, max float64) float64 { + if value < min { + return min + } + if value > max { + return max + } + return value +} + +func AbsFloat64(x float64) float64 { + if x < 0 { + return -x + } + return x +} + +func SignFloat64(x float64) float64 { + if x < 0 { + return -1 + } + if x > 0 { + return 1 + } + return 0 +} diff --git a/examples/gno.land/p/moul/xmath/xmath.gen_test.gno b/examples/gno.land/p/moul/xmath/xmath.gen_test.gno new file mode 100644 index 00000000000..16c80fc983d --- /dev/null +++ b/examples/gno.land/p/moul/xmath/xmath.gen_test.gno @@ -0,0 +1,466 @@ +package xmath + +import "testing" + +func TestInt8Helpers(t *testing.T) { + // Test MaxInt8 + if MaxInt8(1, 2) != 2 { + t.Error("MaxInt8(1, 2) should be 2") + } + if MaxInt8(-1, -2) != -1 { + t.Error("MaxInt8(-1, -2) should be -1") + } + + // Test MinInt8 + if MinInt8(1, 2) != 1 { + t.Error("MinInt8(1, 2) should be 1") + } + if MinInt8(-1, -2) != -2 { + t.Error("MinInt8(-1, -2) should be -2") + } + + // Test ClampInt8 + if ClampInt8(5, 1, 3) != 3 { + t.Error("ClampInt8(5, 1, 3) should be 3") + } + if ClampInt8(0, 1, 3) != 1 { + t.Error("ClampInt8(0, 1, 3) should be 1") + } + if ClampInt8(2, 1, 3) != 2 { + t.Error("ClampInt8(2, 1, 3) should be 2") + } + + // Test AbsInt8 + if AbsInt8(-5) != 5 { + t.Error("AbsInt8(-5) should be 5") + } + if AbsInt8(5) != 5 { + t.Error("AbsInt8(5) should be 5") + } + + // Test SignInt8 + if SignInt8(-5) != -1 { + t.Error("SignInt8(-5) should be -1") + } + if SignInt8(5) != 1 { + t.Error("SignInt8(5) should be 1") + } + if SignInt8(0) != 0 { + t.Error("SignInt8(0) should be 0") + } + +} + +func TestInt16Helpers(t *testing.T) { + // Test MaxInt16 + if MaxInt16(1, 2) != 2 { + t.Error("MaxInt16(1, 2) should be 2") + } + if MaxInt16(-1, -2) != -1 { + t.Error("MaxInt16(-1, -2) should be -1") + } + + // Test MinInt16 + if MinInt16(1, 2) != 1 { + t.Error("MinInt16(1, 2) should be 1") + } + if MinInt16(-1, -2) != -2 { + t.Error("MinInt16(-1, -2) should be -2") + } + + // Test ClampInt16 + if ClampInt16(5, 1, 3) != 3 { + t.Error("ClampInt16(5, 1, 3) should be 3") + } + if ClampInt16(0, 1, 3) != 1 { + t.Error("ClampInt16(0, 1, 3) should be 1") + } + if ClampInt16(2, 1, 3) != 2 { + t.Error("ClampInt16(2, 1, 3) should be 2") + } + + // Test AbsInt16 + if AbsInt16(-5) != 5 { + t.Error("AbsInt16(-5) should be 5") + } + if AbsInt16(5) != 5 { + t.Error("AbsInt16(5) should be 5") + } + + // Test SignInt16 + if SignInt16(-5) != -1 { + t.Error("SignInt16(-5) should be -1") + } + if SignInt16(5) != 1 { + t.Error("SignInt16(5) should be 1") + } + if SignInt16(0) != 0 { + t.Error("SignInt16(0) should be 0") + } + +} + +func TestInt32Helpers(t *testing.T) { + // Test MaxInt32 + if MaxInt32(1, 2) != 2 { + t.Error("MaxInt32(1, 2) should be 2") + } + if MaxInt32(-1, -2) != -1 { + t.Error("MaxInt32(-1, -2) should be -1") + } + + // Test MinInt32 + if MinInt32(1, 2) != 1 { + t.Error("MinInt32(1, 2) should be 1") + } + if MinInt32(-1, -2) != -2 { + t.Error("MinInt32(-1, -2) should be -2") + } + + // Test ClampInt32 + if ClampInt32(5, 1, 3) != 3 { + t.Error("ClampInt32(5, 1, 3) should be 3") + } + if ClampInt32(0, 1, 3) != 1 { + t.Error("ClampInt32(0, 1, 3) should be 1") + } + if ClampInt32(2, 1, 3) != 2 { + t.Error("ClampInt32(2, 1, 3) should be 2") + } + + // Test AbsInt32 + if AbsInt32(-5) != 5 { + t.Error("AbsInt32(-5) should be 5") + } + if AbsInt32(5) != 5 { + t.Error("AbsInt32(5) should be 5") + } + + // Test SignInt32 + if SignInt32(-5) != -1 { + t.Error("SignInt32(-5) should be -1") + } + if SignInt32(5) != 1 { + t.Error("SignInt32(5) should be 1") + } + if SignInt32(0) != 0 { + t.Error("SignInt32(0) should be 0") + } + +} + +func TestInt64Helpers(t *testing.T) { + // Test MaxInt64 + if MaxInt64(1, 2) != 2 { + t.Error("MaxInt64(1, 2) should be 2") + } + if MaxInt64(-1, -2) != -1 { + t.Error("MaxInt64(-1, -2) should be -1") + } + + // Test MinInt64 + if MinInt64(1, 2) != 1 { + t.Error("MinInt64(1, 2) should be 1") + } + if MinInt64(-1, -2) != -2 { + t.Error("MinInt64(-1, -2) should be -2") + } + + // Test ClampInt64 + if ClampInt64(5, 1, 3) != 3 { + t.Error("ClampInt64(5, 1, 3) should be 3") + } + if ClampInt64(0, 1, 3) != 1 { + t.Error("ClampInt64(0, 1, 3) should be 1") + } + if ClampInt64(2, 1, 3) != 2 { + t.Error("ClampInt64(2, 1, 3) should be 2") + } + + // Test AbsInt64 + if AbsInt64(-5) != 5 { + t.Error("AbsInt64(-5) should be 5") + } + if AbsInt64(5) != 5 { + t.Error("AbsInt64(5) should be 5") + } + + // Test SignInt64 + if SignInt64(-5) != -1 { + t.Error("SignInt64(-5) should be -1") + } + if SignInt64(5) != 1 { + t.Error("SignInt64(5) should be 1") + } + if SignInt64(0) != 0 { + t.Error("SignInt64(0) should be 0") + } + +} + +func TestIntHelpers(t *testing.T) { + // Test MaxInt + if MaxInt(1, 2) != 2 { + t.Error("MaxInt(1, 2) should be 2") + } + if MaxInt(-1, -2) != -1 { + t.Error("MaxInt(-1, -2) should be -1") + } + + // Test MinInt + if MinInt(1, 2) != 1 { + t.Error("MinInt(1, 2) should be 1") + } + if MinInt(-1, -2) != -2 { + t.Error("MinInt(-1, -2) should be -2") + } + + // Test ClampInt + if ClampInt(5, 1, 3) != 3 { + t.Error("ClampInt(5, 1, 3) should be 3") + } + if ClampInt(0, 1, 3) != 1 { + t.Error("ClampInt(0, 1, 3) should be 1") + } + if ClampInt(2, 1, 3) != 2 { + t.Error("ClampInt(2, 1, 3) should be 2") + } + + // Test AbsInt + if AbsInt(-5) != 5 { + t.Error("AbsInt(-5) should be 5") + } + if AbsInt(5) != 5 { + t.Error("AbsInt(5) should be 5") + } + + // Test SignInt + if SignInt(-5) != -1 { + t.Error("SignInt(-5) should be -1") + } + if SignInt(5) != 1 { + t.Error("SignInt(5) should be 1") + } + if SignInt(0) != 0 { + t.Error("SignInt(0) should be 0") + } + +} + +func TestUint8Helpers(t *testing.T) { + // Test MaxUint8 + if MaxUint8(1, 2) != 2 { + t.Error("MaxUint8(1, 2) should be 2") + } + + // Test MinUint8 + if MinUint8(1, 2) != 1 { + t.Error("MinUint8(1, 2) should be 1") + } + + // Test ClampUint8 + if ClampUint8(5, 1, 3) != 3 { + t.Error("ClampUint8(5, 1, 3) should be 3") + } + if ClampUint8(0, 1, 3) != 1 { + t.Error("ClampUint8(0, 1, 3) should be 1") + } + if ClampUint8(2, 1, 3) != 2 { + t.Error("ClampUint8(2, 1, 3) should be 2") + } + +} + +func TestUint16Helpers(t *testing.T) { + // Test MaxUint16 + if MaxUint16(1, 2) != 2 { + t.Error("MaxUint16(1, 2) should be 2") + } + + // Test MinUint16 + if MinUint16(1, 2) != 1 { + t.Error("MinUint16(1, 2) should be 1") + } + + // Test ClampUint16 + if ClampUint16(5, 1, 3) != 3 { + t.Error("ClampUint16(5, 1, 3) should be 3") + } + if ClampUint16(0, 1, 3) != 1 { + t.Error("ClampUint16(0, 1, 3) should be 1") + } + if ClampUint16(2, 1, 3) != 2 { + t.Error("ClampUint16(2, 1, 3) should be 2") + } + +} + +func TestUint32Helpers(t *testing.T) { + // Test MaxUint32 + if MaxUint32(1, 2) != 2 { + t.Error("MaxUint32(1, 2) should be 2") + } + + // Test MinUint32 + if MinUint32(1, 2) != 1 { + t.Error("MinUint32(1, 2) should be 1") + } + + // Test ClampUint32 + if ClampUint32(5, 1, 3) != 3 { + t.Error("ClampUint32(5, 1, 3) should be 3") + } + if ClampUint32(0, 1, 3) != 1 { + t.Error("ClampUint32(0, 1, 3) should be 1") + } + if ClampUint32(2, 1, 3) != 2 { + t.Error("ClampUint32(2, 1, 3) should be 2") + } + +} + +func TestUint64Helpers(t *testing.T) { + // Test MaxUint64 + if MaxUint64(1, 2) != 2 { + t.Error("MaxUint64(1, 2) should be 2") + } + + // Test MinUint64 + if MinUint64(1, 2) != 1 { + t.Error("MinUint64(1, 2) should be 1") + } + + // Test ClampUint64 + if ClampUint64(5, 1, 3) != 3 { + t.Error("ClampUint64(5, 1, 3) should be 3") + } + if ClampUint64(0, 1, 3) != 1 { + t.Error("ClampUint64(0, 1, 3) should be 1") + } + if ClampUint64(2, 1, 3) != 2 { + t.Error("ClampUint64(2, 1, 3) should be 2") + } + +} + +func TestUintHelpers(t *testing.T) { + // Test MaxUint + if MaxUint(1, 2) != 2 { + t.Error("MaxUint(1, 2) should be 2") + } + + // Test MinUint + if MinUint(1, 2) != 1 { + t.Error("MinUint(1, 2) should be 1") + } + + // Test ClampUint + if ClampUint(5, 1, 3) != 3 { + t.Error("ClampUint(5, 1, 3) should be 3") + } + if ClampUint(0, 1, 3) != 1 { + t.Error("ClampUint(0, 1, 3) should be 1") + } + if ClampUint(2, 1, 3) != 2 { + t.Error("ClampUint(2, 1, 3) should be 2") + } + +} + +func TestFloat32Helpers(t *testing.T) { + // Test MaxFloat32 + if MaxFloat32(1, 2) != 2 { + t.Error("MaxFloat32(1, 2) should be 2") + } + if MaxFloat32(-1, -2) != -1 { + t.Error("MaxFloat32(-1, -2) should be -1") + } + + // Test MinFloat32 + if MinFloat32(1, 2) != 1 { + t.Error("MinFloat32(1, 2) should be 1") + } + if MinFloat32(-1, -2) != -2 { + t.Error("MinFloat32(-1, -2) should be -2") + } + + // Test ClampFloat32 + if ClampFloat32(5, 1, 3) != 3 { + t.Error("ClampFloat32(5, 1, 3) should be 3") + } + if ClampFloat32(0, 1, 3) != 1 { + t.Error("ClampFloat32(0, 1, 3) should be 1") + } + if ClampFloat32(2, 1, 3) != 2 { + t.Error("ClampFloat32(2, 1, 3) should be 2") + } + + // Test AbsFloat32 + if AbsFloat32(-5) != 5 { + t.Error("AbsFloat32(-5) should be 5") + } + if AbsFloat32(5) != 5 { + t.Error("AbsFloat32(5) should be 5") + } + + // Test SignFloat32 + if SignFloat32(-5) != -1 { + t.Error("SignFloat32(-5) should be -1") + } + if SignFloat32(5) != 1 { + t.Error("SignFloat32(5) should be 1") + } + if SignFloat32(0.0) != 0 { + t.Error("SignFloat32(0.0) should be 0") + } + +} + +func TestFloat64Helpers(t *testing.T) { + // Test MaxFloat64 + if MaxFloat64(1, 2) != 2 { + t.Error("MaxFloat64(1, 2) should be 2") + } + if MaxFloat64(-1, -2) != -1 { + t.Error("MaxFloat64(-1, -2) should be -1") + } + + // Test MinFloat64 + if MinFloat64(1, 2) != 1 { + t.Error("MinFloat64(1, 2) should be 1") + } + if MinFloat64(-1, -2) != -2 { + t.Error("MinFloat64(-1, -2) should be -2") + } + + // Test ClampFloat64 + if ClampFloat64(5, 1, 3) != 3 { + t.Error("ClampFloat64(5, 1, 3) should be 3") + } + if ClampFloat64(0, 1, 3) != 1 { + t.Error("ClampFloat64(0, 1, 3) should be 1") + } + if ClampFloat64(2, 1, 3) != 2 { + t.Error("ClampFloat64(2, 1, 3) should be 2") + } + + // Test AbsFloat64 + if AbsFloat64(-5) != 5 { + t.Error("AbsFloat64(-5) should be 5") + } + if AbsFloat64(5) != 5 { + t.Error("AbsFloat64(5) should be 5") + } + + // Test SignFloat64 + if SignFloat64(-5) != -1 { + t.Error("SignFloat64(-5) should be -1") + } + if SignFloat64(5) != 1 { + t.Error("SignFloat64(5) should be 1") + } + if SignFloat64(0.0) != 0 { + t.Error("SignFloat64(0.0) should be 0") + } + +} diff --git a/examples/gno.land/r/demo/bar20/bar20.gno b/examples/gno.land/r/demo/bar20/bar20.gno index 25636fcda78..52f1baa7408 100644 --- a/examples/gno.land/r/demo/bar20/bar20.gno +++ b/examples/gno.land/r/demo/bar20/bar20.gno @@ -18,8 +18,7 @@ var ( ) func init() { - getter := func() *grc20.Token { return Token } - grc20reg.Register(getter, "") + grc20reg.Register(Token.Getter(), "") } func Faucet() string { diff --git a/examples/gno.land/r/demo/btree_dao/btree_dao.gno b/examples/gno.land/r/demo/btree_dao/btree_dao.gno new file mode 100644 index 00000000000..c90742eb29b --- /dev/null +++ b/examples/gno.land/r/demo/btree_dao/btree_dao.gno @@ -0,0 +1,209 @@ +package btree_dao + +import ( + "errors" + "std" + "strings" + "time" + + "gno.land/p/demo/btree" + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/ufmt" + "gno.land/p/moul/md" +) + +// RegistrationDetails holds the details of a user's registration in the BTree DAO. +// It stores the user's address, registration time, their B-Tree if they planted one, +// and their NFT ID. +type RegistrationDetails struct { + Address std.Address + RegTime time.Time + UserBTree *btree.BTree + NFTID string +} + +// Less implements the btree.Record interface for RegistrationDetails. +// It compares two RegistrationDetails based on their registration time. +// Returns true if the current registration time is before the other registration time. +func (rd *RegistrationDetails) Less(than btree.Record) bool { + other := than.(*RegistrationDetails) + return rd.RegTime.Before(other.RegTime) +} + +var ( + dao = grc721.NewBasicNFT("BTree DAO", "BTDAO") + tokenID = 0 + members = btree.New() +) + +// PlantTree allows a user to plant their B-Tree in the DAO forest. +// It mints an NFT to the user and registers their tree in the DAO. +// Returns an error if the tree is already planted, empty, or if NFT minting fails. +func PlantTree(userBTree *btree.BTree) error { + return plantImpl(userBTree, "") +} + +// PlantSeed allows a user to register as a seed in the DAO with a message. +// It mints an NFT to the user and registers them as a seed member. +// Returns an error if the message is empty or if NFT minting fails. +func PlantSeed(message string) error { + return plantImpl(nil, message) +} + +// plantImpl is the internal implementation that handles both tree planting and seed registration. +// For tree planting (userBTree != nil), it verifies the tree isn't already planted and isn't empty. +// For seed planting (userBTree == nil), it verifies the seed message isn't empty. +// In both cases, it mints an NFT to the user and adds their registration details to the members tree. +// Returns an error if any validation fails or if NFT minting fails. +func plantImpl(userBTree *btree.BTree, seedMessage string) error { + // Get the caller's address + userAddress := std.GetOrigCaller() + + var nftID string + var regDetails *RegistrationDetails + + if userBTree != nil { + // Handle tree planting + var treeExists bool + members.Ascend(func(record btree.Record) bool { + regDetails := record.(*RegistrationDetails) + if regDetails.UserBTree == userBTree { + treeExists = true + return false + } + return true + }) + if treeExists { + return errors.New("tree is already planted in the forest") + } + + if userBTree.Len() == 0 { + return errors.New("cannot plant an empty tree") + } + + nftID = ufmt.Sprintf("%d", tokenID) + regDetails = &RegistrationDetails{ + Address: userAddress, + RegTime: time.Now(), + UserBTree: userBTree, + NFTID: nftID, + } + } else { + // Handle seed planting + if seedMessage == "" { + return errors.New("seed message cannot be empty") + } + nftID = "seed_" + ufmt.Sprintf("%d", tokenID) + regDetails = &RegistrationDetails{ + Address: userAddress, + RegTime: time.Now(), + UserBTree: nil, + NFTID: nftID, + } + } + + // Mint an NFT to the user + err := dao.Mint(userAddress, grc721.TokenID(nftID)) + if err != nil { + return err + } + + members.Insert(regDetails) + tokenID++ + return nil +} + +// Render generates a Markdown representation of the DAO members. +// It displays: +// - Total number of NFTs minted +// - Total number of members +// - Size of the biggest planted tree +// - The first 3 members (OGs) +// - The latest 10 members +// Each member entry includes their address and owned NFTs (🌳 for trees, 🌱 for seeds). +// The path parameter is currently unused. +// Returns a formatted Markdown string. +func Render(path string) string { + var latestMembers []string + var ogMembers []string + + // Get total size and first member + totalSize := members.Len() + biggestTree := 0 + if maxMember := members.Max(); maxMember != nil { + if userBTree := maxMember.(*RegistrationDetails).UserBTree; userBTree != nil { + biggestTree = userBTree.Len() + } + } + + // Collect the latest 10 members + members.Descend(func(record btree.Record) bool { + if len(latestMembers) < 10 { + regDetails := record.(*RegistrationDetails) + addr := regDetails.Address + nftList := "" + balance, err := dao.BalanceOf(addr) + if err == nil && balance > 0 { + nftList = " (NFTs: " + for i := uint64(0); i < balance; i++ { + if i > 0 { + nftList += ", " + } + if regDetails.UserBTree == nil { + nftList += "🌱#" + regDetails.NFTID + } else { + nftList += "🌳#" + regDetails.NFTID + } + } + nftList += ")" + } + latestMembers = append(latestMembers, string(addr)+nftList) + return true + } + return false + }) + + // Collect the first 3 members (OGs) + members.Ascend(func(record btree.Record) bool { + if len(ogMembers) < 3 { + regDetails := record.(*RegistrationDetails) + addr := regDetails.Address + nftList := "" + balance, err := dao.BalanceOf(addr) + if err == nil && balance > 0 { + nftList = " (NFTs: " + for i := uint64(0); i < balance; i++ { + if i > 0 { + nftList += ", " + } + if regDetails.UserBTree == nil { + nftList += "🌱#" + regDetails.NFTID + } else { + nftList += "🌳#" + regDetails.NFTID + } + } + nftList += ")" + } + ogMembers = append(ogMembers, string(addr)+nftList) + return true + } + return false + }) + + var sb strings.Builder + + sb.WriteString(md.H1("B-Tree DAO Members")) + sb.WriteString(md.H2("Total NFTs Minted")) + sb.WriteString(ufmt.Sprintf("Total NFTs minted: %d\n\n", dao.TokenCount())) + sb.WriteString(md.H2("Member Stats")) + sb.WriteString(ufmt.Sprintf("Total members: %d\n", totalSize)) + if biggestTree > 0 { + sb.WriteString(ufmt.Sprintf("Biggest tree size: %d\n", biggestTree)) + } + sb.WriteString(md.H2("OG Members")) + sb.WriteString(md.BulletList(ogMembers)) + sb.WriteString(md.H2("Latest Members")) + sb.WriteString(md.BulletList(latestMembers)) + + return sb.String() +} diff --git a/examples/gno.land/r/demo/btree_dao/btree_dao_test.gno b/examples/gno.land/r/demo/btree_dao/btree_dao_test.gno new file mode 100644 index 00000000000..0514f52f7b4 --- /dev/null +++ b/examples/gno.land/r/demo/btree_dao/btree_dao_test.gno @@ -0,0 +1,97 @@ +package btree_dao + +import ( + "std" + "strings" + "testing" + "time" + + "gno.land/p/demo/btree" + "gno.land/p/demo/uassert" + "gno.land/p/demo/urequire" +) + +func setupTest() { + std.TestSetOrigCaller(std.Address("g1ej0qca5ptsw9kfr64ey8jvfy9eacga6mpj2z0y")) + members = btree.New() +} + +type TestElement struct { + value int +} + +func (te *TestElement) Less(than btree.Record) bool { + return te.value < than.(*TestElement).value +} + +func TestPlantTree(t *testing.T) { + setupTest() + + tree := btree.New() + elements := []int{30, 10, 50, 20, 40} + for _, val := range elements { + tree.Insert(&TestElement{value: val}) + } + + err := PlantTree(tree) + urequire.NoError(t, err) + + found := false + members.Ascend(func(record btree.Record) bool { + regDetails := record.(*RegistrationDetails) + if regDetails.UserBTree == tree { + found = true + return false + } + return true + }) + uassert.True(t, found) + + err = PlantTree(tree) + uassert.Error(t, err) + + emptyTree := btree.New() + err = PlantTree(emptyTree) + uassert.Error(t, err) +} + +func TestPlantSeed(t *testing.T) { + setupTest() + + err := PlantSeed("Hello DAO!") + urequire.NoError(t, err) + + found := false + members.Ascend(func(record btree.Record) bool { + regDetails := record.(*RegistrationDetails) + if regDetails.UserBTree == nil { + found = true + uassert.NotEmpty(t, regDetails.NFTID) + uassert.True(t, strings.Contains(regDetails.NFTID, "seed_")) + return false + } + return true + }) + uassert.True(t, found) + + err = PlantSeed("") + uassert.Error(t, err) +} + +func TestRegistrationDetailsOrdering(t *testing.T) { + setupTest() + + rd1 := &RegistrationDetails{ + Address: std.Address("test1"), + RegTime: time.Now(), + NFTID: "0", + } + rd2 := &RegistrationDetails{ + Address: std.Address("test2"), + RegTime: time.Now().Add(time.Hour), + NFTID: "1", + } + + uassert.True(t, rd1.Less(rd2)) + uassert.False(t, rd2.Less(rd1)) +} diff --git a/examples/gno.land/r/demo/btree_dao/gno.mod b/examples/gno.land/r/demo/btree_dao/gno.mod new file mode 100644 index 00000000000..01b99acc300 --- /dev/null +++ b/examples/gno.land/r/demo/btree_dao/gno.mod @@ -0,0 +1 @@ +module gno.land/r/demo/btree_dao diff --git a/examples/gno.land/r/demo/foo20/foo20.gno b/examples/gno.land/r/demo/foo20/foo20.gno index 5c7d7f12b99..6522fbdc90e 100644 --- a/examples/gno.land/r/demo/foo20/foo20.gno +++ b/examples/gno.land/r/demo/foo20/foo20.gno @@ -22,8 +22,7 @@ var ( func init() { privateLedger.Mint(Ownable.Owner(), 1_000_000*10_000) // @privateLedgeristrator (1M) - getter := func() *grc20.Token { return Token } - grc20reg.Register(getter, "") + grc20reg.Register(Token.Getter(), "") } func TotalSupply() uint64 { diff --git a/examples/gno.land/r/demo/grc20factory/grc20factory.gno b/examples/gno.land/r/demo/grc20factory/grc20factory.gno index 58874409d7f..aa91084ab32 100644 --- a/examples/gno.land/r/demo/grc20factory/grc20factory.gno +++ b/examples/gno.land/r/demo/grc20factory/grc20factory.gno @@ -43,8 +43,7 @@ func NewWithAdmin(name, symbol string, decimals uint, initialMint, faucet uint64 faucet: faucet, } instances.Set(symbol, &inst) - getter := func() *grc20.Token { return token } - grc20reg.Register(getter, symbol) + grc20reg.Register(token.Getter(), symbol) } func (inst instance) Token() *grc20.Token { diff --git a/examples/gno.land/r/demo/tests/test20/gno.mod b/examples/gno.land/r/demo/tests/test20/gno.mod new file mode 100644 index 00000000000..7a71668d2df --- /dev/null +++ b/examples/gno.land/r/demo/tests/test20/gno.mod @@ -0,0 +1 @@ +module gno.land/r/demo/tests/test20 diff --git a/examples/gno.land/r/demo/tests/test20/test20.gno b/examples/gno.land/r/demo/tests/test20/test20.gno new file mode 100644 index 00000000000..9c4df58d1c4 --- /dev/null +++ b/examples/gno.land/r/demo/tests/test20/test20.gno @@ -0,0 +1,20 @@ +// Package test20 implements a deliberately insecure ERC20 token for testing purposes. +// The Test20 token allows anyone to mint any amount of tokens to any address, making +// it unsuitable for production use. The primary goal of this package is to facilitate +// testing and experimentation without any security measures or restrictions. +// +// WARNING: This token is highly insecure and should not be used in any +// production environment. It is intended solely for testing and +// educational purposes. +package test20 + +import ( + "gno.land/p/demo/grc/grc20" + "gno.land/r/demo/grc20reg" +) + +var Token, PrivateLedger = grc20.NewToken("Test20", "TST", 4) + +func init() { + grc20reg.Register(Token.Getter(), "") +} diff --git a/examples/gno.land/r/demo/wugnot/wugnot.gno b/examples/gno.land/r/demo/wugnot/wugnot.gno index 09538b860ca..b72f5161e7d 100644 --- a/examples/gno.land/r/demo/wugnot/wugnot.gno +++ b/examples/gno.land/r/demo/wugnot/wugnot.gno @@ -19,8 +19,7 @@ const ( ) func init() { - getter := func() *grc20.Token { return Token } - grc20reg.Register(getter, "") + grc20reg.Register(Token.Getter(), "") } func Deposit() { diff --git a/examples/gno.land/r/nemanya/config/config.gno b/examples/gno.land/r/nemanya/config/config.gno new file mode 100644 index 00000000000..795e48c94c1 --- /dev/null +++ b/examples/gno.land/r/nemanya/config/config.gno @@ -0,0 +1,63 @@ +package config + +import ( + "errors" + "std" +) + +var ( + main std.Address + backup std.Address + + ErrInvalidAddr = errors.New("Invalid address") + ErrUnauthorized = errors.New("Unauthorized") +) + +func init() { + main = "g1x9qyf6f34v2g52k4q5smn5tctmj3hl2kj7l2ql" +} + +func Address() std.Address { + return main +} + +func Backup() std.Address { + return backup +} + +func SetAddress(a std.Address) error { + if !a.IsValid() { + return ErrInvalidAddr + } + + if err := checkAuthorized(); err != nil { + return err + } + + main = a + return nil +} + +func SetBackup(a std.Address) error { + if !a.IsValid() { + return ErrInvalidAddr + } + + if err := checkAuthorized(); err != nil { + return err + } + + backup = a + return nil +} + +func checkAuthorized() error { + caller := std.PrevRealm().Addr() + isAuthorized := caller == main || caller == backup + + if !isAuthorized { + return ErrUnauthorized + } + + return nil +} diff --git a/examples/gno.land/r/nemanya/config/gno.mod b/examples/gno.land/r/nemanya/config/gno.mod new file mode 100644 index 00000000000..4388b5bd525 --- /dev/null +++ b/examples/gno.land/r/nemanya/config/gno.mod @@ -0,0 +1 @@ +module gno.land/r/nemanya/config diff --git a/examples/gno.land/r/nemanya/home/gno.mod b/examples/gno.land/r/nemanya/home/gno.mod new file mode 100644 index 00000000000..d0220197489 --- /dev/null +++ b/examples/gno.land/r/nemanya/home/gno.mod @@ -0,0 +1 @@ +module gno.land/r/nemanya/home diff --git a/examples/gno.land/r/nemanya/home/home.gno b/examples/gno.land/r/nemanya/home/home.gno new file mode 100644 index 00000000000..08e24baecfd --- /dev/null +++ b/examples/gno.land/r/nemanya/home/home.gno @@ -0,0 +1,280 @@ +package home + +import ( + "std" + "strings" + + "gno.land/p/demo/ufmt" + "gno.land/r/nemanya/config" +) + +type SocialLink struct { + URL string + Text string +} + +type Sponsor struct { + Address std.Address + Amount std.Coins +} + +type Project struct { + Name string + Description string + URL string + ImageURL string + Sponsors map[std.Address]Sponsor +} + +var ( + textArt string + aboutMe string + sponsorInfo string + socialLinks map[string]SocialLink + gnoProjects map[string]Project + otherProjects map[string]Project + totalDonations std.Coins +) + +func init() { + textArt = renderTextArt() + aboutMe = "I am a student of IT at Faculty of Sciences in Novi Sad, Serbia. My background is mainly in web and low-level programming, but since Web3 Bootcamp at Petnica this year I've been actively learning about blockchain and adjacent technologies. I am excited about contributing to the gno.land ecosystem and learning from the community.\n\n" + sponsorInfo = "You can sponsor a project by sending GNOT to this address. Your sponsorship will be displayed on the project page. Thank you for supporting the development of gno.land!\n\n" + + socialLinks = map[string]SocialLink{ + "GitHub": {URL: "https://github.com/Nemanya8", Text: "Explore my repositories and open-source contributions."}, + "LinkedIn": {URL: "https://www.linkedin.com/in/nemanjamatic/", Text: "Connect with me professionally."}, + "Email Me": {URL: "mailto:matic.nemanya@gmail.com", Text: "Reach out for collaboration or inquiries."}, + } + + gnoProjects = make(map[string]Project) + otherProjects = make(map[string]Project) + + gnoProjects["Liberty Bridge"] = Project{ + Name: "Liberty Bridge", + Description: "Liberty Bridge was my first Web3 project, developed as part of the Web3 Bootcamp at Petnica. This project served as a centralized bridge between Ethereum and gno.land, enabling seamless asset transfers and fostering interoperability between the two ecosystems.\n\n The primary objective of Liberty Bridge was to address the challenges of connecting decentralized networks by implementing a user-friendly solution that simplified the process for users. The project incorporated mechanisms to securely transfer assets between the Ethereum and gno.land blockchains, ensuring efficiency and reliability while maintaining a centralized framework for governance and operations.\n\n Through this project, I gained hands-on knowledge of blockchain interoperability, Web3 protocols, and the intricacies of building solutions that bridge different blockchain ecosystems.\n\n", + URL: "https://gno.land", + ImageURL: "https://github.com/Milosevic02/LibertyBridge/raw/main/lb_banner.png", + Sponsors: make(map[std.Address]Sponsor), + } + + otherProjects["Incognito"] = Project{ + Name: "Incognito", + Description: "Incognito is a Web3 platform built for Ethereum-based chains, designed to connect advertisers with users in a privacy-first and mutually beneficial way. Its modular architecture makes it easily expandable to other blockchains. Developed during the ETH Sofia Hackathon, it was recognized as a winning project for its innovation and impact.\n\n The platform allows advertisers to send personalized ads while sharing a portion of the marketing budget with users. It uses machine learning to match users based on wallet activity, ensuring precise targeting. User emails are stored securely on-chain and never shared, prioritizing privacy and transparency.\n\n With all campaign data stored on-chain, Incognito ensures decentralization and accountability. By rewarding users and empowering advertisers, it sets a new standard for fair and transparent blockchain-based advertising.", + URL: "https://github.com/Milosevic02/Incognito-ETHSofia", + ImageURL: "", + Sponsors: make(map[std.Address]Sponsor), + } +} + +func Render(path string) string { + var sb strings.Builder + sb.WriteString("# Hi, I'm\n") + sb.WriteString(textArt) + sb.WriteString("---\n") + sb.WriteString("## About me\n") + sb.WriteString(aboutMe) + sb.WriteString(sponsorInfo) + sb.WriteString(ufmt.Sprintf("# Total Sponsor Donations: %s\n", totalDonations.String())) + sb.WriteString("---\n") + sb.WriteString(renderProjects(gnoProjects, "Gno Projects")) + sb.WriteString("---\n") + sb.WriteString(renderProjects(otherProjects, "Other Projects")) + sb.WriteString("---\n") + sb.WriteString(renderSocialLinks()) + + return sb.String() +} + +func renderTextArt() string { + var sb strings.Builder + sb.WriteString("```\n") + sb.WriteString(" ___ ___ ___ ___ ___ ___ ___ \n") + sb.WriteString(" /\\__\\ /\\ \\ /\\__\\ /\\ \\ /\\__\\ |\\__\\ /\\ \\ \n") + sb.WriteString(" /::| | /::\\ \\ /::| | /::\\ \\ /::| | |:| | /::\\ \\ \n") + sb.WriteString(" /:|:| | /:/\\:\\ \\ /:|:| | /:/\\:\\ \\ /:|:| | |:| | /:/\\:\\ \\ \n") + sb.WriteString(" /:/|:| |__ /::\\~\\:\\ \\ /:/|:|__|__ /::\\~\\:\\ \\ /:/|:| |__ |:|__|__ /::\\~\\:\\ \\ \n") + sb.WriteString(" /:/ |:| /\\__\\ /:/\\:\\ \\:\\__\\ /:/ |::::\\__\\ /:/\\:\\ \\:\\__\\ /:/ |:| /\\__\\ /::::\\__\\ /:/\\:\\ \\:\\__\\\n") + sb.WriteString(" \\/__|:|/:/ / \\:\\~\\:\\ \\/__/ \\/__/~~/:/ / \\/__\\:\\/:/ / \\/__|:|/:/ / /:/~~/~ \\/__\\:\\/:/ / \n") + sb.WriteString(" |:/:/ / \\:\\ \\:\\__\\ /:/ / \\::/ / |:/:/ / /:/ / \\::/ / \n") + sb.WriteString(" |::/ / \\:\\ \\/__/ /:/ / /:/ / |::/ / \\/__/ /:/ / \n") + sb.WriteString(" /:/ / \\:\\__\\ /:/ / /:/ / /:/ / /:/ / \n") + sb.WriteString(" \\/__/ \\/__/ \\/__/ \\/__/ \\/__/ \\/__/ \n") + sb.WriteString("\n```\n") + return sb.String() +} + +func renderSocialLinks() string { + var sb strings.Builder + sb.WriteString("## Links\n\n") + sb.WriteString("You can find me here:\n\n") + sb.WriteString(ufmt.Sprintf("- [GitHub](%s) - %s\n", socialLinks["GitHub"].URL, socialLinks["GitHub"].Text)) + sb.WriteString(ufmt.Sprintf("- [LinkedIn](%s) - %s\n", socialLinks["LinkedIn"].URL, socialLinks["LinkedIn"].Text)) + sb.WriteString(ufmt.Sprintf("- [Email Me](%s) - %s\n", socialLinks["Email Me"].URL, socialLinks["Email Me"].Text)) + sb.WriteString("\n") + return sb.String() +} + +func renderProjects(projectsMap map[string]Project, title string) string { + var sb strings.Builder + sb.WriteString(ufmt.Sprintf("## %s\n\n", title)) + for _, project := range projectsMap { + if project.ImageURL != "" { + sb.WriteString(ufmt.Sprintf("![%s](%s)\n\n", project.Name, project.ImageURL)) + } + sb.WriteString(ufmt.Sprintf("### [%s](%s)\n\n", project.Name, project.URL)) + sb.WriteString(project.Description + "\n\n") + + if len(project.Sponsors) > 0 { + sb.WriteString(ufmt.Sprintf("#### %s Sponsors\n", project.Name)) + for _, sponsor := range project.Sponsors { + sb.WriteString(ufmt.Sprintf("- %s: %s\n", sponsor.Address.String(), sponsor.Amount.String())) + } + sb.WriteString("\n") + } + } + return sb.String() +} + +func UpdateLink(name, newURL string) { + if !isAuthorized(std.PrevRealm().Addr()) { + panic(config.ErrUnauthorized) + } + + if _, exists := socialLinks[name]; !exists { + panic("Link with the given name does not exist") + } + + socialLinks[name] = SocialLink{ + URL: newURL, + Text: socialLinks[name].Text, + } +} + +func UpdateAboutMe(text string) { + if !isAuthorized(std.PrevRealm().Addr()) { + panic(config.ErrUnauthorized) + } + + aboutMe = text +} + +func AddGnoProject(name, description, url, imageURL string) { + if !isAuthorized(std.PrevRealm().Addr()) { + panic(config.ErrUnauthorized) + } + project := Project{ + Name: name, + Description: description, + URL: url, + ImageURL: imageURL, + Sponsors: make(map[std.Address]Sponsor), + } + gnoProjects[name] = project +} + +func DeleteGnoProject(projectName string) { + if !isAuthorized(std.PrevRealm().Addr()) { + panic(config.ErrUnauthorized) + } + + if _, exists := gnoProjects[projectName]; !exists { + panic("Project not found") + } + + delete(gnoProjects, projectName) +} + +func AddOtherProject(name, description, url, imageURL string) { + if !isAuthorized(std.PrevRealm().Addr()) { + panic(config.ErrUnauthorized) + } + project := Project{ + Name: name, + Description: description, + URL: url, + ImageURL: imageURL, + Sponsors: make(map[std.Address]Sponsor), + } + otherProjects[name] = project +} + +func RemoveOtherProject(projectName string) { + if !isAuthorized(std.PrevRealm().Addr()) { + panic(config.ErrUnauthorized) + } + + if _, exists := otherProjects[projectName]; !exists { + panic("Project not found") + } + + delete(otherProjects, projectName) +} + +func isAuthorized(addr std.Address) bool { + return addr == config.Address() || addr == config.Backup() +} + +func SponsorGnoProject(projectName string) { + address := std.GetOrigCaller() + amount := std.GetOrigSend() + + if amount.AmountOf("ugnot") == 0 { + panic("Donation must include GNOT") + } + + project, exists := gnoProjects[projectName] + if !exists { + panic("Gno project not found") + } + + project.Sponsors[address] = Sponsor{ + Address: address, + Amount: project.Sponsors[address].Amount.Add(amount), + } + + totalDonations = totalDonations.Add(amount) + + gnoProjects[projectName] = project +} + +func SponsorOtherProject(projectName string) { + address := std.GetOrigCaller() + amount := std.GetOrigSend() + + if amount.AmountOf("ugnot") == 0 { + panic("Donation must include GNOT") + } + + project, exists := otherProjects[projectName] + if !exists { + panic("Other project not found") + } + + project.Sponsors[address] = Sponsor{ + Address: address, + Amount: project.Sponsors[address].Amount.Add(amount), + } + + totalDonations = totalDonations.Add(amount) + + otherProjects[projectName] = project +} + +func Withdraw() string { + if !isAuthorized(std.PrevRealm().Addr()) { + panic(config.ErrUnauthorized) + } + + banker := std.GetBanker(std.BankerTypeRealmSend) + realmAddress := std.GetOrigPkgAddr() + coins := banker.GetCoins(realmAddress) + + if len(coins) == 0 { + return "No coins available to withdraw" + } + + banker.SendCoins(realmAddress, config.Address(), coins) + + return "Successfully withdrew all coins to config address" +} diff --git a/examples/gno.land/r/stefann/home/home.gno b/examples/gno.land/r/stefann/home/home.gno index 9586f377311..f54721ce37c 100644 --- a/examples/gno.land/r/stefann/home/home.gno +++ b/examples/gno.land/r/stefann/home/home.gno @@ -8,6 +8,8 @@ import ( "gno.land/p/demo/avl" "gno.land/p/demo/ownable" "gno.land/p/demo/ufmt" + "gno.land/r/demo/users" + "gno.land/r/leon/hof" "gno.land/r/stefann/registry" ) @@ -23,7 +25,6 @@ type Sponsor struct { } type Profile struct { - pfp string aboutMe []string } @@ -49,15 +50,15 @@ var ( func init() { owner = ownable.NewWithAddress(registry.MainAddr()) + hof.Register() profile = Profile{ - pfp: "https://i.ibb.co/Bc5YNCx/DSC-0095a.jpg", aboutMe: []string{ - `### About Me`, - `Hey there! I’m Stefan, a student of Computer Science. I’m all about exploring and adventure — whether it’s diving into the latest tech or discovering a new city, I’m always up for the challenge!`, + `## About Me`, + `### Hey there! I’m Stefan, a student of Computer Science. I’m all about exploring and adventure — whether it’s diving into the latest tech or discovering a new city, I’m always up for the challenge!`, - `### Contributions`, - `I'm just getting started, but you can follow my journey through gno.land right [here](https://github.com/gnolang/hackerspace/issues/94) 🔗`, + `## Contributions`, + `### I'm just getting started, but you can follow my journey through gno.land right [here](https://github.com/gnolang/hackerspace/issues/94) 🔗`, }, } @@ -83,7 +84,7 @@ func init() { } sponsorship = Sponsorship{ - maxSponsors: 5, + maxSponsors: 3, sponsors: avl.NewTree(), DonationsCount: 0, sponsorsCount: 0, @@ -106,11 +107,6 @@ func UpdateJarLink(newLink string) { travel.jarLink = newLink } -func UpdatePFP(url string) { - owner.AssertCallerIsOwner() - profile.pfp = url -} - func UpdateAboutMe(aboutMeStr string) { owner.AssertCallerIsOwner() profile.aboutMe = strings.Split(aboutMeStr, "|") @@ -203,46 +199,27 @@ func Render(path string) string { } func renderAboutMe() string { - out := "
" - - out += "
\n\n" + out := "" - out += ufmt.Sprintf("
\n\n", travel.cities[travel.currentCityIndex%len(travel.cities)].URL) - - out += ufmt.Sprintf("my profile pic\n\n", profile.pfp) - - out += "
\n\n" + out += ufmt.Sprintf("![Current Location](%s)\n\n", travel.cities[travel.currentCityIndex%len(travel.cities)].URL) for _, rows := range profile.aboutMe { - out += "
\n\n" out += rows + "\n\n" - out += "
\n\n" } - out += "
\n\n" - return out } func renderTips() string { - out := `
` + "\n\n" + out := "# Help Me Travel The World\n\n" - out += `
` + "\n" + out += ufmt.Sprintf("## I am currently in %s, tip the jar to send me somewhere else!\n\n", travel.cities[travel.currentCityIndex].Name) + out += "### **Click** the jar, **tip** in GNOT coins, and **watch** my background change as I head to a new adventure!\n\n" - out += `

Help Me Travel The World

` + "\n\n" - - out += renderTipsJar() + "\n" - - out += ufmt.Sprintf(`I am currently in %s,
tip the jar to send me somewhere else!
`, travel.cities[travel.currentCityIndex].Name) - - out += `
Click the jar, tip in GNOT coins, and watch my background change as I head to a new adventure!

` + "\n\n" + out += renderTipsJar() + "\n\n" out += renderSponsors() - out += `
` + "\n\n" - - out += `
` + "\n" - return out } @@ -253,11 +230,27 @@ func formatAddress(address string) string { return address[:4] + "..." + address[len(address)-4:] } +func getDisplayName(addr std.Address) string { + if user := users.GetUserByAddress(addr); user != nil { + return user.Name + } + return formatAddress(addr.String()) +} + +func formatAmount(amount std.Coins) string { + ugnot := amount.AmountOf("ugnot") + if ugnot >= 1000000 { + gnot := float64(ugnot) / 1000000 + return ufmt.Sprintf("`%v`*GNOT*", gnot) + } + return ufmt.Sprintf("`%d`*ugnot*", ugnot) +} + func renderSponsors() string { - out := `

Sponsor Leaderboard

` + "\n" + out := "## Sponsor Leaderboard\n\n" if sponsorship.sponsorsCount == 0 { - return out + `

No sponsors yet. Be the first to tip the jar!

` + "\n" + return out + "No sponsors yet. Be the first to tip the jar!\n" } topSponsors := GetTopSponsors() @@ -266,38 +259,30 @@ func renderSponsors() string { numSponsors = sponsorship.maxSponsors } - out += `
    ` + "\n" - for i := 0; i < numSponsors; i++ { sponsor := topSponsors[i] - isLastItem := (i == numSponsors-1) - - padding := "10px 5px" - border := "border-bottom: 1px solid #ddd;" - - if isLastItem { - padding = "8px 5px" - border = "" + position := "" + switch i { + case 0: + position = "🥇" + case 1: + position = "🥈" + case 2: + position = "🥉" + default: + position = ufmt.Sprintf("%d.", i+1) } - out += ufmt.Sprintf( - `
  • - %d. %s - %s -
  • `, - padding, border, i+1, formatAddress(sponsor.Address.String()), sponsor.Amount.String(), + out += ufmt.Sprintf("%s **%s** - %s\n\n", + position, + getDisplayName(sponsor.Address), + formatAmount(sponsor.Amount), ) } - return out + return out + "\n" } func renderTipsJar() string { - out := ufmt.Sprintf(``, travel.jarLink) + "\n" - - out += `Tips Jar` + "\n" - - out += `` + "\n" - - return out + return ufmt.Sprintf("[![Tips Jar](https://i.ibb.co/4TH9zbw/tips-jar.png)](%s)", travel.jarLink) } diff --git a/examples/gno.land/r/stefann/home/home_test.gno b/examples/gno.land/r/stefann/home/home_test.gno index ca146b9eb13..b8ea88670a6 100644 --- a/examples/gno.land/r/stefann/home/home_test.gno +++ b/examples/gno.land/r/stefann/home/home_test.gno @@ -9,19 +9,6 @@ import ( "gno.land/p/demo/testutils" ) -func TestUpdatePFP(t *testing.T) { - var owner = std.Address("g1sd5ezmxt4rwpy52u6wl3l3y085n8x0p6nllxm8") - std.TestSetOrigCaller(owner) - - profile.pfp = "" - - UpdatePFP("https://example.com/pic.png") - - if profile.pfp != "https://example.com/pic.png" { - t.Fatalf("expected pfp to be https://example.com/pic.png, got %s", profile.pfp) - } -} - func TestUpdateAboutMe(t *testing.T) { var owner = std.Address("g1sd5ezmxt4rwpy52u6wl3l3y085n8x0p6nllxm8") std.TestSetOrigCaller(owner) diff --git a/gno.land/cmd/gnoland/config_get_test.go b/gno.land/cmd/gnoland/config_get_test.go index f2ddc5ca6d0..84cf0ba3d37 100644 --- a/gno.land/cmd/gnoland/config_get_test.go +++ b/gno.land/cmd/gnoland/config_get_test.go @@ -289,14 +289,6 @@ func TestConfig_Get_Base(t *testing.T) { }, true, }, - { - "filter peers flag fetched", - "filter_peers", - func(loadedCfg *config.Config, value []byte) { - assert.Equal(t, loadedCfg.FilterPeers, unmarshalJSONCommon[bool](t, value)) - }, - false, - }, } verifyGetTestTableCommon(t, testTable) @@ -616,19 +608,11 @@ func TestConfig_Get_P2P(t *testing.T) { }, true, }, - { - "upnp toggle", - "p2p.upnp", - func(loadedCfg *config.Config, value []byte) { - assert.Equal(t, loadedCfg.P2P.UPNP, unmarshalJSONCommon[bool](t, value)) - }, - false, - }, { "max inbound peers", "p2p.max_num_inbound_peers", func(loadedCfg *config.Config, value []byte) { - assert.Equal(t, loadedCfg.P2P.MaxNumInboundPeers, unmarshalJSONCommon[int](t, value)) + assert.Equal(t, loadedCfg.P2P.MaxNumInboundPeers, unmarshalJSONCommon[uint64](t, value)) }, false, }, @@ -636,7 +620,7 @@ func TestConfig_Get_P2P(t *testing.T) { "max outbound peers", "p2p.max_num_outbound_peers", func(loadedCfg *config.Config, value []byte) { - assert.Equal(t, loadedCfg.P2P.MaxNumOutboundPeers, unmarshalJSONCommon[int](t, value)) + assert.Equal(t, loadedCfg.P2P.MaxNumOutboundPeers, unmarshalJSONCommon[uint64](t, value)) }, false, }, @@ -676,15 +660,7 @@ func TestConfig_Get_P2P(t *testing.T) { "pex reactor toggle", "p2p.pex", func(loadedCfg *config.Config, value []byte) { - assert.Equal(t, loadedCfg.P2P.PexReactor, unmarshalJSONCommon[bool](t, value)) - }, - false, - }, - { - "seed mode", - "p2p.seed_mode", - func(loadedCfg *config.Config, value []byte) { - assert.Equal(t, loadedCfg.P2P.SeedMode, unmarshalJSONCommon[bool](t, value)) + assert.Equal(t, loadedCfg.P2P.PeerExchange, unmarshalJSONCommon[bool](t, value)) }, false, }, @@ -704,30 +680,6 @@ func TestConfig_Get_P2P(t *testing.T) { }, true, }, - { - "allow duplicate IP", - "p2p.allow_duplicate_ip", - func(loadedCfg *config.Config, value []byte) { - assert.Equal(t, loadedCfg.P2P.AllowDuplicateIP, unmarshalJSONCommon[bool](t, value)) - }, - false, - }, - { - "handshake timeout", - "p2p.handshake_timeout", - func(loadedCfg *config.Config, value []byte) { - assert.Equal(t, loadedCfg.P2P.HandshakeTimeout, unmarshalJSONCommon[time.Duration](t, value)) - }, - false, - }, - { - "dial timeout", - "p2p.dial_timeout", - func(loadedCfg *config.Config, value []byte) { - assert.Equal(t, loadedCfg.P2P.DialTimeout, unmarshalJSONCommon[time.Duration](t, value)) - }, - false, - }, } verifyGetTestTableCommon(t, testTable) diff --git a/gno.land/cmd/gnoland/config_set_test.go b/gno.land/cmd/gnoland/config_set_test.go index cb831f0e502..39880313043 100644 --- a/gno.land/cmd/gnoland/config_set_test.go +++ b/gno.land/cmd/gnoland/config_set_test.go @@ -244,19 +244,6 @@ func TestConfig_Set_Base(t *testing.T) { assert.Equal(t, value, loadedCfg.ProfListenAddress) }, }, - { - "filter peers flag updated", - []string{ - "filter_peers", - "true", - }, - func(loadedCfg *config.Config, value string) { - boolVal, err := strconv.ParseBool(value) - require.NoError(t, err) - - assert.Equal(t, boolVal, loadedCfg.FilterPeers) - }, - }, } verifySetTestTableCommon(t, testTable) @@ -505,19 +492,6 @@ func TestConfig_Set_P2P(t *testing.T) { assert.Equal(t, value, loadedCfg.P2P.PersistentPeers) }, }, - { - "upnp toggle updated", - []string{ - "p2p.upnp", - "false", - }, - func(loadedCfg *config.Config, value string) { - boolVal, err := strconv.ParseBool(value) - require.NoError(t, err) - - assert.Equal(t, boolVal, loadedCfg.P2P.UPNP) - }, - }, { "max inbound peers updated", []string{ @@ -588,20 +562,7 @@ func TestConfig_Set_P2P(t *testing.T) { boolVal, err := strconv.ParseBool(value) require.NoError(t, err) - assert.Equal(t, boolVal, loadedCfg.P2P.PexReactor) - }, - }, - { - "seed mode updated", - []string{ - "p2p.seed_mode", - "false", - }, - func(loadedCfg *config.Config, value string) { - boolVal, err := strconv.ParseBool(value) - require.NoError(t, err) - - assert.Equal(t, boolVal, loadedCfg.P2P.SeedMode) + assert.Equal(t, boolVal, loadedCfg.P2P.PeerExchange) }, }, { @@ -614,39 +575,6 @@ func TestConfig_Set_P2P(t *testing.T) { assert.Equal(t, value, loadedCfg.P2P.PrivatePeerIDs) }, }, - { - "allow duplicate IPs updated", - []string{ - "p2p.allow_duplicate_ip", - "false", - }, - func(loadedCfg *config.Config, value string) { - boolVal, err := strconv.ParseBool(value) - require.NoError(t, err) - - assert.Equal(t, boolVal, loadedCfg.P2P.AllowDuplicateIP) - }, - }, - { - "handshake timeout updated", - []string{ - "p2p.handshake_timeout", - "1s", - }, - func(loadedCfg *config.Config, value string) { - assert.Equal(t, value, loadedCfg.P2P.HandshakeTimeout.String()) - }, - }, - { - "dial timeout updated", - []string{ - "p2p.dial_timeout", - "1s", - }, - func(loadedCfg *config.Config, value string) { - assert.Equal(t, value, loadedCfg.P2P.DialTimeout.String()) - }, - }, } verifySetTestTableCommon(t, testTable) diff --git a/gno.land/cmd/gnoland/secrets_common.go b/gno.land/cmd/gnoland/secrets_common.go index d40e90f6b48..500336e3489 100644 --- a/gno.land/cmd/gnoland/secrets_common.go +++ b/gno.land/cmd/gnoland/secrets_common.go @@ -8,7 +8,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/amino" "github.com/gnolang/gno/tm2/pkg/bft/privval" "github.com/gnolang/gno/tm2/pkg/crypto" - "github.com/gnolang/gno/tm2/pkg/p2p" + "github.com/gnolang/gno/tm2/pkg/p2p/types" ) var ( @@ -54,7 +54,7 @@ func isValidDirectory(dirPath string) bool { } type secretData interface { - privval.FilePVKey | privval.FilePVLastSignState | p2p.NodeKey + privval.FilePVKey | privval.FilePVLastSignState | types.NodeKey } // readSecretData reads the secret data from the given path @@ -145,7 +145,7 @@ func validateValidatorStateSignature( } // validateNodeKey validates the node's p2p key -func validateNodeKey(key *p2p.NodeKey) error { +func validateNodeKey(key *types.NodeKey) error { if key.PrivKey == nil { return errInvalidNodeKey } diff --git a/gno.land/cmd/gnoland/secrets_common_test.go b/gno.land/cmd/gnoland/secrets_common_test.go index 34592c3bd8f..38c4772c705 100644 --- a/gno.land/cmd/gnoland/secrets_common_test.go +++ b/gno.land/cmd/gnoland/secrets_common_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/gnolang/gno/tm2/pkg/crypto" - "github.com/gnolang/gno/tm2/pkg/p2p" + "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -26,7 +26,7 @@ func TestCommon_SaveReadData(t *testing.T) { t.Run("invalid data read path", func(t *testing.T) { t.Parallel() - readData, err := readSecretData[p2p.NodeKey]("") + readData, err := readSecretData[types.NodeKey]("") assert.Nil(t, readData) assert.ErrorContains( @@ -44,7 +44,7 @@ func TestCommon_SaveReadData(t *testing.T) { require.NoError(t, saveSecretData("totally valid key", path)) - readData, err := readSecretData[p2p.NodeKey](path) + readData, err := readSecretData[types.NodeKey](path) require.Nil(t, readData) assert.ErrorContains(t, err, "unable to unmarshal data") @@ -59,7 +59,7 @@ func TestCommon_SaveReadData(t *testing.T) { require.NoError(t, saveSecretData(key, path)) - readKey, err := readSecretData[p2p.NodeKey](path) + readKey, err := readSecretData[types.NodeKey](path) require.NoError(t, err) assert.Equal(t, key, readKey) diff --git a/gno.land/cmd/gnoland/secrets_get.go b/gno.land/cmd/gnoland/secrets_get.go index 8d111516816..0a0a714f6ee 100644 --- a/gno.land/cmd/gnoland/secrets_get.go +++ b/gno.land/cmd/gnoland/secrets_get.go @@ -12,7 +12,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/bft/privval" "github.com/gnolang/gno/tm2/pkg/commands" osm "github.com/gnolang/gno/tm2/pkg/os" - "github.com/gnolang/gno/tm2/pkg/p2p" + "github.com/gnolang/gno/tm2/pkg/p2p/types" ) var errInvalidSecretsGetArgs = errors.New("invalid number of secrets get arguments provided") @@ -169,7 +169,7 @@ func readValidatorState(path string) (*validatorStateInfo, error) { // readNodeID reads the node p2p info from the given path func readNodeID(path string) (*nodeIDInfo, error) { - nodeKey, err := readSecretData[p2p.NodeKey](path) + nodeKey, err := readSecretData[types.NodeKey](path) if err != nil { return nil, fmt.Errorf("unable to read node key, %w", err) } @@ -199,7 +199,7 @@ func readNodeID(path string) (*nodeIDInfo, error) { // constructP2PAddress constructs the P2P address other nodes can use // to connect directly -func constructP2PAddress(nodeID p2p.ID, listenAddress string) string { +func constructP2PAddress(nodeID types.ID, listenAddress string) string { var ( address string parts = strings.SplitN(listenAddress, "://", 2) diff --git a/gno.land/cmd/gnoland/secrets_get_test.go b/gno.land/cmd/gnoland/secrets_get_test.go index 66e6e3509fc..3dfe0c727dd 100644 --- a/gno.land/cmd/gnoland/secrets_get_test.go +++ b/gno.land/cmd/gnoland/secrets_get_test.go @@ -13,7 +13,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/bft/config" "github.com/gnolang/gno/tm2/pkg/bft/privval" "github.com/gnolang/gno/tm2/pkg/commands" - "github.com/gnolang/gno/tm2/pkg/p2p" + "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -66,7 +66,7 @@ func TestSecrets_Get_All(t *testing.T) { // Get the node key nodeKeyPath := filepath.Join(tempDir, defaultNodeKeyName) - nodeKey, err := readSecretData[p2p.NodeKey](nodeKeyPath) + nodeKey, err := readSecretData[types.NodeKey](nodeKeyPath) require.NoError(t, err) // Get the validator private key diff --git a/gno.land/cmd/gnoland/secrets_init.go b/gno.land/cmd/gnoland/secrets_init.go index 58dd0783f66..9a7ddd106c3 100644 --- a/gno.land/cmd/gnoland/secrets_init.go +++ b/gno.land/cmd/gnoland/secrets_init.go @@ -12,7 +12,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/commands" "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" osm "github.com/gnolang/gno/tm2/pkg/os" - "github.com/gnolang/gno/tm2/pkg/p2p" + "github.com/gnolang/gno/tm2/pkg/p2p/types" ) var errOverwriteNotEnabled = errors.New("overwrite not enabled") @@ -200,10 +200,6 @@ func generateLastSignValidatorState() *privval.FilePVLastSignState { } // generateNodeKey generates the p2p node key -func generateNodeKey() *p2p.NodeKey { - privKey := ed25519.GenPrivKey() - - return &p2p.NodeKey{ - PrivKey: privKey, - } +func generateNodeKey() *types.NodeKey { + return types.GenerateNodeKey() } diff --git a/gno.land/cmd/gnoland/secrets_init_test.go b/gno.land/cmd/gnoland/secrets_init_test.go index 20e061447f5..7be3650fb4b 100644 --- a/gno.land/cmd/gnoland/secrets_init_test.go +++ b/gno.land/cmd/gnoland/secrets_init_test.go @@ -7,7 +7,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/bft/privval" "github.com/gnolang/gno/tm2/pkg/commands" - "github.com/gnolang/gno/tm2/pkg/p2p" + "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -37,7 +37,7 @@ func verifyValidatorState(t *testing.T, path string) { func verifyNodeKey(t *testing.T, path string) { t.Helper() - nodeKey, err := readSecretData[p2p.NodeKey](path) + nodeKey, err := readSecretData[types.NodeKey](path) require.NoError(t, err) assert.NoError(t, validateNodeKey(nodeKey)) diff --git a/gno.land/cmd/gnoland/secrets_verify.go b/gno.land/cmd/gnoland/secrets_verify.go index 32e563c1c6f..15fef6649ec 100644 --- a/gno.land/cmd/gnoland/secrets_verify.go +++ b/gno.land/cmd/gnoland/secrets_verify.go @@ -8,7 +8,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/bft/privval" "github.com/gnolang/gno/tm2/pkg/commands" - "github.com/gnolang/gno/tm2/pkg/p2p" + "github.com/gnolang/gno/tm2/pkg/p2p/types" ) type secretsVerifyCfg struct { @@ -146,7 +146,7 @@ func readAndVerifyValidatorState(path string, io commands.IO) (*privval.FilePVLa // readAndVerifyNodeKey reads the node p2p key from the given path and verifies it func readAndVerifyNodeKey(path string, io commands.IO) error { - nodeKey, err := readSecretData[p2p.NodeKey](path) + nodeKey, err := readSecretData[types.NodeKey](path) if err != nil { return fmt.Errorf("unable to read node p2p key, %w", err) } diff --git a/gno.land/cmd/gnoland/secrets_verify_test.go b/gno.land/cmd/gnoland/secrets_verify_test.go index 513d7c8b503..67630aaaa4a 100644 --- a/gno.land/cmd/gnoland/secrets_verify_test.go +++ b/gno.land/cmd/gnoland/secrets_verify_test.go @@ -8,7 +8,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/bft/privval" "github.com/gnolang/gno/tm2/pkg/commands" - "github.com/gnolang/gno/tm2/pkg/p2p" + "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -347,7 +347,7 @@ func TestSecrets_Verify_Single(t *testing.T) { dirPath := t.TempDir() path := filepath.Join(dirPath, defaultNodeKeyName) - invalidNodeKey := &p2p.NodeKey{ + invalidNodeKey := &types.NodeKey{ PrivKey: nil, // invalid } diff --git a/gno.land/cmd/gnoland/start.go b/gno.land/cmd/gnoland/start.go index a420e652810..cb5d54a513a 100644 --- a/gno.land/cmd/gnoland/start.go +++ b/gno.land/cmd/gnoland/start.go @@ -26,6 +26,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/crypto" "github.com/gnolang/gno/tm2/pkg/events" osm "github.com/gnolang/gno/tm2/pkg/os" + "github.com/gnolang/gno/tm2/pkg/std" "github.com/gnolang/gno/tm2/pkg/telemetry" "go.uber.org/zap" @@ -234,9 +235,10 @@ func execStart(ctx context.Context, c *startCfg, io commands.IO) error { // Create a top-level shared event switch evsw := events.NewEventSwitch() + minGasPrices := cfg.Application.MinGasPrices // Create application and node - cfg.LocalApp, err = gnoland.NewApp(nodeDir, c.skipFailingGenesisTxs, evsw, logger) + cfg.LocalApp, err = gnoland.NewApp(nodeDir, c.skipFailingGenesisTxs, evsw, logger, minGasPrices) if err != nil { return fmt.Errorf("unable to create the Gnoland app, %w", err) } @@ -372,10 +374,10 @@ func generateGenesisFile(genesisFile string, pk crypto.PubKey, c *startCfg) erro gen.ConsensusParams = abci.ConsensusParams{ Block: &abci.BlockParams{ // TODO: update limits. - MaxTxBytes: 1_000_000, // 1MB, - MaxDataBytes: 2_000_000, // 2MB, - MaxGas: 100_000_000, // 100M gas - TimeIotaMS: 100, // 100ms + MaxTxBytes: 1_000_000, // 1MB, + MaxDataBytes: 2_000_000, // 2MB, + MaxGas: 3_000_000_000, // 3B gas + TimeIotaMS: 100, // 100ms }, } diff --git a/gno.land/pkg/gnoclient/integration_test.go b/gno.land/pkg/gnoclient/integration_test.go index 4b70fb60c49..bfcaaec999e 100644 --- a/gno.land/pkg/gnoclient/integration_test.go +++ b/gno.land/pkg/gnoclient/integration_test.go @@ -1,10 +1,12 @@ package gnoclient import ( + "path/filepath" "testing" "github.com/gnolang/gno/gnovm/pkg/gnolang" + "github.com/gnolang/gno/gno.land/pkg/gnoland" "github.com/gnolang/gno/gno.land/pkg/gnoland/ugnot" "github.com/gnolang/gno/gno.land/pkg/integration" "github.com/gnolang/gno/gno.land/pkg/sdk/vm" @@ -21,8 +23,14 @@ import ( ) func TestCallSingle_Integration(t *testing.T) { - // Set up in-memory node - config, _ := integration.TestingNodeConfig(t, gnoenv.RootDir()) + // Setup packages + rootdir := gnoenv.RootDir() + config := integration.TestingMinimalNodeConfig(gnoenv.RootDir()) + meta := loadpkgs(t, rootdir, "gno.land/r/demo/deep/very/deep") + state := config.Genesis.AppState.(gnoland.GnoGenesisState) + state.Txs = append(state.Txs, meta...) + config.Genesis.AppState = state + node, remoteAddr := integration.TestingInMemoryNode(t, log.NewNoopLogger(), config) defer node.Stop() @@ -74,8 +82,14 @@ func TestCallSingle_Integration(t *testing.T) { } func TestCallMultiple_Integration(t *testing.T) { - // Set up in-memory node - config, _ := integration.TestingNodeConfig(t, gnoenv.RootDir()) + // Setup packages + rootdir := gnoenv.RootDir() + config := integration.TestingMinimalNodeConfig(gnoenv.RootDir()) + meta := loadpkgs(t, rootdir, "gno.land/r/demo/deep/very/deep") + state := config.Genesis.AppState.(gnoland.GnoGenesisState) + state.Txs = append(state.Txs, meta...) + config.Genesis.AppState = state + node, remoteAddr := integration.TestingInMemoryNode(t, log.NewNoopLogger(), config) defer node.Stop() @@ -137,7 +151,7 @@ func TestCallMultiple_Integration(t *testing.T) { func TestSendSingle_Integration(t *testing.T) { // Set up in-memory node - config, _ := integration.TestingNodeConfig(t, gnoenv.RootDir()) + config := integration.TestingMinimalNodeConfig(gnoenv.RootDir()) node, remoteAddr := integration.TestingInMemoryNode(t, log.NewNoopLogger(), config) defer node.Stop() @@ -201,7 +215,7 @@ func TestSendSingle_Integration(t *testing.T) { func TestSendMultiple_Integration(t *testing.T) { // Set up in-memory node - config, _ := integration.TestingNodeConfig(t, gnoenv.RootDir()) + config := integration.TestingMinimalNodeConfig(gnoenv.RootDir()) node, remoteAddr := integration.TestingInMemoryNode(t, log.NewNoopLogger(), config) defer node.Stop() @@ -273,8 +287,15 @@ func TestSendMultiple_Integration(t *testing.T) { // Run tests func TestRunSingle_Integration(t *testing.T) { + // Setup packages + rootdir := gnoenv.RootDir() + config := integration.TestingMinimalNodeConfig(gnoenv.RootDir()) + meta := loadpkgs(t, rootdir, "gno.land/p/demo/ufmt", "gno.land/r/demo/tests") + state := config.Genesis.AppState.(gnoland.GnoGenesisState) + state.Txs = append(state.Txs, meta...) + config.Genesis.AppState = state + // Set up in-memory node - config, _ := integration.TestingNodeConfig(t, gnoenv.RootDir()) node, remoteAddr := integration.TestingInMemoryNode(t, log.NewNoopLogger(), config) defer node.Stop() @@ -342,7 +363,17 @@ func main() { // Run tests func TestRunMultiple_Integration(t *testing.T) { // Set up in-memory node - config, _ := integration.TestingNodeConfig(t, gnoenv.RootDir()) + rootdir := gnoenv.RootDir() + config := integration.TestingMinimalNodeConfig(rootdir) + meta := loadpkgs(t, rootdir, + "gno.land/p/demo/ufmt", + "gno.land/r/demo/tests", + "gno.land/r/demo/deep/very/deep", + ) + state := config.Genesis.AppState.(gnoland.GnoGenesisState) + state.Txs = append(state.Txs, meta...) + config.Genesis.AppState = state + node, remoteAddr := integration.TestingInMemoryNode(t, log.NewNoopLogger(), config) defer node.Stop() @@ -434,7 +465,7 @@ func main() { func TestAddPackageSingle_Integration(t *testing.T) { // Set up in-memory node - config, _ := integration.TestingNodeConfig(t, gnoenv.RootDir()) + config := integration.TestingMinimalNodeConfig(gnoenv.RootDir()) node, remoteAddr := integration.TestingInMemoryNode(t, log.NewNoopLogger(), config) defer node.Stop() @@ -519,7 +550,7 @@ func Echo(str string) string { func TestAddPackageMultiple_Integration(t *testing.T) { // Set up in-memory node - config, _ := integration.TestingNodeConfig(t, gnoenv.RootDir()) + config := integration.TestingMinimalNodeConfig(gnoenv.RootDir()) node, remoteAddr := integration.TestingInMemoryNode(t, log.NewNoopLogger(), config) defer node.Stop() @@ -670,3 +701,24 @@ func newInMemorySigner(t *testing.T, chainid string) *SignerFromKeybase { ChainID: chainid, // Chain ID for transaction signing } } + +func loadpkgs(t *testing.T, rootdir string, paths ...string) []gnoland.TxWithMetadata { + t.Helper() + + loader := integration.NewPkgsLoader() + examplesDir := filepath.Join(rootdir, "examples") + for _, path := range paths { + path = filepath.Clean(path) + path = filepath.Join(examplesDir, path) + err := loader.LoadPackage(examplesDir, path, "") + require.NoErrorf(t, err, "`loadpkg` unable to load package(s) from %q: %s", path, err) + } + privKey, err := integration.GeneratePrivKeyFromMnemonic(integration.DefaultAccount_Seed, "", 0, 0) + require.NoError(t, err) + + defaultFee := std.NewFee(50000, std.MustParseCoin(ugnot.ValueString(1000000))) + + meta, err := loader.LoadPackages(privKey, defaultFee, nil) + require.NoError(t, err) + return meta +} diff --git a/gno.land/pkg/gnoland/app.go b/gno.land/pkg/gnoland/app.go index 9e8f2163441..80c58e9e982 100644 --- a/gno.land/pkg/gnoland/app.go +++ b/gno.land/pkg/gnoland/app.go @@ -34,11 +34,13 @@ import ( // AppOptions contains the options to create the gno.land ABCI application. type AppOptions struct { - DB dbm.DB // required - Logger *slog.Logger // required - EventSwitch events.EventSwitch // required - VMOutput io.Writer // optional - InitChainerConfig // options related to InitChainer + DB dbm.DB // required + Logger *slog.Logger // required + EventSwitch events.EventSwitch // required + VMOutput io.Writer // optional + SkipGenesisVerification bool // default to verify genesis transactions + InitChainerConfig // options related to InitChainer + MinGasPrices string // optional } // TestAppOptions provides a "ready" default [AppOptions] for use with @@ -53,6 +55,7 @@ func TestAppOptions(db dbm.DB) *AppOptions { StdlibDir: filepath.Join(gnoenv.RootDir(), "gnovm", "stdlibs"), CacheStdlibLoad: true, }, + SkipGenesisVerification: true, } } @@ -79,9 +82,13 @@ func NewAppWithOptions(cfg *AppOptions) (abci.Application, error) { mainKey := store.NewStoreKey("main") baseKey := store.NewStoreKey("base") + // set sdk app options + var appOpts []func(*sdk.BaseApp) + if cfg.MinGasPrices != "" { + appOpts = append(appOpts, sdk.SetMinGasPrices(cfg.MinGasPrices)) + } // Create BaseApp. - // TODO: Add a consensus based min gas prices for the node, by default it does not check - baseApp := sdk.NewBaseApp("gnoland", cfg.Logger, cfg.DB, baseKey, mainKey) + baseApp := sdk.NewBaseApp("gnoland", cfg.Logger, cfg.DB, baseKey, mainKey, appOpts...) baseApp.SetAppVersion("dev") // Set mounts for BaseApp's MultiStore. @@ -105,7 +112,7 @@ func NewAppWithOptions(cfg *AppOptions) (abci.Application, error) { // Set AnteHandler authOptions := auth.AnteOptions{ - VerifyGenesisSignatures: false, // for development + VerifyGenesisSignatures: !cfg.SkipGenesisVerification, } authAnteHandler := auth.NewAnteHandler( acctKpr, bankKpr, auth.DefaultSigVerificationGasConsumer, authOptions) @@ -181,6 +188,7 @@ func NewApp( skipFailingGenesisTxs bool, evsw events.EventSwitch, logger *slog.Logger, + minGasPrices string, ) (abci.Application, error) { var err error @@ -191,6 +199,7 @@ func NewApp( GenesisTxResultHandler: PanicOnFailingTxResultHandler, StdlibDir: filepath.Join(gnoenv.RootDir(), "gnovm", "stdlibs"), }, + MinGasPrices: minGasPrices, } if skipFailingGenesisTxs { cfg.GenesisTxResultHandler = NoopGenesisTxResultHandler diff --git a/gno.land/pkg/gnoland/app_test.go b/gno.land/pkg/gnoland/app_test.go index 375602cfa4a..56a15fed5a9 100644 --- a/gno.land/pkg/gnoland/app_test.go +++ b/gno.land/pkg/gnoland/app_test.go @@ -134,7 +134,7 @@ func TestNewApp(t *testing.T) { // NewApp should have good defaults and manage to run InitChain. td := t.TempDir() - app, err := NewApp(td, true, events.NewEventSwitch(), log.NewNoopLogger()) + app, err := NewApp(td, true, events.NewEventSwitch(), log.NewNoopLogger(), "") require.NoError(t, err, "NewApp should be successful") resp := app.InitChain(abci.RequestInitChain{ diff --git a/gno.land/pkg/gnoland/node_inmemory.go b/gno.land/pkg/gnoland/node_inmemory.go index f42166411c8..cc9e74a78d8 100644 --- a/gno.land/pkg/gnoland/node_inmemory.go +++ b/gno.land/pkg/gnoland/node_inmemory.go @@ -16,15 +16,16 @@ import ( "github.com/gnolang/gno/tm2/pkg/db" "github.com/gnolang/gno/tm2/pkg/db/memdb" "github.com/gnolang/gno/tm2/pkg/events" - "github.com/gnolang/gno/tm2/pkg/p2p" + "github.com/gnolang/gno/tm2/pkg/p2p/types" ) type InMemoryNodeConfig struct { - PrivValidator bft.PrivValidator // identity of the validator - Genesis *bft.GenesisDoc - TMConfig *tmcfg.Config - DB *memdb.MemDB // will be initialized if nil - VMOutput io.Writer // optional + PrivValidator bft.PrivValidator // identity of the validator + Genesis *bft.GenesisDoc + TMConfig *tmcfg.Config + DB db.DB // will be initialized if nil + VMOutput io.Writer // optional + SkipGenesisVerification bool // If StdlibDir not set, then it's filepath.Join(TMConfig.RootDir, "gnovm", "stdlibs") InitChainerConfig @@ -112,11 +113,12 @@ func NewInMemoryNode(logger *slog.Logger, cfg *InMemoryNodeConfig) (*node.Node, // Initialize the application with the provided options gnoApp, err := NewAppWithOptions(&AppOptions{ - Logger: logger, - DB: cfg.DB, - EventSwitch: evsw, - InitChainerConfig: cfg.InitChainerConfig, - VMOutput: cfg.VMOutput, + Logger: logger, + DB: cfg.DB, + EventSwitch: evsw, + InitChainerConfig: cfg.InitChainerConfig, + VMOutput: cfg.VMOutput, + SkipGenesisVerification: cfg.SkipGenesisVerification, }) if err != nil { return nil, fmt.Errorf("error initializing new app: %w", err) @@ -138,7 +140,7 @@ func NewInMemoryNode(logger *slog.Logger, cfg *InMemoryNodeConfig) (*node.Node, dbProvider := func(*node.DBContext) (db.DB, error) { return cfg.DB, nil } // Generate p2p node identity - nodekey := &p2p.NodeKey{PrivKey: ed25519.GenPrivKey()} + nodekey := &types.NodeKey{PrivKey: ed25519.GenPrivKey()} // Create and return the in-memory node instance return node.NewNode(cfg.TMConfig, diff --git a/gno.land/pkg/gnoweb/Makefile b/gno.land/pkg/gnoweb/Makefile index 61397fef54f..39c9d20ab10 100644 --- a/gno.land/pkg/gnoweb/Makefile +++ b/gno.land/pkg/gnoweb/Makefile @@ -13,6 +13,7 @@ input_css := frontend/css/input.css output_css := $(PUBLIC_DIR)/styles.css tw_version := 3.4.14 tw_config_path := frontend/css/tx.config.js +templates_files := $(shell find . -iname '*.gohtml') # static config src_dir_static := frontend/static @@ -42,8 +43,8 @@ all: generate generate: css ts static css: $(output_css) -$(output_css): $(input_css) - npx -y tailwindcss@$(tw_version) -c $(tw_config_path) -i $< -o $@ --minify # tailwind +$(output_css): $(input_css) $(templates_files) + npx -y tailwindcss@$(tw_version) -c $(tw_config_path) -i $(input_css) -o $@ --minify # tailwind touch $@ ts: $(output_js) diff --git a/gno.land/pkg/gnoweb/app.go b/gno.land/pkg/gnoweb/app.go index dc13253468e..1deea86644b 100644 --- a/gno.land/pkg/gnoweb/app.go +++ b/gno.land/pkg/gnoweb/app.go @@ -7,11 +7,12 @@ import ( "path" "strings" + markdown "github.com/yuin/goldmark-highlighting/v2" + "github.com/alecthomas/chroma/v2" chromahtml "github.com/alecthomas/chroma/v2/formatters/html" "github.com/alecthomas/chroma/v2/styles" "github.com/gnolang/gno/gno.land/pkg/gnoweb/components" - "github.com/gnolang/gno/gno.land/pkg/gnoweb/markdown" "github.com/gnolang/gno/tm2/pkg/bft/rpc/client" "github.com/yuin/goldmark" mdhtml "github.com/yuin/goldmark/renderer/html" diff --git a/gno.land/pkg/gnoweb/app_test.go b/gno.land/pkg/gnoweb/app_test.go index 78fe197a134..4fac6e0b971 100644 --- a/gno.land/pkg/gnoweb/app_test.go +++ b/gno.land/pkg/gnoweb/app_test.go @@ -73,6 +73,7 @@ func TestRoutes(t *testing.T) { for _, r := range routes { t.Run(fmt.Sprintf("test route %s", r.route), func(t *testing.T) { + t.Logf("input: %q", r.route) request := httptest.NewRequest(http.MethodGet, r.route, nil) response := httptest.NewRecorder() router.ServeHTTP(response, request) @@ -125,7 +126,7 @@ func TestAnalytics(t *testing.T) { request := httptest.NewRequest(http.MethodGet, route, nil) response := httptest.NewRecorder() router.ServeHTTP(response, request) - fmt.Println("HELLO:", response.Body.String()) + assert.Contains(t, response.Body.String(), "sa.gno.services") }) } @@ -143,6 +144,7 @@ func TestAnalytics(t *testing.T) { request := httptest.NewRequest(http.MethodGet, route, nil) response := httptest.NewRecorder() router.ServeHTTP(response, request) + assert.NotContains(t, response.Body.String(), "sa.gno.services") }) } diff --git a/gno.land/pkg/gnoweb/components/breadcrumb.go b/gno.land/pkg/gnoweb/components/breadcrumb.go index 9e7a97b2fae..8eda02a9f4d 100644 --- a/gno.land/pkg/gnoweb/components/breadcrumb.go +++ b/gno.land/pkg/gnoweb/components/breadcrumb.go @@ -6,11 +6,12 @@ import ( type BreadcrumbPart struct { Name string - Path string + URL string } type BreadcrumbData struct { Parts []BreadcrumbPart + Args string } func RenderBreadcrumpComponent(w io.Writer, data BreadcrumbData) error { diff --git a/gno.land/pkg/gnoweb/components/breadcrumb.gohtml b/gno.land/pkg/gnoweb/components/breadcrumb.gohtml index a3301cb037e..9295b17b6f5 100644 --- a/gno.land/pkg/gnoweb/components/breadcrumb.gohtml +++ b/gno.land/pkg/gnoweb/components/breadcrumb.gohtml @@ -1,12 +1,18 @@ {{ define "breadcrumb" }} -
      +
        {{- range $index, $part := .Parts }} {{- if $index }} -
      1. +
      2. {{- else }} -
      3. +
      4. {{- end }} - {{ $part.Name }}
      5. + {{ $part.Name }} + + {{- end }} + {{- if .Args }} +
      6. + {{ .Args }} +
      7. {{- end }}
      -{{ end }} \ No newline at end of file +{{ end }} diff --git a/gno.land/pkg/gnoweb/handler.go b/gno.land/pkg/gnoweb/handler.go index bc87f057e26..20f19e4405a 100644 --- a/gno.land/pkg/gnoweb/handler.go +++ b/gno.land/pkg/gnoweb/handler.go @@ -94,16 +94,16 @@ func (h *WebHandler) Get(w http.ResponseWriter, r *http.Request) { indexData.HeadData.Title = "gno.land - " + gnourl.Path // Header - indexData.HeaderData.RealmPath = gnourl.Path - indexData.HeaderData.Breadcrumb.Parts = generateBreadcrumbPaths(gnourl.Path) + indexData.HeaderData.RealmPath = gnourl.Encode(EncodePath | EncodeArgs | EncodeQuery | EncodeNoEscape) + indexData.HeaderData.Breadcrumb = generateBreadcrumbPaths(gnourl) indexData.HeaderData.WebQuery = gnourl.WebQuery // Render - switch gnourl.Kind() { - case KindRealm, KindPure: + switch { + case gnourl.IsRealm(), gnourl.IsPure(): status, err = h.renderPackage(&body, gnourl) default: - h.logger.Debug("invalid page kind", "kind", gnourl.Kind) + h.logger.Debug("invalid path: path is neither a pure package or a realm") status, err = http.StatusNotFound, components.RenderStatusComponent(&body, "page not found") } } @@ -129,10 +129,8 @@ func (h *WebHandler) Get(w http.ResponseWriter, r *http.Request) { func (h *WebHandler) renderPackage(w io.Writer, gnourl *GnoURL) (status int, err error) { h.logger.Info("component render", "path", gnourl.Path, "args", gnourl.Args) - kind := gnourl.Kind() - // Display realm help page? - if kind == KindRealm && gnourl.WebQuery.Has("help") { + if gnourl.WebQuery.Has("help") { return h.renderRealmHelp(w, gnourl) } @@ -140,26 +138,11 @@ func (h *WebHandler) renderPackage(w io.Writer, gnourl *GnoURL) (status int, err switch { case gnourl.WebQuery.Has("source"): return h.renderRealmSource(w, gnourl) - case kind == KindPure, - strings.HasSuffix(gnourl.Path, "/"), - isFile(gnourl.Path): - i := strings.LastIndexByte(gnourl.Path, '/') - if i < 0 { - return http.StatusInternalServerError, fmt.Errorf("unable to get ending slash for %q", gnourl.Path) - } - + case gnourl.IsFile(): // Fill webquery with file infos - gnourl.WebQuery.Set("source", "") // set source - - file := gnourl.Path[i+1:] - if file == "" { - return h.renderRealmDirectory(w, gnourl) - } - - gnourl.WebQuery.Set("file", file) - gnourl.Path = gnourl.Path[:i] - return h.renderRealmSource(w, gnourl) + case gnourl.IsDir(), gnourl.IsPure(): + return h.renderRealmDirectory(w, gnourl) } // Render content into the content buffer @@ -250,12 +233,16 @@ func (h *WebHandler) renderRealmSource(w io.Writer, gnourl *GnoURL) (status int, return http.StatusOK, components.RenderStatusComponent(w, "no files available") } + file := gnourl.WebQuery.Get("file") // webquery override file + if file == "" { + file = gnourl.File + } + var fileName string - file := gnourl.WebQuery.Get("file") if file == "" { - fileName = files[0] + fileName = files[0] // Default to the first file if none specified } else if slices.Contains(files, file) { - fileName = file + fileName = file // Use specified file if it exists } else { h.logger.Error("unable to render source", "file", file, "err", "file does not exist") return http.StatusInternalServerError, components.RenderStatusComponent(w, "internal error") @@ -352,28 +339,25 @@ func (h *WebHandler) highlightSource(fileName string, src []byte) ([]byte, error return buff.Bytes(), nil } -func generateBreadcrumbPaths(path string) []components.BreadcrumbPart { - split := strings.Split(path, "/") - parts := []components.BreadcrumbPart{} +func generateBreadcrumbPaths(url *GnoURL) components.BreadcrumbData { + split := strings.Split(url.Path, "/") + var data components.BreadcrumbData var name string for i := range split { if name = split[i]; name == "" { continue } - parts = append(parts, components.BreadcrumbPart{ + data.Parts = append(data.Parts, components.BreadcrumbPart{ Name: name, - Path: strings.Join(split[:i+1], "/"), + URL: strings.Join(split[:i+1], "/"), }) } - return parts -} + if args := url.EncodeArgs(); args != "" { + data.Args = args + } -// IsFile checks if the last element of the path is a file (has an extension) -func isFile(path string) bool { - base := filepath.Base(path) - ext := filepath.Ext(base) - return ext != "" + return data } diff --git a/gno.land/pkg/gnoweb/markdown/highlighting.go b/gno.land/pkg/gnoweb/markdown/highlighting.go deleted file mode 100644 index 51c66674df1..00000000000 --- a/gno.land/pkg/gnoweb/markdown/highlighting.go +++ /dev/null @@ -1,588 +0,0 @@ -// This file was copied from https://github.com/yuin/goldmark-highlighting -// -// package highlighting is an extension for the goldmark(http://github.com/yuin/goldmark). -// -// This extension adds syntax-highlighting to the fenced code blocks using -// chroma(https://github.com/alecthomas/chroma). -package markdown - -import ( - "bytes" - "io" - "strconv" - "strings" - - "github.com/yuin/goldmark" - "github.com/yuin/goldmark/ast" - "github.com/yuin/goldmark/parser" - "github.com/yuin/goldmark/renderer" - "github.com/yuin/goldmark/renderer/html" - "github.com/yuin/goldmark/text" - "github.com/yuin/goldmark/util" - - "github.com/alecthomas/chroma/v2" - chromahtml "github.com/alecthomas/chroma/v2/formatters/html" - "github.com/alecthomas/chroma/v2/lexers" - "github.com/alecthomas/chroma/v2/styles" -) - -// ImmutableAttributes is a read-only interface for ast.Attributes. -type ImmutableAttributes interface { - // Get returns (value, true) if an attribute associated with given - // name exists, otherwise (nil, false) - Get(name []byte) (interface{}, bool) - - // GetString returns (value, true) if an attribute associated with given - // name exists, otherwise (nil, false) - GetString(name string) (interface{}, bool) - - // All returns all attributes. - All() []ast.Attribute -} - -type immutableAttributes struct { - n ast.Node -} - -func (a *immutableAttributes) Get(name []byte) (interface{}, bool) { - return a.n.Attribute(name) -} - -func (a *immutableAttributes) GetString(name string) (interface{}, bool) { - return a.n.AttributeString(name) -} - -func (a *immutableAttributes) All() []ast.Attribute { - if a.n.Attributes() == nil { - return []ast.Attribute{} - } - return a.n.Attributes() -} - -// CodeBlockContext holds contextual information of code highlighting. -type CodeBlockContext interface { - // Language returns (language, true) if specified, otherwise (nil, false). - Language() ([]byte, bool) - - // Highlighted returns true if this code block can be highlighted, otherwise false. - Highlighted() bool - - // Attributes return attributes of the code block. - Attributes() ImmutableAttributes -} - -type codeBlockContext struct { - language []byte - highlighted bool - attributes ImmutableAttributes -} - -func newCodeBlockContext(language []byte, highlighted bool, attrs ImmutableAttributes) CodeBlockContext { - return &codeBlockContext{ - language: language, - highlighted: highlighted, - attributes: attrs, - } -} - -func (c *codeBlockContext) Language() ([]byte, bool) { - if c.language != nil { - return c.language, true - } - return nil, false -} - -func (c *codeBlockContext) Highlighted() bool { - return c.highlighted -} - -func (c *codeBlockContext) Attributes() ImmutableAttributes { - return c.attributes -} - -// WrapperRenderer renders wrapper elements like div, pre, etc. -type WrapperRenderer func(w util.BufWriter, context CodeBlockContext, entering bool) - -// CodeBlockOptions creates Chroma options per code block. -type CodeBlockOptions func(ctx CodeBlockContext) []chromahtml.Option - -// Config struct holds options for the extension. -type Config struct { - html.Config - - // Style is a highlighting style. - // Supported styles are defined under https://github.com/alecthomas/chroma/tree/master/formatters. - Style string - - // Pass in a custom Chroma style. If this is not nil, the Style string will be ignored - CustomStyle *chroma.Style - - // If set, will try to guess language if none provided. - // If the guessing fails, we will fall back to a text lexer. - // Note that while Chroma's API supports language guessing, the implementation - // is not there yet, so you will currently always get the basic text lexer. - GuessLanguage bool - - // FormatOptions is a option related to output formats. - // See https://github.com/alecthomas/chroma#the-html-formatter for details. - FormatOptions []chromahtml.Option - - // CSSWriter is an io.Writer that will be used as CSS data output buffer. - // If WithClasses() is enabled, you can get CSS data corresponds to the style. - CSSWriter io.Writer - - // CodeBlockOptions allows set Chroma options per code block. - CodeBlockOptions CodeBlockOptions - - // WrapperRenderer allows you to change wrapper elements. - WrapperRenderer WrapperRenderer -} - -// NewConfig returns a new Config with defaults. -func NewConfig() Config { - return Config{ - Config: html.NewConfig(), - Style: "github", - FormatOptions: []chromahtml.Option{}, - CSSWriter: nil, - WrapperRenderer: nil, - CodeBlockOptions: nil, - } -} - -// SetOption implements renderer.SetOptioner. -func (c *Config) SetOption(name renderer.OptionName, value interface{}) { - switch name { - case optStyle: - c.Style = value.(string) - case optCustomStyle: - c.CustomStyle = value.(*chroma.Style) - case optFormatOptions: - if value != nil { - c.FormatOptions = value.([]chromahtml.Option) - } - case optCSSWriter: - c.CSSWriter = value.(io.Writer) - case optWrapperRenderer: - c.WrapperRenderer = value.(WrapperRenderer) - case optCodeBlockOptions: - c.CodeBlockOptions = value.(CodeBlockOptions) - case optGuessLanguage: - c.GuessLanguage = value.(bool) - default: - c.Config.SetOption(name, value) - } -} - -// Option interface is a functional option interface for the extension. -type Option interface { - renderer.Option - // SetHighlightingOption sets given option to the extension. - SetHighlightingOption(*Config) -} - -type withHTMLOptions struct { - value []html.Option -} - -func (o *withHTMLOptions) SetConfig(c *renderer.Config) { - if o.value != nil { - for _, v := range o.value { - v.(renderer.Option).SetConfig(c) - } - } -} - -func (o *withHTMLOptions) SetHighlightingOption(c *Config) { - if o.value != nil { - for _, v := range o.value { - v.SetHTMLOption(&c.Config) - } - } -} - -// WithHTMLOptions is functional option that wraps goldmark HTMLRenderer options. -func WithHTMLOptions(opts ...html.Option) Option { - return &withHTMLOptions{opts} -} - -const ( - optStyle renderer.OptionName = "HighlightingStyle" - optCustomStyle renderer.OptionName = "HighlightingCustomStyle" -) - -var highlightLinesAttrName = []byte("hl_lines") - -var ( - styleAttrName = []byte("hl_style") - nohlAttrName = []byte("nohl") - linenosAttrName = []byte("linenos") - linenosTableAttrValue = []byte("table") - linenosInlineAttrValue = []byte("inline") - linenostartAttrName = []byte("linenostart") -) - -type withStyle struct { - value string -} - -func (o *withStyle) SetConfig(c *renderer.Config) { - c.Options[optStyle] = o.value -} - -func (o *withStyle) SetHighlightingOption(c *Config) { - c.Style = o.value -} - -// WithStyle is a functional option that changes highlighting style. -func WithStyle(style string) Option { - return &withStyle{style} -} - -type withCustomStyle struct { - value *chroma.Style -} - -func (o *withCustomStyle) SetConfig(c *renderer.Config) { - c.Options[optCustomStyle] = o.value -} - -func (o *withCustomStyle) SetHighlightingOption(c *Config) { - c.CustomStyle = o.value -} - -// WithStyle is a functional option that changes highlighting style. -func WithCustomStyle(style *chroma.Style) Option { - return &withCustomStyle{style} -} - -const optCSSWriter renderer.OptionName = "HighlightingCSSWriter" - -type withCSSWriter struct { - value io.Writer -} - -func (o *withCSSWriter) SetConfig(c *renderer.Config) { - c.Options[optCSSWriter] = o.value -} - -func (o *withCSSWriter) SetHighlightingOption(c *Config) { - c.CSSWriter = o.value -} - -// WithCSSWriter is a functional option that sets io.Writer for CSS data. -func WithCSSWriter(w io.Writer) Option { - return &withCSSWriter{w} -} - -const optGuessLanguage renderer.OptionName = "HighlightingGuessLanguage" - -type withGuessLanguage struct { - value bool -} - -func (o *withGuessLanguage) SetConfig(c *renderer.Config) { - c.Options[optGuessLanguage] = o.value -} - -func (o *withGuessLanguage) SetHighlightingOption(c *Config) { - c.GuessLanguage = o.value -} - -// WithGuessLanguage is a functional option that toggles language guessing -// if none provided. -func WithGuessLanguage(b bool) Option { - return &withGuessLanguage{value: b} -} - -const optWrapperRenderer renderer.OptionName = "HighlightingWrapperRenderer" - -type withWrapperRenderer struct { - value WrapperRenderer -} - -func (o *withWrapperRenderer) SetConfig(c *renderer.Config) { - c.Options[optWrapperRenderer] = o.value -} - -func (o *withWrapperRenderer) SetHighlightingOption(c *Config) { - c.WrapperRenderer = o.value -} - -// WithWrapperRenderer is a functional option that sets WrapperRenderer that -// renders wrapper elements like div, pre, etc. -func WithWrapperRenderer(w WrapperRenderer) Option { - return &withWrapperRenderer{w} -} - -const optCodeBlockOptions renderer.OptionName = "HighlightingCodeBlockOptions" - -type withCodeBlockOptions struct { - value CodeBlockOptions -} - -func (o *withCodeBlockOptions) SetConfig(c *renderer.Config) { - c.Options[optCodeBlockOptions] = o.value -} - -func (o *withCodeBlockOptions) SetHighlightingOption(c *Config) { - c.CodeBlockOptions = o.value -} - -// WithCodeBlockOptions is a functional option that sets CodeBlockOptions that -// allows setting Chroma options per code block. -func WithCodeBlockOptions(c CodeBlockOptions) Option { - return &withCodeBlockOptions{value: c} -} - -const optFormatOptions renderer.OptionName = "HighlightingFormatOptions" - -type withFormatOptions struct { - value []chromahtml.Option -} - -func (o *withFormatOptions) SetConfig(c *renderer.Config) { - if _, ok := c.Options[optFormatOptions]; !ok { - c.Options[optFormatOptions] = []chromahtml.Option{} - } - c.Options[optFormatOptions] = append(c.Options[optFormatOptions].([]chromahtml.Option), o.value...) -} - -func (o *withFormatOptions) SetHighlightingOption(c *Config) { - c.FormatOptions = append(c.FormatOptions, o.value...) -} - -// WithFormatOptions is a functional option that wraps chroma HTML formatter options. -func WithFormatOptions(opts ...chromahtml.Option) Option { - return &withFormatOptions{opts} -} - -// HTMLRenderer struct is a renderer.NodeRenderer implementation for the extension. -type HTMLRenderer struct { - Config -} - -// NewHTMLRenderer builds a new HTMLRenderer with given options and returns it. -func NewHTMLRenderer(opts ...Option) renderer.NodeRenderer { - r := &HTMLRenderer{ - Config: NewConfig(), - } - for _, opt := range opts { - opt.SetHighlightingOption(&r.Config) - } - return r -} - -// RegisterFuncs implements NodeRenderer.RegisterFuncs. -func (r *HTMLRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) { - reg.Register(ast.KindFencedCodeBlock, r.renderFencedCodeBlock) -} - -func getAttributes(node *ast.FencedCodeBlock, infostr []byte) ImmutableAttributes { - if node.Attributes() != nil { - return &immutableAttributes{node} - } - if infostr != nil { - attrStartIdx := -1 - - for idx, char := range infostr { - if char == '{' { - attrStartIdx = idx - break - } - } - if attrStartIdx > 0 { - n := ast.NewTextBlock() // dummy node for storing attributes - attrStr := infostr[attrStartIdx:] - if attrs, hasAttr := parser.ParseAttributes(text.NewReader(attrStr)); hasAttr { - for _, attr := range attrs { - n.SetAttribute(attr.Name, attr.Value) - } - return &immutableAttributes{n} - } - } - } - return nil -} - -func (r *HTMLRenderer) renderFencedCodeBlock(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) { - n := node.(*ast.FencedCodeBlock) - if !entering { - return ast.WalkContinue, nil - } - language := n.Language(source) - - chromaFormatterOptions := make([]chromahtml.Option, 0, len(r.FormatOptions)) - for _, opt := range r.FormatOptions { - chromaFormatterOptions = append(chromaFormatterOptions, opt) - } - - style := r.CustomStyle - if style == nil { - style = styles.Get(r.Style) - } - nohl := false - - var info []byte - if n.Info != nil { - info = n.Info.Segment.Value(source) - } - attrs := getAttributes(n, info) - if attrs != nil { - baseLineNumber := 1 - if linenostartAttr, ok := attrs.Get(linenostartAttrName); ok { - if linenostart, ok := linenostartAttr.(float64); ok { - baseLineNumber = int(linenostart) - chromaFormatterOptions = append( - chromaFormatterOptions, chromahtml.BaseLineNumber(baseLineNumber), - ) - } - } - if linesAttr, hasLinesAttr := attrs.Get(highlightLinesAttrName); hasLinesAttr { - if lines, ok := linesAttr.([]interface{}); ok { - var hlRanges [][2]int - for _, l := range lines { - if ln, ok := l.(float64); ok { - hlRanges = append(hlRanges, [2]int{int(ln) + baseLineNumber - 1, int(ln) + baseLineNumber - 1}) - } - if rng, ok := l.([]uint8); ok { - slices := strings.Split(string(rng), "-") - lhs, err := strconv.Atoi(slices[0]) - if err != nil { - continue - } - rhs := lhs - if len(slices) > 1 { - rhs, err = strconv.Atoi(slices[1]) - if err != nil { - continue - } - } - hlRanges = append(hlRanges, [2]int{lhs + baseLineNumber - 1, rhs + baseLineNumber - 1}) - } - } - chromaFormatterOptions = append(chromaFormatterOptions, chromahtml.HighlightLines(hlRanges)) - } - } - if styleAttr, hasStyleAttr := attrs.Get(styleAttrName); hasStyleAttr { - if st, ok := styleAttr.([]uint8); ok { - styleStr := string(st) - style = styles.Get(styleStr) - } - } - if _, hasNohlAttr := attrs.Get(nohlAttrName); hasNohlAttr { - nohl = true - } - - if linenosAttr, ok := attrs.Get(linenosAttrName); ok { - switch v := linenosAttr.(type) { - case bool: - chromaFormatterOptions = append(chromaFormatterOptions, chromahtml.WithLineNumbers(v)) - case []uint8: - if v != nil { - chromaFormatterOptions = append(chromaFormatterOptions, chromahtml.WithLineNumbers(true)) - } - if bytes.Equal(v, linenosTableAttrValue) { - chromaFormatterOptions = append(chromaFormatterOptions, chromahtml.LineNumbersInTable(true)) - } else if bytes.Equal(v, linenosInlineAttrValue) { - chromaFormatterOptions = append(chromaFormatterOptions, chromahtml.LineNumbersInTable(false)) - } - } - } - } - - var lexer chroma.Lexer - if language != nil { - lexer = lexers.Get(string(language)) - } - if !nohl && (lexer != nil || r.GuessLanguage) { - if style == nil { - style = styles.Fallback - } - var buffer bytes.Buffer - l := n.Lines().Len() - for i := 0; i < l; i++ { - line := n.Lines().At(i) - buffer.Write(line.Value(source)) - } - - if lexer == nil { - lexer = lexers.Analyse(buffer.String()) - if lexer == nil { - lexer = lexers.Fallback - } - language = []byte(strings.ToLower(lexer.Config().Name)) - } - lexer = chroma.Coalesce(lexer) - - iterator, err := lexer.Tokenise(nil, buffer.String()) - if err == nil { - c := newCodeBlockContext(language, true, attrs) - - if r.CodeBlockOptions != nil { - chromaFormatterOptions = append(chromaFormatterOptions, r.CodeBlockOptions(c)...) - } - formatter := chromahtml.New(chromaFormatterOptions...) - if r.WrapperRenderer != nil { - r.WrapperRenderer(w, c, true) - } - _ = formatter.Format(w, style, iterator) == nil - if r.WrapperRenderer != nil { - r.WrapperRenderer(w, c, false) - } - if r.CSSWriter != nil { - _ = formatter.WriteCSS(r.CSSWriter, style) - } - return ast.WalkContinue, nil - } - } - - var c CodeBlockContext - if r.WrapperRenderer != nil { - c = newCodeBlockContext(language, false, attrs) - r.WrapperRenderer(w, c, true) - } else { - _, _ = w.WriteString("
      ')
      -	}
      -	l := n.Lines().Len()
      -	for i := 0; i < l; i++ {
      -		line := n.Lines().At(i)
      -		r.Writer.RawWrite(w, line.Value(source))
      -	}
      -	if r.WrapperRenderer != nil {
      -		r.WrapperRenderer(w, c, false)
      -	} else {
      -		_, _ = w.WriteString("
      \n") - } - return ast.WalkContinue, nil -} - -type highlighting struct { - options []Option -} - -// Highlighting is a goldmark.Extender implementation. -var Highlighting = &highlighting{ - options: []Option{}, -} - -// NewHighlighting returns a new extension with given options. -func NewHighlighting(opts ...Option) goldmark.Extender { - return &highlighting{ - options: opts, - } -} - -// Extend implements goldmark.Extender. -func (e *highlighting) Extend(m goldmark.Markdown) { - m.Renderer().AddOptions(renderer.WithNodeRenderers( - util.Prioritized(NewHTMLRenderer(e.options...), 200), - )) -} diff --git a/gno.land/pkg/gnoweb/markdown/highlighting_test.go b/gno.land/pkg/gnoweb/markdown/highlighting_test.go deleted file mode 100644 index 25bc4fedd61..00000000000 --- a/gno.land/pkg/gnoweb/markdown/highlighting_test.go +++ /dev/null @@ -1,568 +0,0 @@ -// This file was copied from https://github.com/yuin/goldmark-highlighting - -package markdown - -import ( - "bytes" - "fmt" - "strings" - "testing" - - "github.com/alecthomas/chroma/v2" - chromahtml "github.com/alecthomas/chroma/v2/formatters/html" - "github.com/yuin/goldmark" - "github.com/yuin/goldmark/testutil" - "github.com/yuin/goldmark/util" -) - -func TestHighlighting(t *testing.T) { - var css bytes.Buffer - markdown := goldmark.New( - goldmark.WithExtensions( - NewHighlighting( - WithStyle("monokai"), - WithCSSWriter(&css), - WithFormatOptions( - chromahtml.WithClasses(true), - chromahtml.WithLineNumbers(false), - ), - WithWrapperRenderer(func(w util.BufWriter, c CodeBlockContext, entering bool) { - _, ok := c.Language() - if entering { - if !ok { - w.WriteString("
      ")
      -							return
      -						}
      -						w.WriteString(`
      `) - } else { - if !ok { - w.WriteString("
      ") - return - } - w.WriteString(`
`) - } - }), - WithCodeBlockOptions(func(c CodeBlockContext) []chromahtml.Option { - if language, ok := c.Language(); ok { - // Turn on line numbers for Go only. - if string(language) == "go" { - return []chromahtml.Option{ - chromahtml.WithLineNumbers(true), - } - } - } - return nil - }), - ), - ), - ) - var buffer bytes.Buffer - if err := markdown.Convert([]byte(` -Title -======= -`+"``` go\n"+`func main() { - fmt.Println("ok") -} -`+"```"+` -`), &buffer); err != nil { - t.Fatal(err) - } - - if strings.TrimSpace(buffer.String()) != strings.TrimSpace(` -

Title

-
1func main() {
-2    fmt.Println("ok")
-3}
-
-`) { - t.Errorf("failed to render HTML\n%s", buffer.String()) - } - - expected := strings.TrimSpace(`/* Background */ .bg { color: #f8f8f2; background-color: #272822; } -/* PreWrapper */ .chroma { color: #f8f8f2; background-color: #272822; } -/* LineNumbers targeted by URL anchor */ .chroma .ln:target { color: #f8f8f2; background-color: #3c3d38 } -/* LineNumbersTable targeted by URL anchor */ .chroma .lnt:target { color: #f8f8f2; background-color: #3c3d38 } -/* Error */ .chroma .err { color: #960050; background-color: #1e0010 } -/* LineLink */ .chroma .lnlinks { outline: none; text-decoration: none; color: inherit } -/* LineTableTD */ .chroma .lntd { vertical-align: top; padding: 0; margin: 0; border: 0; } -/* LineTable */ .chroma .lntable { border-spacing: 0; padding: 0; margin: 0; border: 0; } -/* LineHighlight */ .chroma .hl { background-color: #3c3d38 } -/* LineNumbersTable */ .chroma .lnt { white-space: pre; -webkit-user-select: none; user-select: none; margin-right: 0.4em; padding: 0 0.4em 0 0.4em;color: #7f7f7f } -/* LineNumbers */ .chroma .ln { white-space: pre; -webkit-user-select: none; user-select: none; margin-right: 0.4em; padding: 0 0.4em 0 0.4em;color: #7f7f7f } -/* Line */ .chroma .line { display: flex; } -/* Keyword */ .chroma .k { color: #66d9ef } -/* KeywordConstant */ .chroma .kc { color: #66d9ef } -/* KeywordDeclaration */ .chroma .kd { color: #66d9ef } -/* KeywordNamespace */ .chroma .kn { color: #f92672 } -/* KeywordPseudo */ .chroma .kp { color: #66d9ef } -/* KeywordReserved */ .chroma .kr { color: #66d9ef } -/* KeywordType */ .chroma .kt { color: #66d9ef } -/* NameAttribute */ .chroma .na { color: #a6e22e } -/* NameClass */ .chroma .nc { color: #a6e22e } -/* NameConstant */ .chroma .no { color: #66d9ef } -/* NameDecorator */ .chroma .nd { color: #a6e22e } -/* NameException */ .chroma .ne { color: #a6e22e } -/* NameFunction */ .chroma .nf { color: #a6e22e } -/* NameOther */ .chroma .nx { color: #a6e22e } -/* NameTag */ .chroma .nt { color: #f92672 } -/* Literal */ .chroma .l { color: #ae81ff } -/* LiteralDate */ .chroma .ld { color: #e6db74 } -/* LiteralString */ .chroma .s { color: #e6db74 } -/* LiteralStringAffix */ .chroma .sa { color: #e6db74 } -/* LiteralStringBacktick */ .chroma .sb { color: #e6db74 } -/* LiteralStringChar */ .chroma .sc { color: #e6db74 } -/* LiteralStringDelimiter */ .chroma .dl { color: #e6db74 } -/* LiteralStringDoc */ .chroma .sd { color: #e6db74 } -/* LiteralStringDouble */ .chroma .s2 { color: #e6db74 } -/* LiteralStringEscape */ .chroma .se { color: #ae81ff } -/* LiteralStringHeredoc */ .chroma .sh { color: #e6db74 } -/* LiteralStringInterpol */ .chroma .si { color: #e6db74 } -/* LiteralStringOther */ .chroma .sx { color: #e6db74 } -/* LiteralStringRegex */ .chroma .sr { color: #e6db74 } -/* LiteralStringSingle */ .chroma .s1 { color: #e6db74 } -/* LiteralStringSymbol */ .chroma .ss { color: #e6db74 } -/* LiteralNumber */ .chroma .m { color: #ae81ff } -/* LiteralNumberBin */ .chroma .mb { color: #ae81ff } -/* LiteralNumberFloat */ .chroma .mf { color: #ae81ff } -/* LiteralNumberHex */ .chroma .mh { color: #ae81ff } -/* LiteralNumberInteger */ .chroma .mi { color: #ae81ff } -/* LiteralNumberIntegerLong */ .chroma .il { color: #ae81ff } -/* LiteralNumberOct */ .chroma .mo { color: #ae81ff } -/* Operator */ .chroma .o { color: #f92672 } -/* OperatorWord */ .chroma .ow { color: #f92672 } -/* Comment */ .chroma .c { color: #75715e } -/* CommentHashbang */ .chroma .ch { color: #75715e } -/* CommentMultiline */ .chroma .cm { color: #75715e } -/* CommentSingle */ .chroma .c1 { color: #75715e } -/* CommentSpecial */ .chroma .cs { color: #75715e } -/* CommentPreproc */ .chroma .cp { color: #75715e } -/* CommentPreprocFile */ .chroma .cpf { color: #75715e } -/* GenericDeleted */ .chroma .gd { color: #f92672 } -/* GenericEmph */ .chroma .ge { font-style: italic } -/* GenericInserted */ .chroma .gi { color: #a6e22e } -/* GenericStrong */ .chroma .gs { font-weight: bold } -/* GenericSubheading */ .chroma .gu { color: #75715e }`) - - gotten := strings.TrimSpace(css.String()) - - if expected != gotten { - diff := testutil.DiffPretty([]byte(expected), []byte(gotten)) - t.Errorf("incorrect CSS.\n%s", string(diff)) - } -} - -func TestHighlighting2(t *testing.T) { - markdown := goldmark.New( - goldmark.WithExtensions( - Highlighting, - ), - ) - var buffer bytes.Buffer - if err := markdown.Convert([]byte(` -Title -======= -`+"```"+` -func main() { - fmt.Println("ok") -} -`+"```"+` -`), &buffer); err != nil { - t.Fatal(err) - } - - if strings.TrimSpace(buffer.String()) != strings.TrimSpace(` -

Title

-
func main() {
-    fmt.Println("ok")
-}
-
-`) { - t.Error("failed to render HTML") - } -} - -func TestHighlighting3(t *testing.T) { - markdown := goldmark.New( - goldmark.WithExtensions( - Highlighting, - ), - ) - var buffer bytes.Buffer - if err := markdown.Convert([]byte(` -Title -======= - -`+"```"+`cpp {hl_lines=[1,2]} -#include -int main() { - std::cout<< "hello" << std::endl; -} -`+"```"+` -`), &buffer); err != nil { - t.Fatal(err) - } - if strings.TrimSpace(buffer.String()) != strings.TrimSpace(` -

Title

-
#include <iostream>
-int main() {
-    std::cout<< "hello" << std::endl;
-}
-
-`) { - t.Errorf("failed to render HTML:\n%s", buffer.String()) - } -} - -func TestHighlightingCustom(t *testing.T) { - custom := chroma.MustNewStyle("custom", chroma.StyleEntries{ - chroma.Background: "#cccccc bg:#1d1d1d", - chroma.Comment: "#999999", - chroma.CommentSpecial: "#cd0000", - chroma.Keyword: "#cc99cd", - chroma.KeywordDeclaration: "#cc99cd", - chroma.KeywordNamespace: "#cc99cd", - chroma.KeywordType: "#cc99cd", - chroma.Operator: "#67cdcc", - chroma.OperatorWord: "#cdcd00", - chroma.NameClass: "#f08d49", - chroma.NameBuiltin: "#f08d49", - chroma.NameFunction: "#f08d49", - chroma.NameException: "bold #666699", - chroma.NameVariable: "#00cdcd", - chroma.LiteralString: "#7ec699", - chroma.LiteralNumber: "#f08d49", - chroma.LiteralStringBoolean: "#f08d49", - chroma.GenericHeading: "bold #000080", - chroma.GenericSubheading: "bold #800080", - chroma.GenericDeleted: "#e2777a", - chroma.GenericInserted: "#cc99cd", - chroma.GenericError: "#e2777a", - chroma.GenericEmph: "italic", - chroma.GenericStrong: "bold", - chroma.GenericPrompt: "bold #000080", - chroma.GenericOutput: "#888", - chroma.GenericTraceback: "#04D", - chroma.GenericUnderline: "underline", - chroma.Error: "border:#e2777a", - }) - - var css bytes.Buffer - markdown := goldmark.New( - goldmark.WithExtensions( - NewHighlighting( - WithStyle("monokai"), // to make sure it is overrided even if present - WithCustomStyle(custom), - WithCSSWriter(&css), - WithFormatOptions( - chromahtml.WithClasses(true), - chromahtml.WithLineNumbers(false), - ), - WithWrapperRenderer(func(w util.BufWriter, c CodeBlockContext, entering bool) { - _, ok := c.Language() - if entering { - if !ok { - w.WriteString("
")
-							return
-						}
-						w.WriteString(`
`) - } else { - if !ok { - w.WriteString("
") - return - } - w.WriteString(``) - } - }), - WithCodeBlockOptions(func(c CodeBlockContext) []chromahtml.Option { - if language, ok := c.Language(); ok { - // Turn on line numbers for Go only. - if string(language) == "go" { - return []chromahtml.Option{ - chromahtml.WithLineNumbers(true), - } - } - } - return nil - }), - ), - ), - ) - var buffer bytes.Buffer - if err := markdown.Convert([]byte(` -Title -======= -`+"``` go\n"+`func main() { - fmt.Println("ok") -} -`+"```"+` -`), &buffer); err != nil { - t.Fatal(err) - } - - if strings.TrimSpace(buffer.String()) != strings.TrimSpace(` -

Title

-
1func main() {
-2    fmt.Println("ok")
-3}
-
-`) { - t.Error("failed to render HTML", buffer.String()) - } - - expected := strings.TrimSpace(`/* Background */ .bg { color: #cccccc; background-color: #1d1d1d; } -/* PreWrapper */ .chroma { color: #cccccc; background-color: #1d1d1d; } -/* LineNumbers targeted by URL anchor */ .chroma .ln:target { color: #cccccc; background-color: #333333 } -/* LineNumbersTable targeted by URL anchor */ .chroma .lnt:target { color: #cccccc; background-color: #333333 } -/* Error */ .chroma .err { } -/* LineLink */ .chroma .lnlinks { outline: none; text-decoration: none; color: inherit } -/* LineTableTD */ .chroma .lntd { vertical-align: top; padding: 0; margin: 0; border: 0; } -/* LineTable */ .chroma .lntable { border-spacing: 0; padding: 0; margin: 0; border: 0; } -/* LineHighlight */ .chroma .hl { background-color: #333333 } -/* LineNumbersTable */ .chroma .lnt { white-space: pre; -webkit-user-select: none; user-select: none; margin-right: 0.4em; padding: 0 0.4em 0 0.4em;color: #666666 } -/* LineNumbers */ .chroma .ln { white-space: pre; -webkit-user-select: none; user-select: none; margin-right: 0.4em; padding: 0 0.4em 0 0.4em;color: #666666 } -/* Line */ .chroma .line { display: flex; } -/* Keyword */ .chroma .k { color: #cc99cd } -/* KeywordConstant */ .chroma .kc { color: #cc99cd } -/* KeywordDeclaration */ .chroma .kd { color: #cc99cd } -/* KeywordNamespace */ .chroma .kn { color: #cc99cd } -/* KeywordPseudo */ .chroma .kp { color: #cc99cd } -/* KeywordReserved */ .chroma .kr { color: #cc99cd } -/* KeywordType */ .chroma .kt { color: #cc99cd } -/* NameBuiltin */ .chroma .nb { color: #f08d49 } -/* NameClass */ .chroma .nc { color: #f08d49 } -/* NameException */ .chroma .ne { color: #666699; font-weight: bold } -/* NameFunction */ .chroma .nf { color: #f08d49 } -/* NameVariable */ .chroma .nv { color: #00cdcd } -/* LiteralString */ .chroma .s { color: #7ec699 } -/* LiteralStringAffix */ .chroma .sa { color: #7ec699 } -/* LiteralStringBacktick */ .chroma .sb { color: #7ec699 } -/* LiteralStringChar */ .chroma .sc { color: #7ec699 } -/* LiteralStringDelimiter */ .chroma .dl { color: #7ec699 } -/* LiteralStringDoc */ .chroma .sd { color: #7ec699 } -/* LiteralStringDouble */ .chroma .s2 { color: #7ec699 } -/* LiteralStringEscape */ .chroma .se { color: #7ec699 } -/* LiteralStringHeredoc */ .chroma .sh { color: #7ec699 } -/* LiteralStringInterpol */ .chroma .si { color: #7ec699 } -/* LiteralStringOther */ .chroma .sx { color: #7ec699 } -/* LiteralStringRegex */ .chroma .sr { color: #7ec699 } -/* LiteralStringSingle */ .chroma .s1 { color: #7ec699 } -/* LiteralStringSymbol */ .chroma .ss { color: #7ec699 } -/* LiteralNumber */ .chroma .m { color: #f08d49 } -/* LiteralNumberBin */ .chroma .mb { color: #f08d49 } -/* LiteralNumberFloat */ .chroma .mf { color: #f08d49 } -/* LiteralNumberHex */ .chroma .mh { color: #f08d49 } -/* LiteralNumberInteger */ .chroma .mi { color: #f08d49 } -/* LiteralNumberIntegerLong */ .chroma .il { color: #f08d49 } -/* LiteralNumberOct */ .chroma .mo { color: #f08d49 } -/* Operator */ .chroma .o { color: #67cdcc } -/* OperatorWord */ .chroma .ow { color: #cdcd00 } -/* Comment */ .chroma .c { color: #999999 } -/* CommentHashbang */ .chroma .ch { color: #999999 } -/* CommentMultiline */ .chroma .cm { color: #999999 } -/* CommentSingle */ .chroma .c1 { color: #999999 } -/* CommentSpecial */ .chroma .cs { color: #cd0000 } -/* CommentPreproc */ .chroma .cp { color: #999999 } -/* CommentPreprocFile */ .chroma .cpf { color: #999999 } -/* GenericDeleted */ .chroma .gd { color: #e2777a } -/* GenericEmph */ .chroma .ge { font-style: italic } -/* GenericError */ .chroma .gr { color: #e2777a } -/* GenericHeading */ .chroma .gh { color: #000080; font-weight: bold } -/* GenericInserted */ .chroma .gi { color: #cc99cd } -/* GenericOutput */ .chroma .go { color: #888888 } -/* GenericPrompt */ .chroma .gp { color: #000080; font-weight: bold } -/* GenericStrong */ .chroma .gs { font-weight: bold } -/* GenericSubheading */ .chroma .gu { color: #800080; font-weight: bold } -/* GenericTraceback */ .chroma .gt { color: #0044dd } -/* GenericUnderline */ .chroma .gl { text-decoration: underline }`) - - gotten := strings.TrimSpace(css.String()) - - if expected != gotten { - diff := testutil.DiffPretty([]byte(expected), []byte(gotten)) - t.Errorf("incorrect CSS.\n%s", string(diff)) - } -} - -func TestHighlightingHlLines(t *testing.T) { - markdown := goldmark.New( - goldmark.WithExtensions( - NewHighlighting( - WithFormatOptions( - chromahtml.WithClasses(true), - ), - ), - ), - ) - - for i, test := range []struct { - attributes string - expect []int - }{ - {`hl_lines=["2"]`, []int{2}}, - {`hl_lines=["2-3",5],linenostart=5`, []int{2, 3, 5}}, - {`hl_lines=["2-3"]`, []int{2, 3}}, - {`hl_lines=["2-3",5],linenostart="5"`, []int{2, 3}}, // linenostart must be a number. string values are ignored - } { - t.Run(fmt.Sprint(i), func(t *testing.T) { - var buffer bytes.Buffer - codeBlock := fmt.Sprintf(`bash {%s} -LINE1 -LINE2 -LINE3 -LINE4 -LINE5 -LINE6 -LINE7 -LINE8 -`, test.attributes) - - if err := markdown.Convert([]byte(` -`+"```"+codeBlock+"```"+` -`), &buffer); err != nil { - t.Fatal(err) - } - - for _, line := range test.expect { - expectStr := fmt.Sprintf("LINE%d\n", line) - if !strings.Contains(buffer.String(), expectStr) { - t.Fatal("got\n", buffer.String(), "\nexpected\n", expectStr) - } - } - }) - } -} - -type nopPreWrapper struct{} - -// Start is called to write a start
 element.
-func (nopPreWrapper) Start(code bool, styleAttr string) string { return "" }
-
-// End is called to write the end 
element. -func (nopPreWrapper) End(code bool) string { return "" } - -func TestHighlightingLinenos(t *testing.T) { - outputLineNumbersInTable := `
- -
-1 - -LINE1 -
-
` - - for i, test := range []struct { - attributes string - lineNumbers bool - lineNumbersInTable bool - expect string - }{ - {`linenos=true`, false, false, `1LINE1 -`}, - {`linenos=false`, false, false, `LINE1 -`}, - {``, true, false, `1LINE1 -`}, - {``, true, true, outputLineNumbersInTable}, - {`linenos=inline`, true, true, `1LINE1 -`}, - {`linenos=foo`, false, false, `1LINE1 -`}, - {`linenos=table`, false, false, outputLineNumbersInTable}, - } { - t.Run(fmt.Sprint(i), func(t *testing.T) { - markdown := goldmark.New( - goldmark.WithExtensions( - NewHighlighting( - WithFormatOptions( - chromahtml.WithLineNumbers(test.lineNumbers), - chromahtml.LineNumbersInTable(test.lineNumbersInTable), - chromahtml.WithPreWrapper(nopPreWrapper{}), - chromahtml.WithClasses(true), - ), - ), - ), - ) - - var buffer bytes.Buffer - codeBlock := fmt.Sprintf(`bash {%s} -LINE1 -`, test.attributes) - - content := "```" + codeBlock + "```" - - if err := markdown.Convert([]byte(content), &buffer); err != nil { - t.Fatal(err) - } - - s := strings.TrimSpace(buffer.String()) - - if s != test.expect { - t.Fatal("got\n", s, "\nexpected\n", test.expect) - } - }) - } -} - -func TestHighlightingGuessLanguage(t *testing.T) { - markdown := goldmark.New( - goldmark.WithExtensions( - NewHighlighting( - WithGuessLanguage(true), - WithFormatOptions( - chromahtml.WithClasses(true), - chromahtml.WithLineNumbers(true), - ), - ), - ), - ) - var buffer bytes.Buffer - if err := markdown.Convert([]byte("```"+` -LINE -`+"```"), &buffer); err != nil { - t.Fatal(err) - } - if strings.TrimSpace(buffer.String()) != strings.TrimSpace(` -
1LINE
-
-`) { - t.Errorf("render mismatch, got\n%s", buffer.String()) - } -} - -func TestCoalesceNeeded(t *testing.T) { - markdown := goldmark.New( - goldmark.WithExtensions( - NewHighlighting( - // WithGuessLanguage(true), - WithFormatOptions( - chromahtml.WithClasses(true), - chromahtml.WithLineNumbers(true), - ), - ), - ), - ) - var buffer bytes.Buffer - if err := markdown.Convert([]byte("```http"+` -GET /foo HTTP/1.1 -Content-Type: application/json -User-Agent: foo - -{ - "hello": "world" -} -`+"```"), &buffer); err != nil { - t.Fatal(err) - } - if strings.TrimSpace(buffer.String()) != strings.TrimSpace(` -
1GET /foo HTTP/1.1
-2Content-Type: application/json
-3User-Agent: foo
-4
-5{
-6  "hello": "world"
-7}
-
-`) { - t.Errorf("render mismatch, got\n%s", buffer.String()) - } -} diff --git a/gno.land/pkg/gnoweb/public/styles.css b/gno.land/pkg/gnoweb/public/styles.css index a1d7860c63e..a6771695454 100644 --- a/gno.land/pkg/gnoweb/public/styles.css +++ b/gno.land/pkg/gnoweb/public/styles.css @@ -1,3 +1,3 @@ @font-face{font-family:Roboto;font-style:normal;font-weight:900;font-display:swap;src:url(fonts/roboto/roboto-mono-normal.woff2) format("woff2"),url(fonts/roboto/roboto-mono-normal.woff) format("woff")}@font-face{font-family:Inter var;font-weight:100 900;font-display:block;font-style:oblique 0deg 10deg;src:url(fonts/intervar/Intervar.woff2) format("woff2")}*,:after,:before{--tw-border-spacing-x:0;--tw-border-spacing-y:0;--tw-translate-x:0;--tw-translate-y:0;--tw-rotate:0;--tw-skew-x:0;--tw-skew-y:0;--tw-scale-x:1;--tw-scale-y:1;--tw-pan-x: ;--tw-pan-y: ;--tw-pinch-zoom: ;--tw-scroll-snap-strictness:proximity;--tw-gradient-from-position: ;--tw-gradient-via-position: ;--tw-gradient-to-position: ;--tw-ordinal: ;--tw-slashed-zero: ;--tw-numeric-figure: ;--tw-numeric-spacing: ;--tw-numeric-fraction: ;--tw-ring-inset: ;--tw-ring-offset-width:0px;--tw-ring-offset-color:#fff;--tw-ring-color:rgba(59,130,246,.5);--tw-ring-offset-shadow:0 0 #0000;--tw-ring-shadow:0 0 #0000;--tw-shadow:0 0 #0000;--tw-shadow-colored:0 0 #0000;--tw-blur: ;--tw-brightness: ;--tw-contrast: ;--tw-grayscale: ;--tw-hue-rotate: ;--tw-invert: ;--tw-saturate: ;--tw-sepia: ;--tw-drop-shadow: ;--tw-backdrop-blur: ;--tw-backdrop-brightness: ;--tw-backdrop-contrast: ;--tw-backdrop-grayscale: ;--tw-backdrop-hue-rotate: ;--tw-backdrop-invert: ;--tw-backdrop-opacity: ;--tw-backdrop-saturate: ;--tw-backdrop-sepia: ;--tw-contain-size: ;--tw-contain-layout: ;--tw-contain-paint: ;--tw-contain-style: }::backdrop{--tw-border-spacing-x:0;--tw-border-spacing-y:0;--tw-translate-x:0;--tw-translate-y:0;--tw-rotate:0;--tw-skew-x:0;--tw-skew-y:0;--tw-scale-x:1;--tw-scale-y:1;--tw-pan-x: ;--tw-pan-y: ;--tw-pinch-zoom: ;--tw-scroll-snap-strictness:proximity;--tw-gradient-from-position: ;--tw-gradient-via-position: ;--tw-gradient-to-position: ;--tw-ordinal: ;--tw-slashed-zero: ;--tw-numeric-figure: ;--tw-numeric-spacing: ;--tw-numeric-fraction: ;--tw-ring-inset: ;--tw-ring-offset-width:0px;--tw-ring-offset-color:#fff;--tw-ring-color:rgba(59,130,246,.5);--tw-ring-offset-shadow:0 0 #0000;--tw-ring-shadow:0 0 #0000;--tw-shadow:0 0 #0000;--tw-shadow-colored:0 0 #0000;--tw-blur: ;--tw-brightness: ;--tw-contrast: ;--tw-grayscale: ;--tw-hue-rotate: ;--tw-invert: ;--tw-saturate: ;--tw-sepia: ;--tw-drop-shadow: ;--tw-backdrop-blur: ;--tw-backdrop-brightness: ;--tw-backdrop-contrast: ;--tw-backdrop-grayscale: ;--tw-backdrop-hue-rotate: ;--tw-backdrop-invert: ;--tw-backdrop-opacity: ;--tw-backdrop-saturate: ;--tw-backdrop-sepia: ;--tw-contain-size: ;--tw-contain-layout: ;--tw-contain-paint: ;--tw-contain-style: } -/*! tailwindcss v3.4.14 | MIT License | https://tailwindcss.com*/*,:after,:before{box-sizing:border-box;border:0 solid #bdbdbd}:after,:before{--tw-content:""}:host,html{line-height:1.5;-webkit-text-size-adjust:100%;-moz-tab-size:4;-o-tab-size:4;tab-size:4;font-family:ui-sans-serif,system-ui,sans-serif,Apple Color Emoji,Segoe UI Emoji,Segoe UI Symbol,Noto Color Emoji;font-feature-settings:normal;font-variation-settings:normal;-webkit-tap-highlight-color:transparent}body{margin:0;line-height:inherit}hr{height:0;color:inherit;border-top-width:1px}abbr:where([title]){-webkit-text-decoration:underline dotted;text-decoration:underline dotted}h1,h2,h3,h4,h5,h6{font-size:inherit;font-weight:inherit}a{color:inherit;text-decoration:inherit}b,strong{font-weight:bolder}code,kbd,pre,samp{font-family:Roboto,Menlo,Consolas,Ubuntu Mono,Roboto Mono,DejaVu Sans Mono,monospace;;font-feature-settings:normal;font-variation-settings:normal;font-size:1em}small{font-size:80%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sub{bottom:-.25em}sup{top:-.5em}table{text-indent:0;border-color:inherit;border-collapse:collapse}button,input,optgroup,select,textarea{font-family:inherit;font-feature-settings:inherit;font-variation-settings:inherit;font-size:100%;font-weight:inherit;line-height:inherit;letter-spacing:inherit;color:inherit;margin:0;padding:0}button,select{text-transform:none}button,input:where([type=button]),input:where([type=reset]),input:where([type=submit]){-webkit-appearance:button;background-color:transparent;background-image:none}:-moz-focusring{outline:auto}:-moz-ui-invalid{box-shadow:none}progress{vertical-align:baseline}::-webkit-inner-spin-button,::-webkit-outer-spin-button{height:auto}[type=search]{-webkit-appearance:textfield;outline-offset:-2px}::-webkit-search-decoration{-webkit-appearance:none}::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}summary{display:list-item}blockquote,dd,dl,figure,h1,h2,h3,h4,h5,h6,hr,p,pre{margin:0}fieldset{margin:0}fieldset,legend{padding:0}menu,ol,ul{list-style:none;margin:0;padding:0}dialog{padding:0}textarea{resize:vertical}input::-moz-placeholder,textarea::-moz-placeholder{opacity:1;color:#7c7c7c}input::placeholder,textarea::placeholder{opacity:1;color:#7c7c7c}[role=button],button{cursor:pointer}:disabled{cursor:default}audio,canvas,embed,iframe,img,object,svg,video{display:block;vertical-align:middle}img,video{max-width:100%;height:auto}[hidden]:where(:not([hidden=until-found])){display:none}html{--tw-bg-opacity:1;background-color:rgb(255 255 255/var(--tw-bg-opacity));font-family:Inter var,-apple-system,BlinkMacSystemFont,Segoe UI,Roboto,Helvetica Neue,Arial,Noto Sans,Apple Color Emoji,Segoe UI Emoji,Segoe UI Symbol,Noto Color Emoji,sans-serif;font-size:1rem;--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity));font-feature-settings:"kern" on,"liga" on,"calt" on,"zero" on;-webkit-font-feature-settings:"kern" on,"liga" on,"calt" on,"zero" on;-webkit-text-size-adjust:100%;-moz-text-size-adjust:100%;text-size-adjust:100%;-moz-osx-font-smoothing:grayscale;font-smoothing:antialiased;font-variant-ligatures:contextual common-ligatures;font-kerning:normal;text-rendering:optimizeLegibility}svg{max-height:100%;max-width:100%}form{margin-top:0;margin-bottom:0}.realm-content{overflow-wrap:break-word;padding-top:2.5rem;font-size:1rem}.realm-content>:first-child{margin-top:0!important}.realm-content a{font-weight:500;--tw-text-opacity:1;color:rgb(34 108 87/var(--tw-text-opacity))}.realm-content a:hover{text-decoration-line:underline}.realm-content h1,.realm-content h2,.realm-content h3,.realm-content h4{margin-top:3rem;line-height:1.25;--tw-text-opacity:1;color:rgb(8 8 9/var(--tw-text-opacity))}.realm-content h2,.realm-content h2 *{font-weight:700}.realm-content h3,.realm-content h3 *,.realm-content h4,.realm-content h4 *{font-weight:600}.realm-content h1+h2,.realm-content h2+h3,.realm-content h3+h4{margin-top:1rem}.realm-content h1{font-size:2.375rem;font-weight:700}.realm-content h2{font-size:1.5rem}.realm-content h3{margin-top:2.5rem;font-size:1.25rem}.realm-content h3,.realm-content h4{--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity))}.realm-content h4{margin-top:1.5rem;margin-bottom:1.5rem;font-size:1.125rem;font-weight:500}.realm-content p{margin-top:1.25rem;margin-bottom:1.25rem}.realm-content strong{font-weight:700;--tw-text-opacity:1;color:rgb(8 8 9/var(--tw-text-opacity))}.realm-content strong *{font-weight:700}.realm-content em{font-style:oblique 10deg}.realm-content blockquote{margin-top:1rem;margin-bottom:1rem;border-left-width:4px;--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity));padding-left:1rem;--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity));font-style:oblique 10deg}.realm-content ol,.realm-content ul{margin-top:1.5rem;margin-bottom:1.5rem;padding-left:1rem}.realm-content ol li,.realm-content ul li{margin-bottom:.5rem}.realm-content img{margin-top:2rem;margin-bottom:2rem;max-width:100%}.realm-content figure{margin-top:1.5rem;margin-bottom:1.5rem;text-align:center}.realm-content figcaption{font-size:.875rem;--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity))}.realm-content :not(pre)>code{border-radius:.25rem;background-color:rgb(226 226 226/var(--tw-bg-opacity));padding:.125rem .25rem;font-size:.96em}.realm-content :not(pre)>code,.realm-content pre{--tw-bg-opacity:1;font-family:Roboto,Menlo,Consolas,Ubuntu Mono,Roboto Mono,DejaVu Sans Mono,monospace;}.realm-content pre{overflow-x:auto;border-radius:.375rem;background-color:rgb(240 240 240/var(--tw-bg-opacity));padding:1rem}.realm-content hr{margin-top:2.5rem;margin-bottom:2.5rem;border-top-width:1px;--tw-border-opacity:1;border-color:rgb(226 226 226/var(--tw-border-opacity))}.realm-content table{margin-top:2rem;margin-bottom:2rem;width:100%;border-collapse:collapse}.realm-content td,.realm-content th{border-width:1px;--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity));padding:.5rem 1rem}.realm-content th{--tw-bg-opacity:1;background-color:rgb(226 226 226/var(--tw-bg-opacity));font-weight:700}.realm-content caption{margin-top:.5rem;text-align:left;font-size:.875rem;--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity))}.realm-content q{margin-top:1.5rem;margin-bottom:1.5rem;border-left-width:4px;--tw-border-opacity:1;border-left-color:rgb(204 204 204/var(--tw-border-opacity));padding-left:1rem;--tw-text-opacity:1;color:rgb(85 85 85/var(--tw-text-opacity));font-style:oblique 10deg;quotes:"“" "”" "‘" "’"}.realm-content q:after,.realm-content q:before{margin-right:.25rem;font-size:1.5rem;--tw-text-opacity:1;color:rgb(153 153 153/var(--tw-text-opacity));content:open-quote;vertical-align:-.4rem}.realm-content q:after{content:close-quote}.realm-content q:before{content:open-quote}.realm-content q:after{content:close-quote}.realm-content ol ol,.realm-content ol ul,.realm-content ul ol,.realm-content ul ul{margin-top:.75rem;margin-bottom:.5rem;padding-left:1rem}.realm-content ul{list-style-type:disc}.realm-content ol{list-style-type:decimal}.realm-content table th:first-child,.realm-content td:first-child{padding-left:0}.realm-content table th:last-child,.realm-content td:last-child{padding-right:0}.realm-content abbr[title]{cursor:help;border-bottom-width:1px;border-style:dotted}.realm-content details{margin-top:1.25rem;margin-bottom:1.25rem}.realm-content summary{cursor:pointer;font-weight:700}.realm-content a code{color:inherit}.realm-content video{margin-top:2rem;margin-bottom:2rem;max-width:100%}.realm-content math{font-family:Roboto,Menlo,Consolas,Ubuntu Mono,Roboto Mono,DejaVu Sans Mono,monospace;}.realm-content small{font-size:.875rem}.realm-content del{text-decoration-line:line-through}.realm-content sub{vertical-align:sub;font-size:.75rem}.realm-content sup{vertical-align:super;font-size:.75rem}.realm-content button,.realm-content input{border-width:1px;--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity));padding:.5rem 1rem}main :is(h1,h2,h3,h4){scroll-margin-top:6rem}::-moz-selection{--tw-bg-opacity:1;background-color:rgb(34 108 87/var(--tw-bg-opacity));--tw-text-opacity:1;color:rgb(255 255 255/var(--tw-text-opacity))}::selection{--tw-bg-opacity:1;background-color:rgb(34 108 87/var(--tw-bg-opacity));--tw-text-opacity:1;color:rgb(255 255 255/var(--tw-text-opacity))}.sidemenu .peer:checked+label>svg{--tw-text-opacity:1;color:rgb(34 108 87/var(--tw-text-opacity))}.toc-expend-btn:has(#toc-expend:checked)+nav{display:block}.toc-expend-btn:has(#toc-expend:checked) .toc-expend-btn_ico{--tw-rotate:180deg;transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skewX(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}.main-header:has(#sidemenu-docs:checked)+main #sidebar #sidebar-docs,.main-header:has(#sidemenu-meta:checked)+main #sidebar #sidebar-meta,.main-header:has(#sidemenu-source:checked)+main #sidebar #sidebar-source,.main-header:has(#sidemenu-summary:checked)+main #sidebar #sidebar-summary{display:block}@media (min-width:40rem){:is(.main-header:has(#sidemenu-source:checked),.main-header:has(#sidemenu-docs:checked),.main-header:has(#sidemenu-meta:checked)) .main-navigation,:is(.main-header:has(#sidemenu-source:checked),.main-header:has(#sidemenu-docs:checked),.main-header:has(#sidemenu-meta:checked))+main .realm-content{grid-column:span 6/span 6}:is(.main-header:has(#sidemenu-source:checked),.main-header:has(#sidemenu-docs:checked),.main-header:has(#sidemenu-meta:checked)) .sidemenu,:is(.main-header:has(#sidemenu-source:checked),.main-header:has(#sidemenu-docs:checked),.main-header:has(#sidemenu-meta:checked))+main #sidebar{grid-column:span 4/span 4}}:is(.main-header:has(#sidemenu-source:checked),.main-header:has(#sidemenu-docs:checked),.main-header:has(#sidemenu-meta:checked))+main #sidebar:before{position:absolute;top:0;left:-1.75rem;z-index:-1;display:block;height:100%;width:50vw;--tw-bg-opacity:1;background-color:rgb(226 226 226/var(--tw-bg-opacity));--tw-content:"";content:var(--tw-content)}main :is(.source-code)>pre{overflow:scroll;border-radius:.375rem;--tw-bg-opacity:1!important;background-color:rgb(255 255 255/var(--tw-bg-opacity))!important;padding:1rem .25rem;font-family:Roboto,Menlo,Consolas,Ubuntu Mono,Roboto Mono,DejaVu Sans Mono,monospace;;font-size:.875rem}@media (min-width:40rem){main :is(.source-code)>pre{padding:2rem .75rem;font-size:1rem}}main .realm-content>pre a:hover{text-decoration-line:none}main :is(.realm-content,.source-code)>pre .chroma-ln:target{background-color:transparent!important}main :is(.realm-content,.source-code)>pre .chroma-line:has(.chroma-ln:target),main :is(.realm-content,.source-code)>pre .chroma-line:has(.chroma-ln:target) .chroma-cl,main :is(.realm-content,.source-code)>pre .chroma-line:has(.chroma-lnlinks:hover),main :is(.realm-content,.source-code)>pre .chroma-line:has(.chroma-lnlinks:hover) .chroma-cl{border-radius:.375rem;--tw-bg-opacity:1!important;background-color:rgb(226 226 226/var(--tw-bg-opacity))!important}main :is(.realm-content,.source-code)>pre .chroma-ln{scroll-margin-top:6rem}.absolute{position:absolute}.relative{position:relative}.sticky{position:sticky}.bottom-1{bottom:.25rem}.left-0{left:0}.right-2{right:.5rem}.right-3{right:.75rem}.top-0{top:0}.top-1\/2{top:50%}.top-14{top:3.5rem}.top-2{top:.5rem}.z-max{z-index:9999}.col-span-1{grid-column:span 1/span 1}.col-span-10{grid-column:span 10/span 10}.col-span-3{grid-column:span 3/span 3}.col-span-7{grid-column:span 7/span 7}.row-span-1{grid-row:span 1/span 1}.row-start-1{grid-row-start:1}.mx-auto{margin-left:auto;margin-right:auto}.mb-1{margin-bottom:.25rem}.mb-2{margin-bottom:.5rem}.mb-3{margin-bottom:.75rem}.mb-4{margin-bottom:1rem}.mb-8{margin-bottom:2rem}.mr-10{margin-right:2.5rem}.mt-1{margin-top:.25rem}.mt-10{margin-top:2.5rem}.mt-2{margin-top:.5rem}.mt-4{margin-top:1rem}.mt-6{margin-top:1.5rem}.mt-8{margin-top:2rem}.line-clamp-2{overflow:hidden;display:-webkit-box;-webkit-box-orient:vertical;-webkit-line-clamp:2}.block{display:block}.inline-block{display:inline-block}.inline{display:inline}.flex{display:flex}.grid{display:grid}.hidden{display:none}.h-10{height:2.5rem}.h-4{height:1rem}.h-5{height:1.25rem}.h-6{height:1.5rem}.h-full{height:100%}.max-h-screen{max-height:100vh}.min-h-full{min-height:100%}.min-h-screen{min-height:100vh}.w-10{width:2.5rem}.w-4{width:1rem}.w-5{width:1.25rem}.w-full{width:100%}.min-w-2{min-width:.5rem}.min-w-48{min-width:12rem}.max-w-screen-max{max-width:98.75rem}.shrink-0{flex-shrink:0}.grow-\[2\]{flex-grow:2}.-translate-y-1\/2{--tw-translate-y:-50%;transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skewX(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}.cursor-pointer{cursor:pointer}.list-none{list-style-type:none}.appearance-none{-webkit-appearance:none;-moz-appearance:none;appearance:none}.grid-flow-dense{grid-auto-flow:dense}.auto-rows-min{grid-auto-rows:min-content}.grid-cols-1{grid-template-columns:repeat(1,minmax(0,1fr))}.grid-cols-10{grid-template-columns:repeat(10,minmax(0,1fr))}.flex-col{flex-direction:column}.items-start{align-items:flex-start}.items-center{align-items:center}.items-stretch{align-items:stretch}.justify-end{justify-content:flex-end}.justify-center{justify-content:center}.justify-between{justify-content:space-between}.gap-0\.5{gap:.125rem}.gap-1{gap:.25rem}.gap-1\.5{gap:.375rem}.gap-2{gap:.5rem}.gap-3{gap:.75rem}.gap-4{gap:1rem}.gap-8{gap:2rem}.gap-x-20{-moz-column-gap:5rem;column-gap:5rem}.gap-x-3{-moz-column-gap:.75rem;column-gap:.75rem}.gap-y-2{row-gap:.5rem}.space-y-2>:not([hidden])~:not([hidden]){--tw-space-y-reverse:0;margin-top:calc(.5rem*(1 - var(--tw-space-y-reverse)));margin-bottom:calc(.5rem*var(--tw-space-y-reverse))}.overflow-hidden{overflow:hidden}.overflow-scroll{overflow:scroll}.whitespace-pre-wrap{white-space:pre-wrap}.rounded{border-radius:.375rem}.rounded-sm{border-radius:.25rem}.border{border-width:1px}.border-b{border-bottom-width:1px}.border-l{border-left-width:1px}.border-t{border-top-width:1px}.border-gray-100{--tw-border-opacity:1;border-color:rgb(226 226 226/var(--tw-border-opacity))}.bg-gray-100{--tw-bg-opacity:1;background-color:rgb(226 226 226/var(--tw-bg-opacity))}.bg-gray-50{--tw-bg-opacity:1;background-color:rgb(240 240 240/var(--tw-bg-opacity))}.bg-light{--tw-bg-opacity:1;background-color:rgb(255 255 255/var(--tw-bg-opacity))}.bg-transparent{background-color:transparent}.p-1\.5{padding:.375rem}.p-2{padding:.5rem}.p-4{padding:1rem}.px-1{padding-left:.25rem;padding-right:.25rem}.px-10{padding-left:2.5rem;padding-right:2.5rem}.px-2{padding-left:.5rem;padding-right:.5rem}.px-3{padding-left:.75rem;padding-right:.75rem}.px-4{padding-left:1rem;padding-right:1rem}.py-1{padding-top:.25rem;padding-bottom:.25rem}.py-1\.5{padding-top:.375rem;padding-bottom:.375rem}.py-2{padding-top:.5rem;padding-bottom:.5rem}.py-px{padding-top:1px;padding-bottom:1px}.pb-24{padding-bottom:6rem}.pb-3{padding-bottom:.75rem}.pb-4{padding-bottom:1rem}.pb-6{padding-bottom:1.5rem}.pb-8{padding-bottom:2rem}.pl-4{padding-left:1rem}.pr-10{padding-right:2.5rem}.pt-0\.5{padding-top:.125rem}.pt-2{padding-top:.5rem}.font-mono{font-family:Roboto,Menlo,Consolas,Ubuntu Mono,Roboto Mono,DejaVu Sans Mono,monospace;}.text-100{font-size:.875rem}.text-200{font-size:1rem}.text-50{font-size:.75rem}.text-600{font-size:1.5rem}.font-bold{font-weight:700}.font-medium{font-weight:500}.font-normal{font-weight:400}.font-semibold{font-weight:600}.capitalize{text-transform:capitalize}.leading-tight{line-height:1.25}.text-gray-300{--tw-text-opacity:1;color:rgb(153 153 153/var(--tw-text-opacity))}.text-gray-400{--tw-text-opacity:1;color:rgb(124 124 124/var(--tw-text-opacity))}.text-gray-600{--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity))}.text-gray-800{--tw-text-opacity:1;color:rgb(19 19 19/var(--tw-text-opacity))}.text-gray-900{--tw-text-opacity:1;color:rgb(8 8 9/var(--tw-text-opacity))}.text-green-600{--tw-text-opacity:1;color:rgb(34 108 87/var(--tw-text-opacity))}.outline-none{outline:2px solid transparent;outline-offset:2px}.text-stroke{-webkit-text-stroke:currentColor;-webkit-text-stroke-width:.6px}.no-scrollbar::-webkit-scrollbar{display:none}.no-scrollbar{-ms-overflow-style:none;scrollbar-width:none}.\*\:pl-0>*{padding-left:0}.before\:px-\[0\.18rem\]:before{content:var(--tw-content);padding-left:.18rem;padding-right:.18rem}.before\:text-gray-300:before{content:var(--tw-content);--tw-text-opacity:1;color:rgb(153 153 153/var(--tw-text-opacity))}.before\:content-\[\'\/\'\]:before{--tw-content:"/";content:var(--tw-content)}.before\:content-\[\'open\'\]:before{--tw-content:"open";content:var(--tw-content)}.after\:absolute:after{content:var(--tw-content);position:absolute}.after\:bottom-0:after{content:var(--tw-content);bottom:0}.after\:left-0:after{content:var(--tw-content);left:0}.after\:block:after{content:var(--tw-content);display:block}.after\:h-1:after{content:var(--tw-content);height:.25rem}.after\:w-full:after{content:var(--tw-content);width:100%}.after\:rounded-t-sm:after{content:var(--tw-content);border-top-left-radius:.25rem;border-top-right-radius:.25rem}.after\:bg-green-600:after{content:var(--tw-content);--tw-bg-opacity:1;background-color:rgb(34 108 87/var(--tw-bg-opacity))}.first\:border-t:first-child{border-top-width:1px}.hover\:border-gray-300:hover{--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity))}.hover\:bg-gray-100:hover{--tw-bg-opacity:1;background-color:rgb(226 226 226/var(--tw-bg-opacity))}.hover\:bg-gray-50:hover{--tw-bg-opacity:1;background-color:rgb(240 240 240/var(--tw-bg-opacity))}.hover\:bg-green-600:hover{--tw-bg-opacity:1;background-color:rgb(34 108 87/var(--tw-bg-opacity))}.hover\:text-gray-600:hover{--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity))}.hover\:text-green-600:hover{--tw-text-opacity:1;color:rgb(34 108 87/var(--tw-text-opacity))}.hover\:text-light:hover{--tw-text-opacity:1;color:rgb(255 255 255/var(--tw-text-opacity))}.hover\:underline:hover{text-decoration-line:underline}.focus\:border-gray-300:focus{--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity))}.focus\:border-l-gray-300:focus{--tw-border-opacity:1;border-left-color:rgb(153 153 153/var(--tw-border-opacity))}.group:hover .group-hover\:border-gray-300{--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity))}.group:hover .group-hover\:border-l-gray-300{--tw-border-opacity:1;border-left-color:rgb(153 153 153/var(--tw-border-opacity))}.group.is-active .group-\[\.is-active\]\:text-green-600{--tw-text-opacity:1;color:rgb(34 108 87/var(--tw-text-opacity))}.peer:checked~.peer-checked\:before\:content-\[\'close\'\]:before{--tw-content:"close";content:var(--tw-content)}.peer:focus-within~.peer-focus-within\:hidden{display:none}.has-\[ul\:empty\]\:hidden:has(ul:empty){display:none}.has-\[\:focus-within\]\:border-gray-300:has(:focus-within){--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity))}.has-\[\:focus\]\:border-gray-300:has(:focus){--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity))}@media (min-width:30rem){.sm\:gap-6{gap:1.5rem}}@media (min-width:40rem){.md\:col-span-3{grid-column:span 3/span 3}.md\:mb-0{margin-bottom:0}.md\:h-4{height:1rem}.md\:grid-cols-4{grid-template-columns:repeat(4,minmax(0,1fr))}.md\:flex-row{flex-direction:row}.md\:items-center{align-items:center}.md\:gap-x-8{-moz-column-gap:2rem;column-gap:2rem}.md\:px-10{padding-left:2.5rem;padding-right:2.5rem}.md\:pb-0{padding-bottom:0}}@media (min-width:51.25rem){.lg\:order-2{order:2}.lg\:col-span-3{grid-column:span 3/span 3}.lg\:col-span-7{grid-column:span 7/span 7}.lg\:row-span-2{grid-row:span 2/span 2}.lg\:row-start-1{grid-row-start:1}.lg\:row-start-2{grid-row-start:2}.lg\:mb-4{margin-bottom:1rem}.lg\:mt-0{margin-top:0}.lg\:mt-10{margin-top:2.5rem}.lg\:block{display:block}.lg\:hidden{display:none}.lg\:grid-cols-10{grid-template-columns:repeat(10,minmax(0,1fr))}.lg\:flex-row{flex-direction:row}.lg\:justify-start{justify-content:flex-start}.lg\:justify-between{justify-content:space-between}.lg\:gap-x-20{-moz-column-gap:5rem;column-gap:5rem}.lg\:border-none{border-style:none}.lg\:bg-transparent{background-color:transparent}.lg\:p-0{padding:0}.lg\:px-0{padding-left:0;padding-right:0}.lg\:px-2{padding-left:.5rem;padding-right:.5rem}.lg\:py-1\.5{padding-top:.375rem;padding-bottom:.375rem}.lg\:pb-28{padding-bottom:7rem}.lg\:pt-2{padding-top:.5rem}.lg\:text-200{font-size:1rem}.lg\:font-semibold{font-weight:600}.lg\:hover\:bg-transparent:hover{background-color:transparent}}@media (min-width:63.75rem){.xl\:inline{display:inline}.xl\:hidden{display:none}.xl\:grid-cols-10{grid-template-columns:repeat(10,minmax(0,1fr))}.xl\:flex-row{flex-direction:row}.xl\:items-center{align-items:center}.xl\:gap-20{gap:5rem}.xl\:gap-6{gap:1.5rem}.xl\:pt-0{padding-top:0}}@media (min-width:85.375rem){.xxl\:inline-block{display:inline-block}.xxl\:h-4{height:1rem}.xxl\:w-4{width:1rem}.xxl\:gap-20{gap:5rem}.xxl\:gap-x-32{-moz-column-gap:8rem;column-gap:8rem}.xxl\:pr-1{padding-right:.25rem}} \ No newline at end of file +/*! tailwindcss v3.4.14 | MIT License | https://tailwindcss.com*/*,:after,:before{box-sizing:border-box;border:0 solid #bdbdbd}:after,:before{--tw-content:""}:host,html{line-height:1.5;-webkit-text-size-adjust:100%;-moz-tab-size:4;-o-tab-size:4;tab-size:4;font-family:ui-sans-serif,system-ui,sans-serif,Apple Color Emoji,Segoe UI Emoji,Segoe UI Symbol,Noto Color Emoji;font-feature-settings:normal;font-variation-settings:normal;-webkit-tap-highlight-color:transparent}body{margin:0;line-height:inherit}hr{height:0;color:inherit;border-top-width:1px}abbr:where([title]){-webkit-text-decoration:underline dotted;text-decoration:underline dotted}h1,h2,h3,h4,h5,h6{font-size:inherit;font-weight:inherit}a{color:inherit;text-decoration:inherit}b,strong{font-weight:bolder}code,kbd,pre,samp{font-family:Roboto,Menlo,Consolas,Ubuntu Mono,Roboto Mono,DejaVu Sans Mono,monospace;;font-feature-settings:normal;font-variation-settings:normal;font-size:1em}small{font-size:80%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sub{bottom:-.25em}sup{top:-.5em}table{text-indent:0;border-color:inherit;border-collapse:collapse}button,input,optgroup,select,textarea{font-family:inherit;font-feature-settings:inherit;font-variation-settings:inherit;font-size:100%;font-weight:inherit;line-height:inherit;letter-spacing:inherit;color:inherit;margin:0;padding:0}button,select{text-transform:none}button,input:where([type=button]),input:where([type=reset]),input:where([type=submit]){-webkit-appearance:button;background-color:transparent;background-image:none}:-moz-focusring{outline:auto}:-moz-ui-invalid{box-shadow:none}progress{vertical-align:baseline}::-webkit-inner-spin-button,::-webkit-outer-spin-button{height:auto}[type=search]{-webkit-appearance:textfield;outline-offset:-2px}::-webkit-search-decoration{-webkit-appearance:none}::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}summary{display:list-item}blockquote,dd,dl,figure,h1,h2,h3,h4,h5,h6,hr,p,pre{margin:0}fieldset{margin:0}fieldset,legend{padding:0}menu,ol,ul{list-style:none;margin:0;padding:0}dialog{padding:0}textarea{resize:vertical}input::-moz-placeholder,textarea::-moz-placeholder{opacity:1;color:#7c7c7c}input::placeholder,textarea::placeholder{opacity:1;color:#7c7c7c}[role=button],button{cursor:pointer}:disabled{cursor:default}audio,canvas,embed,iframe,img,object,svg,video{display:block;vertical-align:middle}img,video{max-width:100%;height:auto}[hidden]:where(:not([hidden=until-found])){display:none}html{--tw-bg-opacity:1;background-color:rgb(255 255 255/var(--tw-bg-opacity));font-family:Inter var,-apple-system,BlinkMacSystemFont,Segoe UI,Roboto,Helvetica Neue,Arial,Noto Sans,Apple Color Emoji,Segoe UI Emoji,Segoe UI Symbol,Noto Color Emoji,sans-serif;font-size:1rem;--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity));font-feature-settings:"kern" on,"liga" on,"calt" on,"zero" on;-webkit-font-feature-settings:"kern" on,"liga" on,"calt" on,"zero" on;-webkit-text-size-adjust:100%;-moz-text-size-adjust:100%;text-size-adjust:100%;-moz-osx-font-smoothing:grayscale;font-smoothing:antialiased;font-variant-ligatures:contextual common-ligatures;font-kerning:normal;text-rendering:optimizeLegibility}svg{max-height:100%;max-width:100%}form{margin-top:0;margin-bottom:0}.realm-content{overflow-wrap:break-word;padding-top:2.5rem;font-size:1rem}.realm-content>:first-child{margin-top:0!important}.realm-content a{font-weight:500;--tw-text-opacity:1;color:rgb(34 108 87/var(--tw-text-opacity))}.realm-content a:hover{text-decoration-line:underline}.realm-content h1,.realm-content h2,.realm-content h3,.realm-content h4{margin-top:3rem;line-height:1.25;--tw-text-opacity:1;color:rgb(8 8 9/var(--tw-text-opacity))}.realm-content h2,.realm-content h2 *{font-weight:700}.realm-content h3,.realm-content h3 *,.realm-content h4,.realm-content h4 *{font-weight:600}.realm-content h1+h2,.realm-content h2+h3,.realm-content h3+h4{margin-top:1rem}.realm-content h1{font-size:2.375rem;font-weight:700}.realm-content h2{font-size:1.5rem}.realm-content h3{margin-top:2.5rem;font-size:1.25rem}.realm-content h3,.realm-content h4{--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity))}.realm-content h4{margin-top:1.5rem;margin-bottom:1.5rem;font-size:1.125rem;font-weight:500}.realm-content p{margin-top:1.25rem;margin-bottom:1.25rem}.realm-content strong{font-weight:700;--tw-text-opacity:1;color:rgb(8 8 9/var(--tw-text-opacity))}.realm-content strong *{font-weight:700}.realm-content em{font-style:oblique 10deg}.realm-content blockquote{margin-top:1rem;margin-bottom:1rem;border-left-width:4px;--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity));padding-left:1rem;--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity));font-style:oblique 10deg}.realm-content ol,.realm-content ul{margin-top:1.5rem;margin-bottom:1.5rem;padding-left:1rem}.realm-content ol li,.realm-content ul li{margin-bottom:.5rem}.realm-content img{margin-top:2rem;margin-bottom:2rem;max-width:100%}.realm-content figure{margin-top:1.5rem;margin-bottom:1.5rem;text-align:center}.realm-content figcaption{font-size:.875rem;--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity))}.realm-content :not(pre)>code{border-radius:.25rem;background-color:rgb(226 226 226/var(--tw-bg-opacity));padding:.125rem .25rem;font-size:.96em}.realm-content :not(pre)>code,.realm-content pre{--tw-bg-opacity:1;font-family:Roboto,Menlo,Consolas,Ubuntu Mono,Roboto Mono,DejaVu Sans Mono,monospace;}.realm-content pre{overflow-x:auto;border-radius:.375rem;background-color:rgb(240 240 240/var(--tw-bg-opacity));padding:1rem}.realm-content hr{margin-top:2.5rem;margin-bottom:2.5rem;border-top-width:1px;--tw-border-opacity:1;border-color:rgb(226 226 226/var(--tw-border-opacity))}.realm-content table{margin-top:2rem;margin-bottom:2rem;width:100%;border-collapse:collapse}.realm-content td,.realm-content th{border-width:1px;--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity));padding:.5rem 1rem}.realm-content th{--tw-bg-opacity:1;background-color:rgb(226 226 226/var(--tw-bg-opacity));font-weight:700}.realm-content caption{margin-top:.5rem;text-align:left;font-size:.875rem;--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity))}.realm-content q{margin-top:1.5rem;margin-bottom:1.5rem;border-left-width:4px;--tw-border-opacity:1;border-left-color:rgb(204 204 204/var(--tw-border-opacity));padding-left:1rem;--tw-text-opacity:1;color:rgb(85 85 85/var(--tw-text-opacity));font-style:oblique 10deg;quotes:"“" "”" "‘" "’"}.realm-content q:after,.realm-content q:before{margin-right:.25rem;font-size:1.5rem;--tw-text-opacity:1;color:rgb(153 153 153/var(--tw-text-opacity));content:open-quote;vertical-align:-.4rem}.realm-content q:after{content:close-quote}.realm-content q:before{content:open-quote}.realm-content q:after{content:close-quote}.realm-content ol ol,.realm-content ol ul,.realm-content ul ol,.realm-content ul ul{margin-top:.75rem;margin-bottom:.5rem;padding-left:1rem}.realm-content ul{list-style-type:disc}.realm-content ol{list-style-type:decimal}.realm-content table th:first-child,.realm-content td:first-child{padding-left:0}.realm-content table th:last-child,.realm-content td:last-child{padding-right:0}.realm-content abbr[title]{cursor:help;border-bottom-width:1px;border-style:dotted}.realm-content details{margin-top:1.25rem;margin-bottom:1.25rem}.realm-content summary{cursor:pointer;font-weight:700}.realm-content a code{color:inherit}.realm-content video{margin-top:2rem;margin-bottom:2rem;max-width:100%}.realm-content math{font-family:Roboto,Menlo,Consolas,Ubuntu Mono,Roboto Mono,DejaVu Sans Mono,monospace;}.realm-content small{font-size:.875rem}.realm-content del{text-decoration-line:line-through}.realm-content sub{vertical-align:sub;font-size:.75rem}.realm-content sup{vertical-align:super;font-size:.75rem}.realm-content button,.realm-content input{border-width:1px;--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity));padding:.5rem 1rem}main :is(h1,h2,h3,h4){scroll-margin-top:6rem}::-moz-selection{--tw-bg-opacity:1;background-color:rgb(34 108 87/var(--tw-bg-opacity));--tw-text-opacity:1;color:rgb(255 255 255/var(--tw-text-opacity))}::selection{--tw-bg-opacity:1;background-color:rgb(34 108 87/var(--tw-bg-opacity));--tw-text-opacity:1;color:rgb(255 255 255/var(--tw-text-opacity))}.sidemenu .peer:checked+label>svg{--tw-text-opacity:1;color:rgb(34 108 87/var(--tw-text-opacity))}.toc-expend-btn:has(#toc-expend:checked)+nav{display:block}.toc-expend-btn:has(#toc-expend:checked) .toc-expend-btn_ico{--tw-rotate:180deg;transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skewX(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}.main-header:has(#sidemenu-docs:checked)+main #sidebar #sidebar-docs,.main-header:has(#sidemenu-meta:checked)+main #sidebar #sidebar-meta,.main-header:has(#sidemenu-source:checked)+main #sidebar #sidebar-source,.main-header:has(#sidemenu-summary:checked)+main #sidebar #sidebar-summary{display:block}@media (min-width:40rem){:is(.main-header:has(#sidemenu-source:checked),.main-header:has(#sidemenu-docs:checked),.main-header:has(#sidemenu-meta:checked)) .main-navigation,:is(.main-header:has(#sidemenu-source:checked),.main-header:has(#sidemenu-docs:checked),.main-header:has(#sidemenu-meta:checked))+main .realm-content{grid-column:span 6/span 6}:is(.main-header:has(#sidemenu-source:checked),.main-header:has(#sidemenu-docs:checked),.main-header:has(#sidemenu-meta:checked)) .sidemenu,:is(.main-header:has(#sidemenu-source:checked),.main-header:has(#sidemenu-docs:checked),.main-header:has(#sidemenu-meta:checked))+main #sidebar{grid-column:span 4/span 4}}:is(.main-header:has(#sidemenu-source:checked),.main-header:has(#sidemenu-docs:checked),.main-header:has(#sidemenu-meta:checked))+main #sidebar:before{position:absolute;top:0;left:-1.75rem;z-index:-1;display:block;height:100%;width:50vw;--tw-bg-opacity:1;background-color:rgb(226 226 226/var(--tw-bg-opacity));--tw-content:"";content:var(--tw-content)}main :is(.source-code)>pre{overflow:scroll;border-radius:.375rem;--tw-bg-opacity:1!important;background-color:rgb(255 255 255/var(--tw-bg-opacity))!important;padding:1rem .25rem;font-family:Roboto,Menlo,Consolas,Ubuntu Mono,Roboto Mono,DejaVu Sans Mono,monospace;;font-size:.875rem}@media (min-width:40rem){main :is(.source-code)>pre{padding:2rem .75rem;font-size:1rem}}main .realm-content>pre a:hover{text-decoration-line:none}main :is(.realm-content,.source-code)>pre .chroma-ln:target{background-color:transparent!important}main :is(.realm-content,.source-code)>pre .chroma-line:has(.chroma-ln:target),main :is(.realm-content,.source-code)>pre .chroma-line:has(.chroma-ln:target) .chroma-cl,main :is(.realm-content,.source-code)>pre .chroma-line:has(.chroma-lnlinks:hover),main :is(.realm-content,.source-code)>pre .chroma-line:has(.chroma-lnlinks:hover) .chroma-cl{border-radius:.375rem;--tw-bg-opacity:1!important;background-color:rgb(226 226 226/var(--tw-bg-opacity))!important}main :is(.realm-content,.source-code)>pre .chroma-ln{scroll-margin-top:6rem}.absolute{position:absolute}.relative{position:relative}.sticky{position:sticky}.bottom-1{bottom:.25rem}.left-0{left:0}.right-2{right:.5rem}.right-3{right:.75rem}.top-0{top:0}.top-1\/2{top:50%}.top-14{top:3.5rem}.top-2{top:.5rem}.z-1{z-index:1}.z-max{z-index:9999}.col-span-1{grid-column:span 1/span 1}.col-span-10{grid-column:span 10/span 10}.col-span-3{grid-column:span 3/span 3}.col-span-7{grid-column:span 7/span 7}.row-span-1{grid-row:span 1/span 1}.row-start-1{grid-row-start:1}.mx-auto{margin-left:auto;margin-right:auto}.mb-1{margin-bottom:.25rem}.mb-2{margin-bottom:.5rem}.mb-3{margin-bottom:.75rem}.mb-4{margin-bottom:1rem}.mb-8{margin-bottom:2rem}.mr-10{margin-right:2.5rem}.mt-1{margin-top:.25rem}.mt-10{margin-top:2.5rem}.mt-2{margin-top:.5rem}.mt-4{margin-top:1rem}.mt-6{margin-top:1.5rem}.mt-8{margin-top:2rem}.line-clamp-2{overflow:hidden;display:-webkit-box;-webkit-box-orient:vertical;-webkit-line-clamp:2}.block{display:block}.inline-block{display:inline-block}.inline{display:inline}.flex{display:flex}.grid{display:grid}.hidden{display:none}.h-10{height:2.5rem}.h-4{height:1rem}.h-5{height:1.25rem}.h-6{height:1.5rem}.h-full{height:100%}.max-h-screen{max-height:100vh}.min-h-full{min-height:100%}.min-h-screen{min-height:100vh}.w-10{width:2.5rem}.w-4{width:1rem}.w-5{width:1.25rem}.w-full{width:100%}.min-w-2{min-width:.5rem}.min-w-48{min-width:12rem}.max-w-screen-max{max-width:98.75rem}.shrink-0{flex-shrink:0}.grow-\[2\]{flex-grow:2}.-translate-y-1\/2{--tw-translate-y:-50%;transform:translate(var(--tw-translate-x),var(--tw-translate-y)) rotate(var(--tw-rotate)) skewX(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y))}.cursor-pointer{cursor:pointer}.list-none{list-style-type:none}.appearance-none{-webkit-appearance:none;-moz-appearance:none;appearance:none}.grid-flow-dense{grid-auto-flow:dense}.auto-rows-min{grid-auto-rows:min-content}.grid-cols-1{grid-template-columns:repeat(1,minmax(0,1fr))}.grid-cols-10{grid-template-columns:repeat(10,minmax(0,1fr))}.flex-col{flex-direction:column}.items-start{align-items:flex-start}.items-center{align-items:center}.items-stretch{align-items:stretch}.justify-end{justify-content:flex-end}.justify-center{justify-content:center}.justify-between{justify-content:space-between}.gap-0\.5{gap:.125rem}.gap-1{gap:.25rem}.gap-1\.5{gap:.375rem}.gap-2{gap:.5rem}.gap-3{gap:.75rem}.gap-4{gap:1rem}.gap-8{gap:2rem}.gap-x-20{-moz-column-gap:5rem;column-gap:5rem}.gap-x-3{-moz-column-gap:.75rem;column-gap:.75rem}.gap-y-2{row-gap:.5rem}.space-y-2>:not([hidden])~:not([hidden]){--tw-space-y-reverse:0;margin-top:calc(.5rem*(1 - var(--tw-space-y-reverse)));margin-bottom:calc(.5rem*var(--tw-space-y-reverse))}.overflow-hidden{overflow:hidden}.overflow-scroll{overflow:scroll}.whitespace-pre-wrap{white-space:pre-wrap}.rounded{border-radius:.375rem}.rounded-sm{border-radius:.25rem}.border{border-width:1px}.border-b{border-bottom-width:1px}.border-l{border-left-width:1px}.border-t{border-top-width:1px}.border-gray-100{--tw-border-opacity:1;border-color:rgb(226 226 226/var(--tw-border-opacity))}.bg-gray-100{--tw-bg-opacity:1;background-color:rgb(226 226 226/var(--tw-bg-opacity))}.bg-gray-300{--tw-bg-opacity:1;background-color:rgb(153 153 153/var(--tw-bg-opacity))}.bg-gray-50{--tw-bg-opacity:1;background-color:rgb(240 240 240/var(--tw-bg-opacity))}.bg-light{--tw-bg-opacity:1;background-color:rgb(255 255 255/var(--tw-bg-opacity))}.bg-transparent{background-color:transparent}.p-1\.5{padding:.375rem}.p-2{padding:.5rem}.p-4{padding:1rem}.px-1{padding-left:.25rem;padding-right:.25rem}.px-10{padding-left:2.5rem;padding-right:2.5rem}.px-2{padding-left:.5rem;padding-right:.5rem}.px-3{padding-left:.75rem;padding-right:.75rem}.px-4{padding-left:1rem;padding-right:1rem}.py-1{padding-top:.25rem;padding-bottom:.25rem}.py-1\.5{padding-top:.375rem;padding-bottom:.375rem}.py-2{padding-top:.5rem;padding-bottom:.5rem}.py-px{padding-top:1px;padding-bottom:1px}.pb-24{padding-bottom:6rem}.pb-3{padding-bottom:.75rem}.pb-4{padding-bottom:1rem}.pb-6{padding-bottom:1.5rem}.pb-8{padding-bottom:2rem}.pl-4{padding-left:1rem}.pr-10{padding-right:2.5rem}.pt-0\.5{padding-top:.125rem}.pt-2{padding-top:.5rem}.font-mono{font-family:Roboto,Menlo,Consolas,Ubuntu Mono,Roboto Mono,DejaVu Sans Mono,monospace;}.text-100{font-size:.875rem}.text-200{font-size:1rem}.text-50{font-size:.75rem}.text-600{font-size:1.5rem}.font-bold{font-weight:700}.font-medium{font-weight:500}.font-normal{font-weight:400}.font-semibold{font-weight:600}.capitalize{text-transform:capitalize}.leading-tight{line-height:1.25}.text-gray-300{--tw-text-opacity:1;color:rgb(153 153 153/var(--tw-text-opacity))}.text-gray-400{--tw-text-opacity:1;color:rgb(124 124 124/var(--tw-text-opacity))}.text-gray-600{--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity))}.text-gray-800{--tw-text-opacity:1;color:rgb(19 19 19/var(--tw-text-opacity))}.text-gray-900{--tw-text-opacity:1;color:rgb(8 8 9/var(--tw-text-opacity))}.text-green-600{--tw-text-opacity:1;color:rgb(34 108 87/var(--tw-text-opacity))}.text-light{--tw-text-opacity:1;color:rgb(255 255 255/var(--tw-text-opacity))}.outline-none{outline:2px solid transparent;outline-offset:2px}.text-stroke{-webkit-text-stroke:currentColor;-webkit-text-stroke-width:.6px}.no-scrollbar::-webkit-scrollbar{display:none}.no-scrollbar{-ms-overflow-style:none;scrollbar-width:none}.\*\:pl-0>*{padding-left:0}.before\:px-\[0\.18rem\]:before{content:var(--tw-content);padding-left:.18rem;padding-right:.18rem}.before\:text-gray-300:before{content:var(--tw-content);--tw-text-opacity:1;color:rgb(153 153 153/var(--tw-text-opacity))}.before\:content-\[\'\/\'\]:before{--tw-content:"/";content:var(--tw-content)}.before\:content-\[\'\:\'\]:before{--tw-content:":";content:var(--tw-content)}.before\:content-\[\'open\'\]:before{--tw-content:"open";content:var(--tw-content)}.after\:pointer-events-none:after{content:var(--tw-content);pointer-events:none}.after\:absolute:after{content:var(--tw-content);position:absolute}.after\:bottom-0:after{content:var(--tw-content);bottom:0}.after\:left-0:after{content:var(--tw-content);left:0}.after\:top-0:after{content:var(--tw-content);top:0}.after\:block:after{content:var(--tw-content);display:block}.after\:h-1:after{content:var(--tw-content);height:.25rem}.after\:h-full:after{content:var(--tw-content);height:100%}.after\:w-full:after{content:var(--tw-content);width:100%}.after\:rounded-t-sm:after{content:var(--tw-content);border-top-left-radius:.25rem;border-top-right-radius:.25rem}.after\:bg-gray-100:after{content:var(--tw-content);--tw-bg-opacity:1;background-color:rgb(226 226 226/var(--tw-bg-opacity))}.after\:bg-green-600:after{content:var(--tw-content);--tw-bg-opacity:1;background-color:rgb(34 108 87/var(--tw-bg-opacity))}.first\:border-t:first-child{border-top-width:1px}.hover\:border-gray-300:hover{--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity))}.hover\:bg-gray-100:hover{--tw-bg-opacity:1;background-color:rgb(226 226 226/var(--tw-bg-opacity))}.hover\:bg-gray-50:hover{--tw-bg-opacity:1;background-color:rgb(240 240 240/var(--tw-bg-opacity))}.hover\:bg-green-600:hover{--tw-bg-opacity:1;background-color:rgb(34 108 87/var(--tw-bg-opacity))}.hover\:text-gray-600:hover{--tw-text-opacity:1;color:rgb(84 89 93/var(--tw-text-opacity))}.hover\:text-green-600:hover{--tw-text-opacity:1;color:rgb(34 108 87/var(--tw-text-opacity))}.hover\:text-light:hover{--tw-text-opacity:1;color:rgb(255 255 255/var(--tw-text-opacity))}.hover\:underline:hover{text-decoration-line:underline}.focus\:border-gray-300:focus{--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity))}.focus\:border-l-gray-300:focus{--tw-border-opacity:1;border-left-color:rgb(153 153 153/var(--tw-border-opacity))}.group:hover .group-hover\:border-gray-300{--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity))}.group:hover .group-hover\:border-l-gray-300{--tw-border-opacity:1;border-left-color:rgb(153 153 153/var(--tw-border-opacity))}.group.is-active .group-\[\.is-active\]\:text-green-600{--tw-text-opacity:1;color:rgb(34 108 87/var(--tw-text-opacity))}.peer:checked~.peer-checked\:before\:content-\[\'close\'\]:before{--tw-content:"close";content:var(--tw-content)}.peer:focus-within~.peer-focus-within\:hidden{display:none}.has-\[ul\:empty\]\:hidden:has(ul:empty){display:none}.has-\[\:focus-within\]\:border-gray-300:has(:focus-within){--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity))}.has-\[\:focus\]\:border-gray-300:has(:focus){--tw-border-opacity:1;border-color:rgb(153 153 153/var(--tw-border-opacity))}@media (min-width:30rem){.sm\:gap-6{gap:1.5rem}}@media (min-width:40rem){.md\:col-span-3{grid-column:span 3/span 3}.md\:mb-0{margin-bottom:0}.md\:h-4{height:1rem}.md\:grid-cols-4{grid-template-columns:repeat(4,minmax(0,1fr))}.md\:flex-row{flex-direction:row}.md\:items-center{align-items:center}.md\:gap-x-8{-moz-column-gap:2rem;column-gap:2rem}.md\:px-10{padding-left:2.5rem;padding-right:2.5rem}.md\:pb-0{padding-bottom:0}}@media (min-width:51.25rem){.lg\:order-2{order:2}.lg\:col-span-3{grid-column:span 3/span 3}.lg\:col-span-7{grid-column:span 7/span 7}.lg\:row-span-2{grid-row:span 2/span 2}.lg\:row-start-1{grid-row-start:1}.lg\:row-start-2{grid-row-start:2}.lg\:mb-4{margin-bottom:1rem}.lg\:mt-0{margin-top:0}.lg\:mt-10{margin-top:2.5rem}.lg\:block{display:block}.lg\:hidden{display:none}.lg\:grid-cols-10{grid-template-columns:repeat(10,minmax(0,1fr))}.lg\:flex-row{flex-direction:row}.lg\:justify-start{justify-content:flex-start}.lg\:justify-between{justify-content:space-between}.lg\:gap-x-20{-moz-column-gap:5rem;column-gap:5rem}.lg\:border-none{border-style:none}.lg\:bg-transparent{background-color:transparent}.lg\:p-0{padding:0}.lg\:px-0{padding-left:0;padding-right:0}.lg\:px-2{padding-left:.5rem;padding-right:.5rem}.lg\:py-1\.5{padding-top:.375rem;padding-bottom:.375rem}.lg\:pb-28{padding-bottom:7rem}.lg\:pt-2{padding-top:.5rem}.lg\:text-200{font-size:1rem}.lg\:font-semibold{font-weight:600}.lg\:hover\:bg-transparent:hover{background-color:transparent}}@media (min-width:63.75rem){.xl\:inline{display:inline}.xl\:hidden{display:none}.xl\:grid-cols-10{grid-template-columns:repeat(10,minmax(0,1fr))}.xl\:flex-row{flex-direction:row}.xl\:items-center{align-items:center}.xl\:gap-20{gap:5rem}.xl\:gap-6{gap:1.5rem}.xl\:pt-0{padding-top:0}}@media (min-width:85.375rem){.xxl\:inline-block{display:inline-block}.xxl\:h-4{height:1rem}.xxl\:w-4{width:1rem}.xxl\:gap-20{gap:5rem}.xxl\:gap-x-32{-moz-column-gap:8rem;column-gap:8rem}.xxl\:pr-1{padding-right:.25rem}} \ No newline at end of file diff --git a/gno.land/pkg/gnoweb/url.go b/gno.land/pkg/gnoweb/url.go index bc03f2182d9..9127225d490 100644 --- a/gno.land/pkg/gnoweb/url.go +++ b/gno.land/pkg/gnoweb/url.go @@ -4,145 +4,253 @@ import ( "errors" "fmt" "net/url" + "path/filepath" "regexp" + "slices" "strings" ) -type PathKind byte +var ErrURLInvalidPath = errors.New("invalid path") -const ( - KindInvalid PathKind = 0 - KindRealm PathKind = 'r' - KindPure PathKind = 'p' -) +// rePkgOrRealmPath matches and validates a flexible path. +var rePkgOrRealmPath = regexp.MustCompile(`^/[a-z][a-z0-9_/]*$`) // GnoURL decomposes the parts of an URL to query a realm. type GnoURL struct { // Example full path: - // gno.land/r/demo/users:jae$help&a=b?c=d + // gno.land/r/demo/users/render.gno:jae$help&a=b?c=d Domain string // gno.land Path string // /r/demo/users Args string // jae WebQuery url.Values // help&a=b Query url.Values // c=d + File string // render.gno } -func (url GnoURL) EncodeArgs() string { +// EncodeFlag is used to specify which URL components to encode. +type EncodeFlag int + +const ( + EncodeDomain EncodeFlag = 1 << iota // Encode the domain component + EncodePath // Encode the path component + EncodeArgs // Encode the arguments component + EncodeWebQuery // Encode the web query component + EncodeQuery // Encode the query component + EncodeNoEscape // Disable escaping of arguments +) + +// Encode constructs a URL string from the components of a GnoURL struct, +// encoding the specified components based on the provided EncodeFlag bitmask. +// +// The function selectively encodes the URL's path, arguments, web query, and +// query parameters, depending on the flags set in encodeFlags. +// +// Returns a string representing the encoded URL. +// +// Example: +// +// gnoURL := GnoURL{ +// Domain: "gno.land", +// Path: "/r/demo/users", +// Args: "john", +// File: "render.gno", +// } +// +// encodedURL := gnoURL.Encode(EncodePath | EncodeArgs) +// fmt.Println(encodedURL) // Output: /r/demo/users/render.gno:john +// +// URL components are encoded using url.PathEscape unless EncodeNoEscape is specified. +func (gnoURL GnoURL) Encode(encodeFlags EncodeFlag) string { var urlstr strings.Builder - if url.Args != "" { - urlstr.WriteString(url.Args) - } - if len(url.Query) > 0 { - urlstr.WriteString("?" + url.Query.Encode()) - } + noEscape := encodeFlags.Has(EncodeNoEscape) - return urlstr.String() -} + if encodeFlags.Has(EncodeDomain) { + urlstr.WriteString(gnoURL.Domain) + } -func (url GnoURL) EncodePath() string { - var urlstr strings.Builder - urlstr.WriteString(url.Path) - if url.Args != "" { - urlstr.WriteString(":" + url.Args) + if encodeFlags.Has(EncodePath) { + path := gnoURL.Path + urlstr.WriteString(path) } - if len(url.Query) > 0 { - urlstr.WriteString("?" + url.Query.Encode()) + if len(gnoURL.File) > 0 { + urlstr.WriteRune('/') + urlstr.WriteString(gnoURL.File) } - return urlstr.String() -} + if encodeFlags.Has(EncodeArgs) && gnoURL.Args != "" { + if encodeFlags.Has(EncodePath) { + urlstr.WriteRune(':') + } -func (url GnoURL) EncodeWebPath() string { - var urlstr strings.Builder - urlstr.WriteString(url.Path) - if url.Args != "" { - pathEscape := escapeDollarSign(url.Args) - urlstr.WriteString(":" + pathEscape) + // XXX: Arguments should ideally always be escaped, + // but this may require changes in some realms. + args := gnoURL.Args + if !noEscape { + args = escapeDollarSign(url.PathEscape(args)) + } + + urlstr.WriteString(args) } - if len(url.WebQuery) > 0 { - urlstr.WriteString("$" + url.WebQuery.Encode()) + if encodeFlags.Has(EncodeWebQuery) && len(gnoURL.WebQuery) > 0 { + urlstr.WriteRune('$') + if noEscape { + urlstr.WriteString(NoEscapeQuery(gnoURL.WebQuery)) + } else { + urlstr.WriteString(gnoURL.WebQuery.Encode()) + } } - if len(url.Query) > 0 { - urlstr.WriteString("?" + url.Query.Encode()) + if encodeFlags.Has(EncodeQuery) && len(gnoURL.Query) > 0 { + urlstr.WriteRune('?') + if noEscape { + urlstr.WriteString(NoEscapeQuery(gnoURL.Query)) + } else { + urlstr.WriteString(gnoURL.Query.Encode()) + } } return urlstr.String() } -func (url GnoURL) Kind() PathKind { - if len(url.Path) < 2 { - return KindInvalid - } - pk := PathKind(url.Path[1]) - switch pk { - case KindPure, KindRealm: - return pk - } - return KindInvalid +// Has checks if the EncodeFlag contains all the specified flags. +func (f EncodeFlag) Has(flags EncodeFlag) bool { + return f&flags != 0 } -var ( - ErrURLMalformedPath = errors.New("malformed URL path") - ErrURLInvalidPathKind = errors.New("invalid path kind") -) +func escapeDollarSign(s string) string { + return strings.ReplaceAll(s, "$", "%24") +} -// reRealName match a realm path -// - matches[1]: path -// - matches[2]: path args -var reRealmPath = regexp.MustCompile(`^` + - `(/(?:[a-zA-Z0-9_-]+)/` + // path kind - `[a-zA-Z][a-zA-Z0-9_-]*` + // First path segment - `(?:/[a-zA-Z][.a-zA-Z0-9_-]*)*/?)` + // Additional path segments - `([:$](?:.*))?$`, // Remaining portions args, separate by `$` or `:` -) +// EncodeArgs encodes the arguments and query parameters into a string. +// This function is intended to be passed as a realm `Render` argument. +func (gnoURL GnoURL) EncodeArgs() string { + return gnoURL.Encode(EncodeArgs | EncodeQuery | EncodeNoEscape) +} +// EncodeURL encodes the path, arguments, and query parameters into a string. +// This function provides the full representation of the URL without the web query. +func (gnoURL GnoURL) EncodeURL() string { + return gnoURL.Encode(EncodePath | EncodeArgs | EncodeQuery) +} + +// EncodeWebURL encodes the path, package arguments, web query, and query into a string. +// This function provides the full representation of the URL. +func (gnoURL GnoURL) EncodeWebURL() string { + return gnoURL.Encode(EncodePath | EncodeArgs | EncodeWebQuery | EncodeQuery) +} + +// IsPure checks if the URL path represents a pure path. +func (gnoURL GnoURL) IsPure() bool { + return strings.HasPrefix(gnoURL.Path, "/p/") +} + +// IsRealm checks if the URL path represents a realm path. +func (gnoURL GnoURL) IsRealm() bool { + return strings.HasPrefix(gnoURL.Path, "/r/") +} + +// IsFile checks if the URL path represents a file. +func (gnoURL GnoURL) IsFile() bool { + return gnoURL.File != "" +} + +// IsDir checks if the URL path represents a directory. +func (gnoURL GnoURL) IsDir() bool { + return !gnoURL.IsFile() && + len(gnoURL.Path) > 0 && gnoURL.Path[len(gnoURL.Path)-1] == '/' +} + +func (gnoURL GnoURL) IsValid() bool { + return rePkgOrRealmPath.MatchString(gnoURL.Path) +} + +// ParseGnoURL parses a URL into a GnoURL structure, extracting and validating its components. func ParseGnoURL(u *url.URL) (*GnoURL, error) { - matches := reRealmPath.FindStringSubmatch(u.EscapedPath()) - if len(matches) != 3 { - return nil, fmt.Errorf("%w: %s", ErrURLMalformedPath, u.Path) + var webargs string + path, args, found := strings.Cut(u.EscapedPath(), ":") + if found { + args, webargs, _ = strings.Cut(args, "$") + } else { + path, webargs, _ = strings.Cut(path, "$") + } + + upath, err := url.PathUnescape(path) + if err != nil { + return nil, fmt.Errorf("unable to unescape path %q: %w", path, err) } - path := matches[1] - args := matches[2] + var file string - if len(args) > 0 { - switch args[0] { - case ':': - args = args[1:] - case '$': - default: - return nil, fmt.Errorf("%w: %s", ErrURLMalformedPath, u.Path) + // A file is considered as one that either ends with an extension or + // contains an uppercase rune + ext := filepath.Ext(upath) + base := filepath.Base(upath) + if ext != "" || strings.ToLower(base) != base { + file = base + upath = strings.TrimSuffix(upath, base) + + // Trim last slash if any + if i := strings.LastIndexByte(upath, '/'); i > 0 { + upath = upath[:i] } } - var err error + if !rePkgOrRealmPath.MatchString(upath) { + return nil, fmt.Errorf("%w: %q", ErrURLInvalidPath, upath) + } + webquery := url.Values{} - args, webargs, found := strings.Cut(args, "$") - if found { - if webquery, err = url.ParseQuery(webargs); err != nil { - return nil, fmt.Errorf("unable to parse webquery %q: %w ", webquery, err) + if len(webargs) > 0 { + var parseErr error + if webquery, parseErr = url.ParseQuery(webargs); parseErr != nil { + return nil, fmt.Errorf("unable to parse webquery %q: %w", webargs, parseErr) } } uargs, err := url.PathUnescape(args) if err != nil { - return nil, fmt.Errorf("unable to unescape path %q: %w", args, err) + return nil, fmt.Errorf("unable to unescape args %q: %w", args, err) } return &GnoURL{ - Path: path, + Path: upath, Args: uargs, WebQuery: webquery, Query: u.Query(), Domain: u.Hostname(), + File: file, }, nil } -func escapeDollarSign(s string) string { - return strings.ReplaceAll(s, "$", "%24") +// NoEscapeQuery generates a URL-encoded query string from the given url.Values, +// without escaping the keys and values. The query parameters are sorted by key. +func NoEscapeQuery(v url.Values) string { + // Encode encodes the values into “URL encoded” form + // ("bar=baz&foo=quux") sorted by key. + if len(v) == 0 { + return "" + } + var buf strings.Builder + keys := make([]string, 0, len(v)) + for k := range v { + keys = append(keys, k) + } + slices.Sort(keys) + for _, k := range keys { + vs := v[k] + keyEscaped := k + for _, v := range vs { + if buf.Len() > 0 { + buf.WriteByte('&') + } + buf.WriteString(keyEscaped) + buf.WriteByte('=') + buf.WriteString(v) + } + } + return buf.String() } diff --git a/gno.land/pkg/gnoweb/url_test.go b/gno.land/pkg/gnoweb/url_test.go index 73cfdda69bd..7a491eaa149 100644 --- a/gno.land/pkg/gnoweb/url_test.go +++ b/gno.land/pkg/gnoweb/url_test.go @@ -19,8 +19,9 @@ func TestParseGnoURL(t *testing.T) { Name: "malformed url", Input: "https://gno.land/r/dem)o:$?", Expected: nil, - Err: ErrURLMalformedPath, + Err: ErrURLInvalidPath, }, + { Name: "simple", Input: "https://gno.land/r/simple/test", @@ -30,8 +31,32 @@ func TestParseGnoURL(t *testing.T) { WebQuery: url.Values{}, Query: url.Values{}, }, - Err: nil, }, + + { + Name: "file", + Input: "https://gno.land/r/simple/test/encode.gno", + Expected: &GnoURL{ + Domain: "gno.land", + Path: "/r/simple/test", + WebQuery: url.Values{}, + Query: url.Values{}, + File: "encode.gno", + }, + }, + + { + Name: "complex file path", + Input: "https://gno.land/r/simple/test///...gno", + Expected: &GnoURL{ + Domain: "gno.land", + Path: "/r/simple/test//", + WebQuery: url.Values{}, + Query: url.Values{}, + File: "...gno", + }, + }, + { Name: "webquery + query", Input: "https://gno.land/r/demo/foo$help&func=Bar&name=Baz", @@ -46,7 +71,6 @@ func TestParseGnoURL(t *testing.T) { Query: url.Values{}, Domain: "gno.land", }, - Err: nil, }, { @@ -61,7 +85,6 @@ func TestParseGnoURL(t *testing.T) { Query: url.Values{}, Domain: "gno.land", }, - Err: nil, }, { @@ -78,7 +101,6 @@ func TestParseGnoURL(t *testing.T) { }, Domain: "gno.land", }, - Err: nil, }, { @@ -93,7 +115,6 @@ func TestParseGnoURL(t *testing.T) { }, Domain: "gno.land", }, - Err: nil, }, { @@ -108,22 +129,140 @@ func TestParseGnoURL(t *testing.T) { Query: url.Values{}, Domain: "gno.land", }, - Err: nil, }, - // XXX: more tests + { + Name: "unknown path kind", + Input: "https://gno.land/x/demo/foo", + Expected: &GnoURL{ + Path: "/x/demo/foo", + Args: "", + WebQuery: url.Values{}, + Query: url.Values{}, + Domain: "gno.land", + }, + }, + + { + Name: "empty path", + Input: "https://gno.land/r/", + Expected: &GnoURL{ + Path: "/r/", + Args: "", + WebQuery: url.Values{}, + Query: url.Values{}, + Domain: "gno.land", + }, + }, + + { + Name: "complex query", + Input: "https://gno.land/r/demo/foo$help?func=Bar&name=Baz&age=30", + Expected: &GnoURL{ + Path: "/r/demo/foo", + Args: "", + WebQuery: url.Values{ + "help": []string{""}, + }, + Query: url.Values{ + "func": []string{"Bar"}, + "name": []string{"Baz"}, + "age": []string{"30"}, + }, + Domain: "gno.land", + }, + }, + + { + Name: "multiple web queries", + Input: "https://gno.land/r/demo/foo$help&func=Bar$test=123", + Expected: &GnoURL{ + Path: "/r/demo/foo", + Args: "", + WebQuery: url.Values{ + "help": []string{""}, + "func": []string{"Bar$test=123"}, + }, + Query: url.Values{}, + Domain: "gno.land", + }, + }, + + { + Name: "webquery-args-webquery", + Input: "https://gno.land/r/demo/aaa$bbb:CCC&DDD$EEE", + Err: ErrURLInvalidPath, // `/r/demo/aaa$bbb` is an invalid path + }, + + { + Name: "args-webquery-args", + Input: "https://gno.land/r/demo/aaa:BBB$CCC&DDD:EEE", + Expected: &GnoURL{ + Domain: "gno.land", + Path: "/r/demo/aaa", + Args: "BBB", + WebQuery: url.Values{ + "CCC": []string{""}, + "DDD:EEE": []string{""}, + }, + Query: url.Values{}, + }, + }, + + { + Name: "escaped characters in args", + Input: "https://gno.land/r/demo/foo:example%20with%20spaces$tz=Europe/Paris", + Expected: &GnoURL{ + Path: "/r/demo/foo", + Args: "example with spaces", + WebQuery: url.Values{ + "tz": []string{"Europe/Paris"}, + }, + Query: url.Values{}, + Domain: "gno.land", + }, + }, + + { + Name: "file in path + args + query", + Input: "https://gno.land/r/demo/foo/render.gno:example$tz=Europe/Paris", + Expected: &GnoURL{ + Path: "/r/demo/foo", + File: "render.gno", + Args: "example", + WebQuery: url.Values{ + "tz": []string{"Europe/Paris"}, + }, + Query: url.Values{}, + Domain: "gno.land", + }, + }, + + { + Name: "no extension file", + Input: "https://gno.land/r/demo/lIcEnSe", + Expected: &GnoURL{ + Path: "/r/demo", + File: "lIcEnSe", + Args: "", + WebQuery: url.Values{}, + Query: url.Values{}, + Domain: "gno.land", + }, + }, } for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { + t.Logf("testing input: %q", tc.Input) + u, err := url.Parse(tc.Input) require.NoError(t, err) result, err := ParseGnoURL(u) if tc.Err == nil { require.NoError(t, err) - t.Logf("parsed: %s", result.EncodePath()) - t.Logf("parsed web: %s", result.EncodeWebPath()) + t.Logf("encoded web path: %q", result.EncodeWebURL()) } else { require.Error(t, err) require.ErrorIs(t, err, tc.Err) @@ -133,3 +272,191 @@ func TestParseGnoURL(t *testing.T) { }) } } + +func TestEncode(t *testing.T) { + testCases := []struct { + Name string + GnoURL GnoURL + EncodeFlags EncodeFlag + Expected string + }{ + { + Name: "encode domain", + GnoURL: GnoURL{ + Domain: "gno.land", + Path: "/r/demo/foo", + }, + EncodeFlags: EncodeDomain, + Expected: "gno.land", + }, + + { + Name: "encode web query without escape", + GnoURL: GnoURL{ + Domain: "gno.land", + Path: "/r/demo/foo", + WebQuery: url.Values{ + "help": []string{""}, + "fun$c": []string{"B$ ar"}, + }, + }, + EncodeFlags: EncodeWebQuery | EncodeNoEscape, + Expected: "$fun$c=B$ ar&help=", + }, + + { + Name: "encode domain and path", + GnoURL: GnoURL{ + Domain: "gno.land", + Path: "/r/demo/foo", + }, + EncodeFlags: EncodeDomain | EncodePath, + Expected: "gno.land/r/demo/foo", + }, + + { + Name: "Encode Path Only", + GnoURL: GnoURL{ + Path: "/r/demo/foo", + }, + EncodeFlags: EncodePath, + Expected: "/r/demo/foo", + }, + + { + Name: "Encode Path and File", + GnoURL: GnoURL{ + Path: "/r/demo/foo", + File: "render.gno", + }, + EncodeFlags: EncodePath, + Expected: "/r/demo/foo/render.gno", + }, + + { + Name: "Encode Path, File, and Args", + GnoURL: GnoURL{ + Path: "/r/demo/foo", + File: "render.gno", + Args: "example", + }, + EncodeFlags: EncodePath | EncodeArgs, + Expected: "/r/demo/foo/render.gno:example", + }, + + { + Name: "Encode Path and Args", + GnoURL: GnoURL{ + Path: "/r/demo/foo", + Args: "example", + }, + EncodeFlags: EncodePath | EncodeArgs, + Expected: "/r/demo/foo:example", + }, + + { + Name: "Encode Path, Args, and WebQuery", + GnoURL: GnoURL{ + Path: "/r/demo/foo", + Args: "example", + WebQuery: url.Values{ + "tz": []string{"Europe/Paris"}, + }, + }, + EncodeFlags: EncodePath | EncodeArgs | EncodeWebQuery, + Expected: "/r/demo/foo:example$tz=Europe%2FParis", + }, + + { + Name: "Encode Full URL", + GnoURL: GnoURL{ + Path: "/r/demo/foo", + Args: "example", + WebQuery: url.Values{ + "tz": []string{"Europe/Paris"}, + }, + Query: url.Values{ + "hello": []string{"42"}, + }, + }, + EncodeFlags: EncodePath | EncodeArgs | EncodeWebQuery | EncodeQuery, + Expected: "/r/demo/foo:example$tz=Europe%2FParis?hello=42", + }, + + { + Name: "Encode Args and Query", + GnoURL: GnoURL{ + Path: "/r/demo/foo", + Args: "hello Jo$ny", + Query: url.Values{ + "hello": []string{"42"}, + }, + }, + EncodeFlags: EncodeArgs | EncodeQuery, + Expected: "hello%20Jo%24ny?hello=42", + }, + + { + Name: "Encode Args and Query (No Escape)", + GnoURL: GnoURL{ + Path: "/r/demo/foo", + Args: "hello Jo$ny", + Query: url.Values{ + "hello": []string{"42"}, + }, + }, + EncodeFlags: EncodeArgs | EncodeQuery | EncodeNoEscape, + Expected: "hello Jo$ny?hello=42", + }, + + { + Name: "Encode Args and Query", + GnoURL: GnoURL{ + Path: "/r/demo/foo", + Args: "example", + Query: url.Values{ + "hello": []string{"42"}, + }, + }, + EncodeFlags: EncodeArgs | EncodeQuery, + Expected: "example?hello=42", + }, + + { + Name: "Encode with Escaped Characters", + GnoURL: GnoURL{ + Path: "/r/demo/foo", + Args: "example with spaces", + WebQuery: url.Values{ + "tz": []string{"Europe/Paris"}, + }, + Query: url.Values{ + "hello": []string{"42"}, + }, + }, + EncodeFlags: EncodePath | EncodeArgs | EncodeWebQuery | EncodeQuery, + Expected: "/r/demo/foo:example%20with%20spaces$tz=Europe%2FParis?hello=42", + }, + + { + Name: "Encode Path, Args, and Query", + GnoURL: GnoURL{ + Path: "/r/demo/foo", + Args: "example", + Query: url.Values{ + "hello": []string{"42"}, + }, + }, + EncodeFlags: EncodePath | EncodeArgs | EncodeQuery, + Expected: "/r/demo/foo:example?hello=42", + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + result := tc.GnoURL.Encode(tc.EncodeFlags) + require.True(t, tc.GnoURL.IsValid(), "gno url is not valid") + assert.Equal(t, tc.Expected, result) + }) + } +} diff --git a/gno.land/pkg/integration/doc.go b/gno.land/pkg/integration/doc.go index ef3ed9923da..3e09d627c9a 100644 --- a/gno.land/pkg/integration/doc.go +++ b/gno.land/pkg/integration/doc.go @@ -76,13 +76,6 @@ // // Input: // -// - LOG_LEVEL: -// The logging level to be used, which can be one of "error", "debug", "info", or an empty string. -// If empty, the log level defaults to "debug". -// -// - LOG_DIR: -// If set, logs will be directed to the specified directory. -// // - TESTWORK: // A boolean that, when enabled, retains working directories after tests for // inspection. If enabled, gnoland logs will be persisted inside this diff --git a/gno.land/pkg/integration/testing_node.go b/gno.land/pkg/integration/node_testing.go similarity index 86% rename from gno.land/pkg/integration/testing_node.go rename to gno.land/pkg/integration/node_testing.go index 7eaf3457b03..7965f228fc2 100644 --- a/gno.land/pkg/integration/testing_node.go +++ b/gno.land/pkg/integration/node_testing.go @@ -55,7 +55,8 @@ func TestingInMemoryNode(t TestingTS, logger *slog.Logger, config *gnoland.InMem // with default packages and genesis transactions already loaded. // It will return the default creator address of the loaded packages. func TestingNodeConfig(t TestingTS, gnoroot string, additionalTxs ...gnoland.TxWithMetadata) (*gnoland.InMemoryNodeConfig, bft.Address) { - cfg := TestingMinimalNodeConfig(t, gnoroot) + cfg := TestingMinimalNodeConfig(gnoroot) + cfg.SkipGenesisVerification = true creator := crypto.MustAddressFromString(DefaultAccount_Address) // test1 @@ -65,24 +66,24 @@ func TestingNodeConfig(t TestingTS, gnoroot string, additionalTxs ...gnoland.TxW txs = append(txs, LoadDefaultPackages(t, creator, gnoroot)...) txs = append(txs, additionalTxs...) - ggs := cfg.Genesis.AppState.(gnoland.GnoGenesisState) - ggs.Balances = balances - ggs.Txs = txs - ggs.Params = params - cfg.Genesis.AppState = ggs + cfg.Genesis.AppState = gnoland.GnoGenesisState{ + Balances: balances, + Txs: txs, + Params: params, + } return cfg, creator } // TestingMinimalNodeConfig constructs the default minimal in-memory node configuration for testing. -func TestingMinimalNodeConfig(t TestingTS, gnoroot string) *gnoland.InMemoryNodeConfig { +func TestingMinimalNodeConfig(gnoroot string) *gnoland.InMemoryNodeConfig { tmconfig := DefaultTestingTMConfig(gnoroot) // Create Mocked Identity pv := gnoland.NewMockedPrivValidator() // Generate genesis config - genesis := DefaultTestingGenesisConfig(t, gnoroot, pv.GetPubKey(), tmconfig) + genesis := DefaultTestingGenesisConfig(gnoroot, pv.GetPubKey(), tmconfig) return &gnoland.InMemoryNodeConfig{ PrivValidator: pv, @@ -96,16 +97,7 @@ func TestingMinimalNodeConfig(t TestingTS, gnoroot string) *gnoland.InMemoryNode } } -func DefaultTestingGenesisConfig(t TestingTS, gnoroot string, self crypto.PubKey, tmconfig *tmcfg.Config) *bft.GenesisDoc { - genState := gnoland.DefaultGenState() - genState.Balances = []gnoland.Balance{ - { - Address: crypto.MustAddressFromString(DefaultAccount_Address), - Amount: std.MustParseCoins(ugnot.ValueString(10000000000000)), - }, - } - genState.Txs = []gnoland.TxWithMetadata{} - genState.Params = []gnoland.Param{} +func DefaultTestingGenesisConfig(gnoroot string, self crypto.PubKey, tmconfig *tmcfg.Config) *bft.GenesisDoc { return &bft.GenesisDoc{ GenesisTime: time.Now(), ChainID: tmconfig.ChainID(), @@ -125,7 +117,16 @@ func DefaultTestingGenesisConfig(t TestingTS, gnoroot string, self crypto.PubKey Name: "self", }, }, - AppState: genState, + AppState: gnoland.GnoGenesisState{ + Balances: []gnoland.Balance{ + { + Address: crypto.MustAddressFromString(DefaultAccount_Address), + Amount: std.MustParseCoins(ugnot.ValueString(10_000_000_000_000)), + }, + }, + Txs: []gnoland.TxWithMetadata{}, + Params: []gnoland.Param{}, + }, } } @@ -178,8 +179,9 @@ func DefaultTestingTMConfig(gnoroot string) *tmcfg.Config { tmconfig := tmcfg.TestConfig().SetRootDir(gnoroot) tmconfig.Consensus.WALDisabled = true + tmconfig.Consensus.SkipTimeoutCommit = true tmconfig.Consensus.CreateEmptyBlocks = true - tmconfig.Consensus.CreateEmptyBlocksInterval = time.Duration(0) + tmconfig.Consensus.CreateEmptyBlocksInterval = time.Millisecond * 100 tmconfig.RPC.ListenAddress = defaultListner tmconfig.P2P.ListenAddress = defaultListner return tmconfig diff --git a/gno.land/pkg/integration/pkgloader.go b/gno.land/pkg/integration/pkgloader.go new file mode 100644 index 00000000000..e40e8ff1eb5 --- /dev/null +++ b/gno.land/pkg/integration/pkgloader.go @@ -0,0 +1,174 @@ +package integration + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/gnolang/gno/gno.land/pkg/gnoland" + "github.com/gnolang/gno/gno.land/pkg/sdk/vm" + "github.com/gnolang/gno/gnovm/pkg/gnolang" + "github.com/gnolang/gno/gnovm/pkg/gnomod" + "github.com/gnolang/gno/gnovm/pkg/packages" + "github.com/gnolang/gno/tm2/pkg/crypto" + "github.com/gnolang/gno/tm2/pkg/std" +) + +type PkgsLoader struct { + pkgs []gnomod.Pkg + visited map[string]struct{} + + // list of occurrences to patchs with the given value + // XXX: find a better way + patchs map[string]string +} + +func NewPkgsLoader() *PkgsLoader { + return &PkgsLoader{ + pkgs: make([]gnomod.Pkg, 0), + visited: make(map[string]struct{}), + patchs: make(map[string]string), + } +} + +func (pl *PkgsLoader) List() gnomod.PkgList { + return pl.pkgs +} + +func (pl *PkgsLoader) SetPatch(replace, with string) { + pl.patchs[replace] = with +} + +func (pl *PkgsLoader) LoadPackages(creatorKey crypto.PrivKey, fee std.Fee, deposit std.Coins) ([]gnoland.TxWithMetadata, error) { + pkgslist, err := pl.List().Sort() // sorts packages by their dependencies. + if err != nil { + return nil, fmt.Errorf("unable to sort packages: %w", err) + } + + txs := make([]gnoland.TxWithMetadata, len(pkgslist)) + for i, pkg := range pkgslist { + tx, err := gnoland.LoadPackage(pkg, creatorKey.PubKey().Address(), fee, deposit) + if err != nil { + return nil, fmt.Errorf("unable to load pkg %q: %w", pkg.Name, err) + } + + // If any replace value is specified, apply them + if len(pl.patchs) > 0 { + for _, msg := range tx.Msgs { + addpkg, ok := msg.(vm.MsgAddPackage) + if !ok { + continue + } + + if addpkg.Package == nil { + continue + } + + for _, file := range addpkg.Package.Files { + for replace, with := range pl.patchs { + file.Body = strings.ReplaceAll(file.Body, replace, with) + } + } + } + } + + txs[i] = gnoland.TxWithMetadata{ + Tx: tx, + } + } + + err = SignTxs(txs, creatorKey, "tendermint_test") + if err != nil { + return nil, fmt.Errorf("unable to sign txs: %w", err) + } + + return txs, nil +} + +func (pl *PkgsLoader) LoadAllPackagesFromDir(path string) error { + // list all packages from target path + pkgslist, err := gnomod.ListPkgs(path) + if err != nil { + return fmt.Errorf("listing gno packages: %w", err) + } + + for _, pkg := range pkgslist { + if !pl.exist(pkg) { + pl.add(pkg) + } + } + + return nil +} + +func (pl *PkgsLoader) LoadPackage(modroot string, path, name string) error { + // Initialize a queue with the root package + queue := []gnomod.Pkg{{Dir: path, Name: name}} + + for len(queue) > 0 { + // Dequeue the first package + currentPkg := queue[0] + queue = queue[1:] + + if currentPkg.Dir == "" { + return fmt.Errorf("no path specified for package") + } + + if currentPkg.Name == "" { + // Load `gno.mod` information + gnoModPath := filepath.Join(currentPkg.Dir, "gno.mod") + gm, err := gnomod.ParseGnoMod(gnoModPath) + if err != nil { + return fmt.Errorf("unable to load %q: %w", gnoModPath, err) + } + gm.Sanitize() + + // Override package info with mod infos + currentPkg.Name = gm.Module.Mod.Path + currentPkg.Draft = gm.Draft + + pkg, err := gnolang.ReadMemPackage(currentPkg.Dir, currentPkg.Name) + if err != nil { + return fmt.Errorf("unable to read package at %q: %w", currentPkg.Dir, err) + } + importsMap, err := packages.Imports(pkg, nil) + if err != nil { + return fmt.Errorf("unable to load package imports in %q: %w", currentPkg.Dir, err) + } + imports := importsMap.Merge(packages.FileKindPackageSource, packages.FileKindTest, packages.FileKindFiletest) + for _, imp := range imports { + if imp.PkgPath == currentPkg.Name || gnolang.IsStdlib(imp.PkgPath) { + continue + } + currentPkg.Imports = append(currentPkg.Imports, imp.PkgPath) + } + } + + if currentPkg.Draft { + continue // Skip draft package + } + + if pl.exist(currentPkg) { + continue + } + pl.add(currentPkg) + + // Add requirements to the queue + for _, pkgPath := range currentPkg.Imports { + fullPath := filepath.Join(modroot, pkgPath) + queue = append(queue, gnomod.Pkg{Dir: fullPath}) + } + } + + return nil +} + +func (pl *PkgsLoader) add(pkg gnomod.Pkg) { + pl.visited[pkg.Name] = struct{}{} + pl.pkgs = append(pl.pkgs, pkg) +} + +func (pl *PkgsLoader) exist(pkg gnomod.Pkg) (ok bool) { + _, ok = pl.visited[pkg.Name] + return +} diff --git a/gno.land/pkg/integration/process.go b/gno.land/pkg/integration/process.go new file mode 100644 index 00000000000..839004ca1f3 --- /dev/null +++ b/gno.land/pkg/integration/process.go @@ -0,0 +1,451 @@ +package integration + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "os/signal" + "slices" + "sync" + "testing" + "time" + + "github.com/gnolang/gno/gno.land/pkg/gnoland" + "github.com/gnolang/gno/tm2/pkg/amino" + tmcfg "github.com/gnolang/gno/tm2/pkg/bft/config" + bft "github.com/gnolang/gno/tm2/pkg/bft/types" + "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" + "github.com/gnolang/gno/tm2/pkg/db" + "github.com/gnolang/gno/tm2/pkg/db/goleveldb" + "github.com/gnolang/gno/tm2/pkg/db/memdb" + "github.com/stretchr/testify/require" +) + +const gracefulShutdown = time.Second * 5 + +type ProcessNodeConfig struct { + ValidatorKey ed25519.PrivKeyEd25519 `json:"priv"` + Verbose bool `json:"verbose"` + DBDir string `json:"dbdir"` + RootDir string `json:"rootdir"` + Genesis *MarshalableGenesisDoc `json:"genesis"` + TMConfig *tmcfg.Config `json:"tm"` +} + +type ProcessConfig struct { + Node *ProcessNodeConfig + + // These parameters are not meant to be passed to the process + CoverDir string + Stderr, Stdout io.Writer +} + +func (i ProcessConfig) validate() error { + if i.Node.TMConfig == nil { + return errors.New("no tm config set") + } + + if i.Node.Genesis == nil { + return errors.New("no genesis is set") + } + + return nil +} + +// RunNode initializes and runs a gnoaland node with the provided configuration. +func RunNode(ctx context.Context, pcfg *ProcessNodeConfig, stdout, stderr io.Writer) error { + // Setup logger based on verbosity + var handler slog.Handler + if pcfg.Verbose { + handler = slog.NewTextHandler(stdout, &slog.HandlerOptions{Level: slog.LevelDebug}) + } else { + handler = slog.NewTextHandler(stdout, &slog.HandlerOptions{Level: slog.LevelError}) + } + logger := slog.New(handler) + + // Initialize database + db, err := initDatabase(pcfg.DBDir) + if err != nil { + return err + } + defer db.Close() // ensure db is close + + nodecfg := TestingMinimalNodeConfig(pcfg.RootDir) + + // Configure validator if provided + if len(pcfg.ValidatorKey) > 0 && !isAllZero(pcfg.ValidatorKey) { + nodecfg.PrivValidator = bft.NewMockPVWithParams(pcfg.ValidatorKey, false, false) + } + pv := nodecfg.PrivValidator.GetPubKey() + + // Setup node configuration + nodecfg.DB = db + nodecfg.TMConfig.DBPath = pcfg.DBDir + nodecfg.TMConfig = pcfg.TMConfig + nodecfg.Genesis = pcfg.Genesis.ToGenesisDoc() + nodecfg.Genesis.Validators = []bft.GenesisValidator{ + { + Address: pv.Address(), + PubKey: pv, + Power: 10, + Name: "self", + }, + } + + // Create and start the node + node, err := gnoland.NewInMemoryNode(logger, nodecfg) + if err != nil { + return fmt.Errorf("failed to create new in-memory node: %w", err) + } + + if err := node.Start(); err != nil { + return fmt.Errorf("failed to start node: %w", err) + } + defer node.Stop() + + // Determine if the node is a validator + ourAddress := nodecfg.PrivValidator.GetPubKey().Address() + isValidator := slices.ContainsFunc(nodecfg.Genesis.Validators, func(val bft.GenesisValidator) bool { + return val.Address == ourAddress + }) + + lisnAddress := node.Config().RPC.ListenAddress + if isValidator { + select { + case <-ctx.Done(): + return fmt.Errorf("waiting for the node to start: %w", ctx.Err()) + case <-node.Ready(): + } + } + + // Write READY signal to stdout + signalWriteReady(stdout, lisnAddress) + + <-ctx.Done() + return node.Stop() +} + +type NodeProcess interface { + Stop() error + Address() string +} + +type nodeProcess struct { + cmd *exec.Cmd + address string + + stopOnce sync.Once + stopErr error +} + +func (n *nodeProcess) Address() string { + return n.address +} + +func (n *nodeProcess) Stop() error { + n.stopOnce.Do(func() { + // Send SIGTERM to the process + if err := n.cmd.Process.Signal(os.Interrupt); err != nil { + n.stopErr = fmt.Errorf("error sending `SIGINT` to the node: %w", err) + return + } + + // Optionally wait for the process to exit + if _, err := n.cmd.Process.Wait(); err != nil { + n.stopErr = fmt.Errorf("process exited with error: %w", err) + return + } + }) + + return n.stopErr +} + +// RunNodeProcess runs the binary at the given path with the provided configuration. +func RunNodeProcess(ctx context.Context, cfg ProcessConfig, name string, args ...string) (NodeProcess, error) { + if cfg.Stdout == nil { + cfg.Stdout = os.Stdout + } + + if cfg.Stderr == nil { + cfg.Stderr = os.Stderr + } + + if err := cfg.validate(); err != nil { + return nil, err + } + + // Marshal the configuration to JSON + nodeConfigData, err := json.Marshal(cfg.Node) + if err != nil { + return nil, fmt.Errorf("failed to marshal config to JSON: %w", err) + } + + // Create and configure the command to execute the binary + cmd := exec.Command(name, args...) + cmd.Env = os.Environ() + cmd.Stdin = bytes.NewReader(nodeConfigData) + + if cfg.CoverDir != "" { + cmd.Env = append(cmd.Env, "GOCOVERDIR="+cfg.CoverDir) + } + + // Redirect all errors into a buffer + cmd.Stderr = os.Stderr + if cfg.Stderr != nil { + cmd.Stderr = cfg.Stderr + } + + // Create pipes for stdout + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdout pipe: %w", err) + } + + // Start the command + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start command: %w", err) + } + + address, err := waitForProcessReady(ctx, stdoutPipe, cfg.Stdout) + if err != nil { + return nil, fmt.Errorf("waiting for readiness: %w", err) + } + + return &nodeProcess{ + cmd: cmd, + address: address, + }, nil +} + +type nodeInMemoryProcess struct { + address string + + stopOnce sync.Once + stopErr error + stop context.CancelFunc + ccNodeError chan error +} + +func (n *nodeInMemoryProcess) Address() string { + return n.address +} + +func (n *nodeInMemoryProcess) Stop() error { + n.stopOnce.Do(func() { + n.stop() + var err error + select { + case err = <-n.ccNodeError: + case <-time.After(time.Second * 5): + err = fmt.Errorf("timeout while waiting for node to stop") + } + + if err != nil { + n.stopErr = fmt.Errorf("unable to node gracefully: %w", err) + } + }) + + return n.stopErr +} + +func RunInMemoryProcess(ctx context.Context, cfg ProcessConfig) (NodeProcess, error) { + ctx, cancel := context.WithCancel(ctx) + + out, in := io.Pipe() + ccStopErr := make(chan error, 1) + go func() { + defer close(ccStopErr) + defer cancel() + + err := RunNode(ctx, cfg.Node, in, cfg.Stderr) + if err != nil { + fmt.Fprintf(cfg.Stderr, "run node failed: %v", err) + } + + ccStopErr <- err + }() + + address, err := waitForProcessReady(ctx, out, cfg.Stdout) + if err == nil { // ok + return &nodeInMemoryProcess{ + address: address, + stop: cancel, + ccNodeError: ccStopErr, + }, nil + } + + cancel() + + select { + case err = <-ccStopErr: // return node error in priority + default: + } + + return nil, err +} + +func RunMain(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer) error { + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) + defer stop() + + // Read the configuration from standard input + configData, err := io.ReadAll(stdin) + if err != nil { + // log.Fatalf("error reading stdin: %v", err) + return fmt.Errorf("error reading stdin: %w", err) + } + + // Unmarshal the JSON configuration + var cfg ProcessNodeConfig + if err := json.Unmarshal(configData, &cfg); err != nil { + return fmt.Errorf("error unmarshaling JSON: %w", err) + // log.Fatalf("error unmarshaling JSON: %v", err) + } + + // Run the node + ccErr := make(chan error, 1) + go func() { + ccErr <- RunNode(ctx, &cfg, stdout, stderr) + close(ccErr) + }() + + // Wait for the node to gracefully terminate + <-ctx.Done() + + // Attempt graceful shutdown + select { + case <-time.After(gracefulShutdown): + return fmt.Errorf("unable to gracefully stop the node, exiting now") + case err = <-ccErr: // done + } + + return err +} + +func runTestingNodeProcess(t TestingTS, ctx context.Context, pcfg ProcessConfig) NodeProcess { + bin, err := os.Executable() + require.NoError(t, err) + args := []string{ + "-test.run=^$", + "-run-node-process", + } + + if pcfg.CoverDir != "" && testing.CoverMode() != "" { + args = append(args, "-test.gocoverdir="+pcfg.CoverDir) + } + + node, err := RunNodeProcess(ctx, pcfg, bin, args...) + require.NoError(t, err) + + return node +} + +// initDatabase initializes the database based on the provided directory configuration. +func initDatabase(dbDir string) (db.DB, error) { + if dbDir == "" { + return memdb.NewMemDB(), nil + } + + data, err := goleveldb.NewGoLevelDB("testdb", dbDir) + if err != nil { + return nil, fmt.Errorf("unable to init database in %q: %w", dbDir, err) + } + + return data, nil +} + +func signalWriteReady(w io.Writer, address string) error { + _, err := fmt.Fprintf(w, "READY:%s\n", address) + return err +} + +func signalReadReady(line string) (string, bool) { + var address string + if _, err := fmt.Sscanf(line, "READY:%s", &address); err == nil { + return address, true + } + return "", false +} + +// waitForProcessReady waits for the process to signal readiness and returns the address. +func waitForProcessReady(ctx context.Context, stdoutPipe io.Reader, out io.Writer) (string, error) { + var address string + + cReady := make(chan error, 2) + go func() { + defer close(cReady) + + scanner := bufio.NewScanner(stdoutPipe) + ready := false + for scanner.Scan() { + line := scanner.Text() + + if !ready { + if addr, ok := signalReadReady(line); ok { + address = addr + ready = true + cReady <- nil + } + } + + fmt.Fprintln(out, line) + } + + if err := scanner.Err(); err != nil { + cReady <- fmt.Errorf("error reading stdout: %w", err) + } else { + cReady <- fmt.Errorf("process exited without 'READY'") + } + }() + + select { + case err := <-cReady: + return address, err + case <-ctx.Done(): + return "", ctx.Err() + } +} + +// isAllZero checks if a 64-byte key consists entirely of zeros. +func isAllZero(key [64]byte) bool { + for _, v := range key { + if v != 0 { + return false + } + } + return true +} + +type MarshalableGenesisDoc bft.GenesisDoc + +func NewMarshalableGenesisDoc(doc *bft.GenesisDoc) *MarshalableGenesisDoc { + m := MarshalableGenesisDoc(*doc) + return &m +} + +func (m *MarshalableGenesisDoc) MarshalJSON() ([]byte, error) { + doc := (*bft.GenesisDoc)(m) + return amino.MarshalJSON(doc) +} + +func (m *MarshalableGenesisDoc) UnmarshalJSON(data []byte) (err error) { + doc, err := bft.GenesisDocFromJSON(data) + if err != nil { + return err + } + + *m = MarshalableGenesisDoc(*doc) + return +} + +// Cast back to the original bft.GenesisDoc. +func (m *MarshalableGenesisDoc) ToGenesisDoc() *bft.GenesisDoc { + return (*bft.GenesisDoc)(m) +} diff --git a/gno.land/pkg/integration/process/main.go b/gno.land/pkg/integration/process/main.go new file mode 100644 index 00000000000..bcd52e6fd44 --- /dev/null +++ b/gno.land/pkg/integration/process/main.go @@ -0,0 +1,20 @@ +package main + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/gnolang/gno/gno.land/pkg/integration" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + if err := integration.RunMain(ctx, os.Stdin, os.Stdout, os.Stderr); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } +} diff --git a/gno.land/pkg/integration/process_test.go b/gno.land/pkg/integration/process_test.go new file mode 100644 index 00000000000..b8768ad0e63 --- /dev/null +++ b/gno.land/pkg/integration/process_test.go @@ -0,0 +1,144 @@ +package integration + +import ( + "bytes" + "context" + "flag" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/gnolang/gno/gnovm/pkg/gnoenv" + "github.com/gnolang/gno/tm2/pkg/bft/rpc/client" + "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Define a flag to indicate whether to run the embedded command +var runCommand = flag.Bool("run-node-process", false, "execute the embedded command") + +func TestMain(m *testing.M) { + flag.Parse() + + // Check if the embedded command should be executed + if !*runCommand { + fmt.Println("Running tests...") + os.Exit(m.Run()) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + if err := RunMain(ctx, os.Stdin, os.Stdout, os.Stderr); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } +} + +// TestGnolandIntegration tests the forking of a Gnoland node. +func TestNodeProcess(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) + defer cancel() + + gnoRootDir := gnoenv.RootDir() + + // Define paths for the build directory and the gnoland binary + gnolandDBDir := filepath.Join(t.TempDir(), "db") + + // Prepare a minimal node configuration for testing + cfg := TestingMinimalNodeConfig(gnoRootDir) + + var stdio bytes.Buffer + defer func() { + t.Log("node output:") + t.Log(stdio.String()) + }() + + start := time.Now() + node := runTestingNodeProcess(t, ctx, ProcessConfig{ + Stderr: &stdio, Stdout: &stdio, + Node: &ProcessNodeConfig{ + Verbose: true, + ValidatorKey: ed25519.GenPrivKey(), + DBDir: gnolandDBDir, + RootDir: gnoRootDir, + TMConfig: cfg.TMConfig, + Genesis: NewMarshalableGenesisDoc(cfg.Genesis), + }, + }) + t.Logf("time to start the node: %v", time.Since(start).String()) + + // Create a new HTTP client to interact with the integration node + cli, err := client.NewHTTPClient(node.Address()) + require.NoError(t, err) + + // Retrieve node info + info, err := cli.ABCIInfo() + require.NoError(t, err) + assert.NotEmpty(t, info.Response.Data) + + // Attempt to stop the node + err = node.Stop() + require.NoError(t, err) + + // Attempt to stop the node a second time, should not fail + err = node.Stop() + require.NoError(t, err) +} + +// TestGnolandIntegration tests the forking of a Gnoland node. +func TestInMemoryNodeProcess(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) + defer cancel() + + gnoRootDir := gnoenv.RootDir() + + // Define paths for the build directory and the gnoland binary + gnolandDBDir := filepath.Join(t.TempDir(), "db") + + // Prepare a minimal node configuration for testing + cfg := TestingMinimalNodeConfig(gnoRootDir) + + var stdio bytes.Buffer + defer func() { + t.Log("node output:") + t.Log(stdio.String()) + }() + + start := time.Now() + node, err := RunInMemoryProcess(ctx, ProcessConfig{ + Stderr: &stdio, Stdout: &stdio, + Node: &ProcessNodeConfig{ + Verbose: true, + ValidatorKey: ed25519.GenPrivKey(), + DBDir: gnolandDBDir, + RootDir: gnoRootDir, + TMConfig: cfg.TMConfig, + Genesis: NewMarshalableGenesisDoc(cfg.Genesis), + }, + }) + require.NoError(t, err) + t.Logf("time to start the node: %v", time.Since(start).String()) + + // Create a new HTTP client to interact with the integration node + cli, err := client.NewHTTPClient(node.Address()) + require.NoError(t, err) + + // Retrieve node info + info, err := cli.ABCIInfo() + require.NoError(t, err) + assert.NotEmpty(t, info.Response.Data) + + // Attempt to stop the node + err = node.Stop() + require.NoError(t, err) + + // Attempt to stop the node a second time, should not fail + err = node.Stop() + require.NoError(t, err) +} diff --git a/gno.land/pkg/integration/signer.go b/gno.land/pkg/integration/signer.go new file mode 100644 index 00000000000..b32cd9c59bc --- /dev/null +++ b/gno.land/pkg/integration/signer.go @@ -0,0 +1,33 @@ +package integration + +import ( + "fmt" + + "github.com/gnolang/gno/gno.land/pkg/gnoland" + + "github.com/gnolang/gno/tm2/pkg/crypto" + "github.com/gnolang/gno/tm2/pkg/std" +) + +// SignTxs will sign all txs passed as argument using the private key +// this signature is only valid for genesis transactions as accountNumber and sequence are 0 +func SignTxs(txs []gnoland.TxWithMetadata, privKey crypto.PrivKey, chainID string) error { + for index, tx := range txs { + bytes, err := tx.Tx.GetSignBytes(chainID, 0, 0) + if err != nil { + return fmt.Errorf("unable to get sign bytes for transaction, %w", err) + } + signature, err := privKey.Sign(bytes) + if err != nil { + return fmt.Errorf("unable to sign transaction, %w", err) + } + + txs[index].Tx.Signatures = []std.Signature{ + { + PubKey: privKey.PubKey(), + Signature: signature, + }, + } + } + return nil +} diff --git a/gno.land/pkg/integration/testdata/event_multi_msg.txtar b/gno.land/pkg/integration/testdata/event_multi_msg.txtar index 84afe3cc6a4..13a448e7f8c 100644 --- a/gno.land/pkg/integration/testdata/event_multi_msg.txtar +++ b/gno.land/pkg/integration/testdata/event_multi_msg.txtar @@ -11,16 +11,19 @@ stdout 'data: {' stdout ' "BaseAccount": {' stdout ' "address": "g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5",' stdout ' "coins": "[0-9]*ugnot",' # dynamic -stdout ' "public_key": null,' +stdout ' "public_key": {' +stdout ' "@type": "/tm.PubKeySecp256k1",' +stdout ' "value": "A\+FhNtsXHjLfSJk1lB8FbiL4mGPjc50Kt81J7EKDnJ2y"' +stdout ' },' stdout ' "account_number": "0",' -stdout ' "sequence": "0"' +stdout ' "sequence": "1"' stdout ' }' stdout '}' ! stderr '.+' # empty ## sign -gnokey sign -tx-path $WORK/multi/multi_msg.tx -chainid=tendermint_test -account-number 0 -account-sequence 0 test1 +gnokey sign -tx-path $WORK/multi/multi_msg.tx -chainid=tendermint_test -account-number 0 -account-sequence 1 test1 stdout 'Tx successfully signed and saved to ' ## broadcast diff --git a/gno.land/pkg/integration/testdata/gnokey_simulate.txtar b/gno.land/pkg/integration/testdata/gnokey_simulate.txtar index 8db2c7302fc..db3cd527eb3 100644 --- a/gno.land/pkg/integration/testdata/gnokey_simulate.txtar +++ b/gno.land/pkg/integration/testdata/gnokey_simulate.txtar @@ -7,41 +7,41 @@ gnoland start # Initial state: assert that sequence == 0. gnokey query auth/accounts/$USER_ADDR_test1 -stdout '"sequence": "0"' +stdout '"sequence": "1"' # attempt adding the "test" package. # the package has a syntax error; simulation should catch this ahead of time and prevent the tx. # -simulate test ! gnokey maketx addpkg -pkgdir $WORK/test -pkgpath gno.land/r/test -gas-fee 1000000ugnot -gas-wanted 2000000 -broadcast -chainid=tendermint_test -simulate test test1 gnokey query auth/accounts/$USER_ADDR_test1 -stdout '"sequence": "0"' +stdout '"sequence": "1"' # -simulate only ! gnokey maketx addpkg -pkgdir $WORK/test -pkgpath gno.land/r/test -gas-fee 1000000ugnot -gas-wanted 2000000 -broadcast -chainid=tendermint_test -simulate only test1 gnokey query auth/accounts/$USER_ADDR_test1 -stdout '"sequence": "0"' +stdout '"sequence": "1"' # -simulate skip ! gnokey maketx addpkg -pkgdir $WORK/test -pkgpath gno.land/r/test -gas-fee 1000000ugnot -gas-wanted 2000000 -broadcast -chainid=tendermint_test -simulate skip test1 gnokey query auth/accounts/$USER_ADDR_test1 -stdout '"sequence": "1"' +stdout '"sequence": "2"' # attempt calling hello.SetName correctly. # -simulate test and skip should do it successfully, -simulate only should not. # -simulate test gnokey maketx call -pkgpath gno.land/r/hello -func SetName -args John -gas-wanted 2000000 -gas-fee 1000000ugnot -broadcast -chainid tendermint_test -simulate test test1 gnokey query auth/accounts/$USER_ADDR_test1 -stdout '"sequence": "2"' +stdout '"sequence": "3"' gnokey query vm/qeval --data "gno.land/r/hello.Hello()" stdout 'Hello, John!' # -simulate only gnokey maketx call -pkgpath gno.land/r/hello -func SetName -args Paul -gas-wanted 2000000 -gas-fee 1000000ugnot -broadcast -chainid tendermint_test -simulate only test1 gnokey query auth/accounts/$USER_ADDR_test1 -stdout '"sequence": "2"' +stdout '"sequence": "3"' gnokey query vm/qeval --data "gno.land/r/hello.Hello()" stdout 'Hello, John!' # -simulate skip gnokey maketx call -pkgpath gno.land/r/hello -func SetName -args George -gas-wanted 2000000 -gas-fee 1000000ugnot -broadcast -chainid tendermint_test -simulate skip test1 gnokey query auth/accounts/$USER_ADDR_test1 -stdout '"sequence": "3"' +stdout '"sequence": "4"' gnokey query vm/qeval --data "gno.land/r/hello.Hello()" stdout 'Hello, George!' @@ -51,19 +51,19 @@ stdout 'Hello, George!' # -simulate test ! gnokey maketx call -pkgpath gno.land/r/hello -func Grumpy -gas-wanted 2000000 -gas-fee 1000000ugnot -broadcast -chainid tendermint_test -simulate test test1 gnokey query auth/accounts/$USER_ADDR_test1 -stdout '"sequence": "3"' +stdout '"sequence": "4"' gnokey query vm/qeval --data "gno.land/r/hello.Hello()" stdout 'Hello, George!' # -simulate only ! gnokey maketx call -pkgpath gno.land/r/hello -func Grumpy -gas-wanted 2000000 -gas-fee 1000000ugnot -broadcast -chainid tendermint_test -simulate only test1 gnokey query auth/accounts/$USER_ADDR_test1 -stdout '"sequence": "3"' +stdout '"sequence": "4"' gnokey query vm/qeval --data "gno.land/r/hello.Hello()" stdout 'Hello, George!' # -simulate skip ! gnokey maketx call -pkgpath gno.land/r/hello -func Grumpy -gas-wanted 2000000 -gas-fee 1000000ugnot -broadcast -chainid tendermint_test -simulate skip test1 gnokey query auth/accounts/$USER_ADDR_test1 -stdout '"sequence": "4"' +stdout '"sequence": "5"' gnokey query vm/qeval --data "gno.land/r/hello.Hello()" stdout 'Hello, George!' diff --git a/gno.land/pkg/integration/testdata/gnoland.txtar b/gno.land/pkg/integration/testdata/gnoland.txtar index 78bdc9cae4e..83c8fe9c9a5 100644 --- a/gno.land/pkg/integration/testdata/gnoland.txtar +++ b/gno.land/pkg/integration/testdata/gnoland.txtar @@ -28,7 +28,7 @@ cmp stderr gnoland-already-stop.stderr.golden -- gnoland-no-arguments.stdout.golden -- -- gnoland-no-arguments.stderr.golden -- -"gnoland" error: syntax: gnoland [start|stop|restart] +"gnoland" error: no command provided -- gnoland-start.stdout.golden -- node started successfully -- gnoland-start.stderr.golden -- diff --git a/gno.land/pkg/integration/testdata/gnoweb_airgapped.txtar b/gno.land/pkg/integration/testdata/gnoweb_airgapped.txtar index 3ed35a1b1d3..02bd8058214 100644 --- a/gno.land/pkg/integration/testdata/gnoweb_airgapped.txtar +++ b/gno.land/pkg/integration/testdata/gnoweb_airgapped.txtar @@ -14,9 +14,12 @@ stdout 'data: {' stdout ' "BaseAccount": {' stdout ' "address": "g1jg8mtutu9khhfwc4nxmuhcpftf0pajdhfvsqf5",' stdout ' "coins": "[0-9]*ugnot",' # dynamic -stdout ' "public_key": null,' +stdout ' "public_key": {' +stdout ' "@type": "/tm.PubKeySecp256k1",' +stdout ' "value": "A\+FhNtsXHjLfSJk1lB8FbiL4mGPjc50Kt81J7EKDnJ2y"' +stdout ' },' stdout ' "account_number": "0",' -stdout ' "sequence": "0"' +stdout ' "sequence": "4"' stdout ' }' stdout '}' ! stderr '.+' # empty @@ -26,7 +29,7 @@ gnokey maketx call -pkgpath "gno.land/r/demo/echo" -func "Render" -gas-fee 10000 cp stdout call.tx # Sign -gnokey sign -tx-path $WORK/call.tx -chainid "tendermint_test" -account-number 0 -account-sequence 0 test1 +gnokey sign -tx-path $WORK/call.tx -chainid "tendermint_test" -account-number 0 -account-sequence 4 test1 cmpenv stdout sign.stdout.golden gnokey broadcast $WORK/call.tx diff --git a/gno.land/pkg/integration/testdata/loadpkg_example.txtar b/gno.land/pkg/integration/testdata/loadpkg_example.txtar index c05bedfef65..f7be500f3b6 100644 --- a/gno.land/pkg/integration/testdata/loadpkg_example.txtar +++ b/gno.land/pkg/integration/testdata/loadpkg_example.txtar @@ -4,11 +4,11 @@ loadpkg gno.land/p/demo/ufmt ## start a new node gnoland start -gnokey maketx addpkg -pkgdir $WORK -pkgpath gno.land/r/importtest -gas-fee 1000000ugnot -gas-wanted 9000000 -broadcast -chainid=tendermint_test test1 +gnokey maketx addpkg -pkgdir $WORK -pkgpath gno.land/r/importtest -gas-fee 1000000ugnot -gas-wanted 10000000 -broadcast -chainid=tendermint_test test1 stdout OK! ## execute Render -gnokey maketx call -pkgpath gno.land/r/importtest -func Render -gas-fee 1000000ugnot -gas-wanted 9000000 -args '' -broadcast -chainid=tendermint_test test1 +gnokey maketx call -pkgpath gno.land/r/importtest -func Render -gas-fee 1000000ugnot -gas-wanted 10000000 -args '' -broadcast -chainid=tendermint_test test1 stdout '("92054" string)' stdout OK! diff --git a/gno.land/pkg/integration/testdata/restart.txtar b/gno.land/pkg/integration/testdata/restart.txtar index 8a63713a214..5571aa9fa66 100644 --- a/gno.land/pkg/integration/testdata/restart.txtar +++ b/gno.land/pkg/integration/testdata/restart.txtar @@ -4,12 +4,12 @@ loadpkg gno.land/r/demo/counter $WORK gnoland start -gnokey maketx call -pkgpath gno.land/r/demo/counter -func Incr -gas-fee 1000000ugnot -gas-wanted 150000 -broadcast -chainid tendermint_test test1 +gnokey maketx call -pkgpath gno.land/r/demo/counter -func Incr -gas-fee 1000000ugnot -gas-wanted 200000 -broadcast -chainid tendermint_test test1 stdout '\(1 int\)' gnoland restart -gnokey maketx call -pkgpath gno.land/r/demo/counter -func Incr -gas-fee 1000000ugnot -gas-wanted 150000 -broadcast -chainid tendermint_test test1 +gnokey maketx call -pkgpath gno.land/r/demo/counter -func Incr -gas-fee 1000000ugnot -gas-wanted 200000 -broadcast -chainid tendermint_test test1 stdout '\(2 int\)' -- counter.gno -- diff --git a/gno.land/pkg/integration/testdata/restart_missing_type.txtar b/gno.land/pkg/integration/testdata/restart_missing_type.txtar index b02acc16d96..09e1a27d6f4 100644 --- a/gno.land/pkg/integration/testdata/restart_missing_type.txtar +++ b/gno.land/pkg/integration/testdata/restart_missing_type.txtar @@ -5,15 +5,15 @@ loadpkg gno.land/p/demo/avl gnoland start -gnokey sign -tx-path $WORK/tx1.tx -chainid tendermint_test -account-sequence 0 test1 +gnokey sign -tx-path $WORK/tx1.tx -chainid tendermint_test -account-sequence 1 test1 ! gnokey broadcast $WORK/tx1.tx stderr 'out of gas' -gnokey sign -tx-path $WORK/tx2.tx -chainid tendermint_test -account-sequence 1 test1 +gnokey sign -tx-path $WORK/tx2.tx -chainid tendermint_test -account-sequence 2 test1 gnokey broadcast $WORK/tx2.tx stdout 'OK!' -gnokey sign -tx-path $WORK/tx3.tx -chainid tendermint_test -account-sequence 2 test1 +gnokey sign -tx-path $WORK/tx3.tx -chainid tendermint_test -account-sequence 3 test1 gnokey broadcast $WORK/tx3.tx stdout 'OK!' diff --git a/gno.land/pkg/integration/testdata/simulate_gas.txtar b/gno.land/pkg/integration/testdata/simulate_gas.txtar index 8550419f205..4c5213da345 100644 --- a/gno.land/pkg/integration/testdata/simulate_gas.txtar +++ b/gno.land/pkg/integration/testdata/simulate_gas.txtar @@ -6,11 +6,11 @@ gnoland start # simulate only gnokey maketx call -pkgpath gno.land/r/simulate -func Hello -gas-fee 1000000ugnot -gas-wanted 2000000 -broadcast -chainid=tendermint_test -simulate only test1 -stdout 'GAS USED: 96411' +stdout 'GAS USED: 99015' # simulate skip gnokey maketx call -pkgpath gno.land/r/simulate -func Hello -gas-fee 1000000ugnot -gas-wanted 2000000 -broadcast -chainid=tendermint_test -simulate skip test1 -stdout 'GAS USED: 96411' # same as simulate only +stdout 'GAS USED: 99015' # same as simulate only -- package/package.gno -- diff --git a/gno.land/pkg/integration/testdata_test.go b/gno.land/pkg/integration/testdata_test.go new file mode 100644 index 00000000000..ba4d5176df1 --- /dev/null +++ b/gno.land/pkg/integration/testdata_test.go @@ -0,0 +1,67 @@ +package integration + +import ( + "os" + "strconv" + "testing" + + gno_integration "github.com/gnolang/gno/gnovm/pkg/integration" + "github.com/rogpeppe/go-internal/testscript" + "github.com/stretchr/testify/require" +) + +func TestTestdata(t *testing.T) { + t.Parallel() + + flagInMemoryTS, _ := strconv.ParseBool(os.Getenv("INMEMORY_TS")) + flagNoSeqTS, _ := strconv.ParseBool(os.Getenv("NO_SEQ_TS")) + + p := gno_integration.NewTestingParams(t, "testdata") + + if coverdir, ok := gno_integration.ResolveCoverageDir(); ok { + err := gno_integration.SetupTestscriptsCoverage(&p, coverdir) + require.NoError(t, err) + } + + // Set up gnoland for testscript + err := SetupGnolandTestscript(t, &p) + require.NoError(t, err) + + mode := commandKindTesting + if flagInMemoryTS { + mode = commandKindInMemory + } + + origSetup := p.Setup + p.Setup = func(env *testscript.Env) error { + env.Values[envKeyExecCommand] = mode + if origSetup != nil { + if err := origSetup(env); err != nil { + return err + } + } + + return nil + } + + if flagInMemoryTS && !flagNoSeqTS { + testscript.RunT(tSeqShim{t}, p) + } else { + testscript.Run(t, p) + } +} + +type tSeqShim struct{ *testing.T } + +// noop Parallel method allow us to run test sequentially +func (tSeqShim) Parallel() {} + +func (t tSeqShim) Run(name string, f func(testscript.T)) { + t.T.Run(name, func(t *testing.T) { + f(tSeqShim{t}) + }) +} + +func (t tSeqShim) Verbose() bool { + return testing.Verbose() +} diff --git a/gno.land/pkg/integration/testing_integration.go b/gno.land/pkg/integration/testing_integration.go deleted file mode 100644 index 0a181950bb3..00000000000 --- a/gno.land/pkg/integration/testing_integration.go +++ /dev/null @@ -1,795 +0,0 @@ -package integration - -import ( - "context" - "errors" - "flag" - "fmt" - "hash/crc32" - "log/slog" - "os" - "path/filepath" - "strconv" - "strings" - "testing" - - "github.com/gnolang/gno/gno.land/pkg/gnoland" - "github.com/gnolang/gno/gno.land/pkg/gnoland/ugnot" - "github.com/gnolang/gno/gno.land/pkg/keyscli" - "github.com/gnolang/gno/gno.land/pkg/log" - "github.com/gnolang/gno/gno.land/pkg/sdk/vm" - "github.com/gnolang/gno/gnovm/pkg/gnoenv" - "github.com/gnolang/gno/gnovm/pkg/gnolang" - "github.com/gnolang/gno/gnovm/pkg/gnomod" - "github.com/gnolang/gno/gnovm/pkg/packages" - "github.com/gnolang/gno/tm2/pkg/bft/node" - bft "github.com/gnolang/gno/tm2/pkg/bft/types" - "github.com/gnolang/gno/tm2/pkg/commands" - "github.com/gnolang/gno/tm2/pkg/crypto" - "github.com/gnolang/gno/tm2/pkg/crypto/bip39" - "github.com/gnolang/gno/tm2/pkg/crypto/keys" - "github.com/gnolang/gno/tm2/pkg/crypto/keys/client" - "github.com/gnolang/gno/tm2/pkg/db/memdb" - tm2Log "github.com/gnolang/gno/tm2/pkg/log" - "github.com/gnolang/gno/tm2/pkg/std" - "github.com/rogpeppe/go-internal/testscript" - "go.uber.org/zap/zapcore" -) - -const ( - envKeyGenesis int = iota - envKeyLogger - envKeyPkgsLoader -) - -type tSeqShim struct{ *testing.T } - -// noop Parallel method allow us to run test sequentially -func (tSeqShim) Parallel() {} - -func (t tSeqShim) Run(name string, f func(testscript.T)) { - t.T.Run(name, func(t *testing.T) { - f(tSeqShim{t}) - }) -} - -func (t tSeqShim) Verbose() bool { - return testing.Verbose() -} - -// RunGnolandTestscripts sets up and runs txtar integration tests for gnoland nodes. -// It prepares an in-memory gnoland node and initializes the necessary environment and custom commands. -// The function adapts the test setup for use with the testscript package, enabling -// the execution of gnoland and gnokey commands within txtar scripts. -// -// Refer to package documentation in doc.go for more information on commands and example txtar scripts. -func RunGnolandTestscripts(t *testing.T, txtarDir string) { - t.Helper() - - p := setupGnolandTestScript(t, txtarDir) - if deadline, ok := t.Deadline(); ok && p.Deadline.IsZero() { - p.Deadline = deadline - } - - testscript.RunT(tSeqShim{t}, p) -} - -type testNode struct { - *node.Node - cfg *gnoland.InMemoryNodeConfig - nGnoKeyExec uint // Counter for execution of gnokey. -} - -func setupGnolandTestScript(t *testing.T, txtarDir string) testscript.Params { - t.Helper() - - tmpdir := t.TempDir() - - // `gnoRootDir` should point to the local location of the gno repository. - // It serves as the gno equivalent of GOROOT. - gnoRootDir := gnoenv.RootDir() - - // `gnoHomeDir` should be the local directory where gnokey stores keys. - gnoHomeDir := filepath.Join(tmpdir, "gno") - - // Testscripts run concurrently by default, so we need to be prepared for that. - nodes := map[string]*testNode{} - - updateScripts, _ := strconv.ParseBool(os.Getenv("UPDATE_SCRIPTS")) - persistWorkDir, _ := strconv.ParseBool(os.Getenv("TESTWORK")) - return testscript.Params{ - UpdateScripts: updateScripts, - TestWork: persistWorkDir, - Dir: txtarDir, - Setup: func(env *testscript.Env) error { - kb, err := keys.NewKeyBaseFromDir(gnoHomeDir) - if err != nil { - return err - } - - // create sessions ID - var sid string - { - works := env.Getenv("WORK") - sum := crc32.ChecksumIEEE([]byte(works)) - sid = strconv.FormatUint(uint64(sum), 16) - env.Setenv("SID", sid) - } - - // setup logger - var logger *slog.Logger - { - logger = tm2Log.NewNoopLogger() - if persistWorkDir || os.Getenv("LOG_PATH_DIR") != "" { - logname := fmt.Sprintf("txtar-gnoland-%s.log", sid) - logger, err = getTestingLogger(env, logname) - if err != nil { - return fmt.Errorf("unable to setup logger: %w", err) - } - } - - env.Values[envKeyLogger] = logger - } - - // Track new user balances added via the `adduser` - // command and packages added with the `loadpkg` command. - // This genesis will be use when node is started. - - genesis := gnoland.DefaultGenState() - genesis.Balances = LoadDefaultGenesisBalanceFile(t, gnoRootDir) - genesis.Params = LoadDefaultGenesisParamFile(t, gnoRootDir) - genesis.Auth.Params.InitialGasPrice = std.GasPrice{Gas: 0, Price: std.Coin{Amount: 0, Denom: "ugnot"}} - genesis.Txs = []gnoland.TxWithMetadata{} - - // test1 must be created outside of the loop below because it is already included in genesis so - // attempting to recreate results in it getting overwritten and breaking existing tests that - // rely on its address being static. - kb.CreateAccount(DefaultAccount_Name, DefaultAccount_Seed, "", "", 0, 0) - env.Setenv("USER_SEED_"+DefaultAccount_Name, DefaultAccount_Seed) - env.Setenv("USER_ADDR_"+DefaultAccount_Name, DefaultAccount_Address) - - env.Values[envKeyGenesis] = &genesis - env.Values[envKeyPkgsLoader] = newPkgsLoader() - - env.Setenv("GNOROOT", gnoRootDir) - env.Setenv("GNOHOME", gnoHomeDir) - - return nil - }, - Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){ - "gnoland": func(ts *testscript.TestScript, neg bool, args []string) { - if len(args) == 0 { - tsValidateError(ts, "gnoland", neg, fmt.Errorf("syntax: gnoland [start|stop|restart]")) - return - } - - logger := ts.Value(envKeyLogger).(*slog.Logger) // grab logger - sid := getNodeSID(ts) // grab session id - - var cmd string - cmd, args = args[0], args[1:] - - var err error - switch cmd { - case "start": - if nodeIsRunning(nodes, sid) { - err = fmt.Errorf("node already started") - break - } - - // parse flags - fs := flag.NewFlagSet("start", flag.ContinueOnError) - nonVal := fs.Bool("non-validator", false, "set up node as a non-validator") - if err := fs.Parse(args); err != nil { - ts.Fatalf("unable to parse `gnoland start` flags: %s", err) - } - - // get packages - pkgs := ts.Value(envKeyPkgsLoader).(*pkgsLoader) // grab logger - creator := crypto.MustAddressFromString(DefaultAccount_Address) // test1 - defaultFee := std.NewFee(50000, std.MustParseCoin(ugnot.ValueString(1000000))) - // we need to define a new err1 otherwise the out err would be shadowed in the case "start": - pkgsTxs, loadErr := pkgs.LoadPackages(creator, defaultFee, nil) - - if loadErr != nil { - ts.Fatalf("unable to load packages txs: %s", err) - } - - // Warp up `ts` so we can pass it to other testing method - t := TSTestingT(ts) - - // Generate config and node - cfg := TestingMinimalNodeConfig(t, gnoRootDir) - genesis := ts.Value(envKeyGenesis).(*gnoland.GnoGenesisState) - genesis.Txs = append(pkgsTxs, genesis.Txs...) - - // setup genesis state - cfg.Genesis.AppState = *genesis - if *nonVal { - // re-create cfg.Genesis.Validators with a throwaway pv, so we start as a - // non-validator. - pv := gnoland.NewMockedPrivValidator() - cfg.Genesis.Validators = []bft.GenesisValidator{ - { - Address: pv.GetPubKey().Address(), - PubKey: pv.GetPubKey(), - Power: 10, - Name: "none", - }, - } - } - cfg.DB = memdb.NewMemDB() // so it can be reused when restarting. - - n, remoteAddr := TestingInMemoryNode(t, logger, cfg) - - // Register cleanup - nodes[sid] = &testNode{Node: n, cfg: cfg} - - // Add default environments - ts.Setenv("RPC_ADDR", remoteAddr) - - fmt.Fprintln(ts.Stdout(), "node started successfully") - case "restart": - n, ok := nodes[sid] - if !ok { - err = fmt.Errorf("node must be started before being restarted") - break - } - - if stopErr := n.Stop(); stopErr != nil { - err = fmt.Errorf("error stopping node: %w", stopErr) - break - } - - // Create new node with same config. - newNode, newRemoteAddr := TestingInMemoryNode(t, logger, n.cfg) - - // Update testNode and environment variables. - n.Node = newNode - ts.Setenv("RPC_ADDR", newRemoteAddr) - - fmt.Fprintln(ts.Stdout(), "node restarted successfully") - case "stop": - n, ok := nodes[sid] - if !ok { - err = fmt.Errorf("node not started cannot be stopped") - break - } - if err = n.Stop(); err == nil { - delete(nodes, sid) - - // Unset gnoland environments - ts.Setenv("RPC_ADDR", "") - fmt.Fprintln(ts.Stdout(), "node stopped successfully") - } - default: - err = fmt.Errorf("invalid gnoland subcommand: %q", cmd) - } - - tsValidateError(ts, "gnoland "+cmd, neg, err) - }, - "gnokey": func(ts *testscript.TestScript, neg bool, args []string) { - logger := ts.Value(envKeyLogger).(*slog.Logger) // grab logger - sid := ts.Getenv("SID") // grab session id - - // Unquote args enclosed in `"` to correctly handle `\n` or similar escapes. - args, err := unquote(args) - if err != nil { - tsValidateError(ts, "gnokey", neg, err) - } - - // Setup IO command - io := commands.NewTestIO() - io.SetOut(commands.WriteNopCloser(ts.Stdout())) - io.SetErr(commands.WriteNopCloser(ts.Stderr())) - cmd := keyscli.NewRootCmd(io, client.DefaultBaseOptions) - - io.SetIn(strings.NewReader("\n")) // Inject empty password to stdin. - defaultArgs := []string{ - "-home", gnoHomeDir, - "-insecure-password-stdin=true", // There no use to not have this param by default. - } - - if n, ok := nodes[sid]; ok { - if raddr := n.Config().RPC.ListenAddress; raddr != "" { - defaultArgs = append(defaultArgs, "-remote", raddr) - } - - n.nGnoKeyExec++ - headerlog := fmt.Sprintf("%.02d!EXEC_GNOKEY", n.nGnoKeyExec) - - // Log the command inside gnoland logger, so we can better scope errors. - logger.Info(headerlog, "args", strings.Join(args, " ")) - defer logger.Info(headerlog, "delimiter", "END") - } - - // Inject default argument, if duplicate - // arguments, it should be override by the ones - // user provided. - args = append(defaultArgs, args...) - - err = cmd.ParseAndRun(context.Background(), args) - tsValidateError(ts, "gnokey", neg, err) - }, - // adduser command must be executed before starting the node; it errors out otherwise. - "adduser": func(ts *testscript.TestScript, neg bool, args []string) { - if nodeIsRunning(nodes, getNodeSID(ts)) { - tsValidateError(ts, "adduser", neg, errors.New("adduser must be used before starting node")) - return - } - - if len(args) == 0 { - ts.Fatalf("new user name required") - } - - kb, err := keys.NewKeyBaseFromDir(gnoHomeDir) - if err != nil { - ts.Fatalf("unable to get keybase") - } - - balance, err := createAccount(ts, kb, args[0]) - if err != nil { - ts.Fatalf("error creating account %s: %s", args[0], err) - } - - // Add balance to genesis - genesis := ts.Value(envKeyGenesis).(*gnoland.GnoGenesisState) - genesis.Balances = append(genesis.Balances, balance) - }, - // adduserfrom commands must be executed before starting the node; it errors out otherwise. - "adduserfrom": func(ts *testscript.TestScript, neg bool, args []string) { - if nodeIsRunning(nodes, getNodeSID(ts)) { - tsValidateError(ts, "adduserfrom", neg, errors.New("adduserfrom must be used before starting node")) - return - } - - var account, index uint64 - var err error - - switch len(args) { - case 2: - // expected user input - // adduserfrom 'username 'menmonic' - // no need to do anything - - case 4: - // expected user input - // adduserfrom 'username 'menmonic' 'account' 'index' - - // parse 'index' first, then fallghrough to `case 3` to parse 'account' - index, err = strconv.ParseUint(args[3], 10, 32) - if err != nil { - ts.Fatalf("invalid index number %s", args[3]) - } - - fallthrough // parse 'account' - case 3: - // expected user input - // adduserfrom 'username 'menmonic' 'account' - - account, err = strconv.ParseUint(args[2], 10, 32) - if err != nil { - ts.Fatalf("invalid account number %s", args[2]) - } - default: - ts.Fatalf("to create account from metadatas, user name and mnemonic are required ( account and index are optional )") - } - - kb, err := keys.NewKeyBaseFromDir(gnoHomeDir) - if err != nil { - ts.Fatalf("unable to get keybase") - } - - balance, err := createAccountFrom(ts, kb, args[0], args[1], uint32(account), uint32(index)) - if err != nil { - ts.Fatalf("error creating wallet %s", err) - } - - // Add balance to genesis - genesis := ts.Value(envKeyGenesis).(*gnoland.GnoGenesisState) - genesis.Balances = append(genesis.Balances, balance) - - fmt.Fprintf(ts.Stdout(), "Added %s(%s) to genesis", args[0], balance.Address) - }, - // `patchpkg` Patch any loaded files by packages by replacing all occurrences of the - // first argument with the second. - // This is mostly use to replace hardcoded address inside txtar file. - "patchpkg": func(ts *testscript.TestScript, neg bool, args []string) { - args, err := unquote(args) - if err != nil { - tsValidateError(ts, "patchpkg", neg, err) - } - - if len(args) != 2 { - ts.Fatalf("`patchpkg`: should have exactly 2 arguments") - } - - pkgs := ts.Value(envKeyPkgsLoader).(*pkgsLoader) - replace, with := args[0], args[1] - pkgs.SetPatch(replace, with) - }, - // `loadpkg` load a specific package from the 'examples' or working directory. - "loadpkg": func(ts *testscript.TestScript, neg bool, args []string) { - // special dirs - workDir := ts.Getenv("WORK") - examplesDir := filepath.Join(gnoRootDir, "examples") - - pkgs := ts.Value(envKeyPkgsLoader).(*pkgsLoader) - - var path, name string - switch len(args) { - case 2: - name = args[0] - path = filepath.Clean(args[1]) - case 1: - path = filepath.Clean(args[0]) - case 0: - ts.Fatalf("`loadpkg`: no arguments specified") - default: - ts.Fatalf("`loadpkg`: too many arguments specified") - } - - // If `all` is specified, fully load 'examples' directory. - // NOTE: In 99% of cases, this is not needed, and - // packages should be loaded individually. - if path == "all" { - ts.Logf("warning: loading all packages") - if err := pkgs.LoadAllPackagesFromDir(examplesDir); err != nil { - ts.Fatalf("unable to load packages from %q: %s", examplesDir, err) - } - - return - } - - if !strings.HasPrefix(path, workDir) { - path = filepath.Join(examplesDir, path) - } - - if err := pkgs.LoadPackage(examplesDir, path, name); err != nil { - ts.Fatalf("`loadpkg` unable to load package(s) from %q: %s", args[0], err) - } - - ts.Logf("%q package was added to genesis", args[0]) - }, - }, - } -} - -// `unquote` takes a slice of strings, resulting from splitting a string block by spaces, and -// processes them. The function handles quoted phrases and escape characters within these strings. -func unquote(args []string) ([]string, error) { - const quote = '"' - - parts := []string{} - var inQuote bool - - var part strings.Builder - for _, arg := range args { - var escaped bool - for _, c := range arg { - if escaped { - // If the character is meant to be escaped, it is processed with Unquote. - // We use `Unquote` here for two main reasons: - // 1. It will validate that the escape sequence is correct - // 2. It converts the escaped string to its corresponding raw character. - // For example, "\\t" becomes '\t'. - uc, err := strconv.Unquote(`"\` + string(c) + `"`) - if err != nil { - return nil, fmt.Errorf("unhandled escape sequence `\\%c`: %w", c, err) - } - - part.WriteString(uc) - escaped = false - continue - } - - // If we are inside a quoted string and encounter an escape character, - // flag the next character as `escaped` - if inQuote && c == '\\' { - escaped = true - continue - } - - // Detect quote and toggle inQuote state - if c == quote { - inQuote = !inQuote - continue - } - - // Handle regular character - part.WriteRune(c) - } - - // If we're inside a quote, add a single space. - // It reflects one or multiple spaces between args in the original string. - if inQuote { - part.WriteRune(' ') - continue - } - - // Finalize part, add to parts, and reset for next part - parts = append(parts, part.String()) - part.Reset() - } - - // Check if a quote is left open - if inQuote { - return nil, errors.New("unfinished quote") - } - - return parts, nil -} - -func getNodeSID(ts *testscript.TestScript) string { - return ts.Getenv("SID") -} - -func nodeIsRunning(nodes map[string]*testNode, sid string) bool { - _, ok := nodes[sid] - return ok -} - -func getTestingLogger(env *testscript.Env, logname string) (*slog.Logger, error) { - var path string - - if logdir := os.Getenv("LOG_PATH_DIR"); logdir != "" { - if err := os.MkdirAll(logdir, 0o755); err != nil { - return nil, fmt.Errorf("unable to make log directory %q", logdir) - } - - var err error - if path, err = filepath.Abs(filepath.Join(logdir, logname)); err != nil { - return nil, fmt.Errorf("unable to get absolute path of logdir %q", logdir) - } - } else if workdir := env.Getenv("WORK"); workdir != "" { - path = filepath.Join(workdir, logname) - } else { - return tm2Log.NewNoopLogger(), nil - } - - f, err := os.Create(path) - if err != nil { - return nil, fmt.Errorf("unable to create log file %q: %w", path, err) - } - - env.Defer(func() { - if err := f.Close(); err != nil { - panic(fmt.Errorf("unable to close log file %q: %w", path, err)) - } - }) - - // Initialize the logger - logLevel, err := zapcore.ParseLevel(strings.ToLower(os.Getenv("LOG_LEVEL"))) - if err != nil { - return nil, fmt.Errorf("unable to parse log level, %w", err) - } - - // Build zap logger for testing - zapLogger := log.NewZapTestingLogger(f, logLevel) - env.Defer(func() { zapLogger.Sync() }) - - env.T().Log("starting logger", path) - return log.ZapLoggerToSlog(zapLogger), nil -} - -func tsValidateError(ts *testscript.TestScript, cmd string, neg bool, err error) { - if err != nil { - fmt.Fprintf(ts.Stderr(), "%q error: %+v\n", cmd, err) - if !neg { - ts.Fatalf("unexpected %q command failure: %s", cmd, err) - } - } else { - if neg { - ts.Fatalf("unexpected %q command success", cmd) - } - } -} - -type envSetter interface { - Setenv(key, value string) -} - -// createAccount creates a new account with the given name and adds it to the keybase. -func createAccount(env envSetter, kb keys.Keybase, accountName string) (gnoland.Balance, error) { - var balance gnoland.Balance - entropy, err := bip39.NewEntropy(256) - if err != nil { - return balance, fmt.Errorf("error creating entropy: %w", err) - } - - mnemonic, err := bip39.NewMnemonic(entropy) - if err != nil { - return balance, fmt.Errorf("error generating mnemonic: %w", err) - } - - var keyInfo keys.Info - if keyInfo, err = kb.CreateAccount(accountName, mnemonic, "", "", 0, 0); err != nil { - return balance, fmt.Errorf("unable to create account: %w", err) - } - - address := keyInfo.GetAddress() - env.Setenv("USER_SEED_"+accountName, mnemonic) - env.Setenv("USER_ADDR_"+accountName, address.String()) - - return gnoland.Balance{ - Address: address, - Amount: std.Coins{std.NewCoin(ugnot.Denom, 10e6)}, - }, nil -} - -// createAccountFrom creates a new account with the given metadata and adds it to the keybase. -func createAccountFrom(env envSetter, kb keys.Keybase, accountName, mnemonic string, account, index uint32) (gnoland.Balance, error) { - var balance gnoland.Balance - - // check if mnemonic is valid - if !bip39.IsMnemonicValid(mnemonic) { - return balance, fmt.Errorf("invalid mnemonic") - } - - keyInfo, err := kb.CreateAccount(accountName, mnemonic, "", "", account, index) - if err != nil { - return balance, fmt.Errorf("unable to create account: %w", err) - } - - address := keyInfo.GetAddress() - env.Setenv("USER_SEED_"+accountName, mnemonic) - env.Setenv("USER_ADDR_"+accountName, address.String()) - - return gnoland.Balance{ - Address: address, - Amount: std.Coins{std.NewCoin(ugnot.Denom, 10e6)}, - }, nil -} - -type pkgsLoader struct { - pkgs []gnomod.Pkg - visited map[string]struct{} - - // list of occurrences to patchs with the given value - // XXX: find a better way - patchs map[string]string -} - -func newPkgsLoader() *pkgsLoader { - return &pkgsLoader{ - pkgs: make([]gnomod.Pkg, 0), - visited: make(map[string]struct{}), - patchs: make(map[string]string), - } -} - -func (pl *pkgsLoader) List() gnomod.PkgList { - return pl.pkgs -} - -func (pl *pkgsLoader) SetPatch(replace, with string) { - pl.patchs[replace] = with -} - -func (pl *pkgsLoader) LoadPackages(creator bft.Address, fee std.Fee, deposit std.Coins) ([]gnoland.TxWithMetadata, error) { - pkgslist, err := pl.List().Sort() // sorts packages by their dependencies. - if err != nil { - return nil, fmt.Errorf("unable to sort packages: %w", err) - } - - txs := make([]gnoland.TxWithMetadata, len(pkgslist)) - for i, pkg := range pkgslist { - tx, err := gnoland.LoadPackage(pkg, creator, fee, deposit) - if err != nil { - return nil, fmt.Errorf("unable to load pkg %q: %w", pkg.Name, err) - } - - // If any replace value is specified, apply them - if len(pl.patchs) > 0 { - for _, msg := range tx.Msgs { - addpkg, ok := msg.(vm.MsgAddPackage) - if !ok { - continue - } - - if addpkg.Package == nil { - continue - } - - for _, file := range addpkg.Package.Files { - for replace, with := range pl.patchs { - file.Body = strings.ReplaceAll(file.Body, replace, with) - } - } - } - } - - txs[i] = gnoland.TxWithMetadata{ - Tx: tx, - } - } - - return txs, nil -} - -func (pl *pkgsLoader) LoadAllPackagesFromDir(path string) error { - // list all packages from target path - pkgslist, err := gnomod.ListPkgs(path) - if err != nil { - return fmt.Errorf("listing gno packages: %w", err) - } - - for _, pkg := range pkgslist { - if !pl.exist(pkg) { - pl.add(pkg) - } - } - - return nil -} - -func (pl *pkgsLoader) LoadPackage(modroot string, path, name string) error { - // Initialize a queue with the root package - queue := []gnomod.Pkg{{Dir: path, Name: name}} - - for len(queue) > 0 { - // Dequeue the first package - currentPkg := queue[0] - queue = queue[1:] - - if currentPkg.Dir == "" { - return fmt.Errorf("no path specified for package") - } - - if currentPkg.Name == "" { - // Load `gno.mod` information - gnoModPath := filepath.Join(currentPkg.Dir, "gno.mod") - gm, err := gnomod.ParseGnoMod(gnoModPath) - if err != nil { - return fmt.Errorf("unable to load %q: %w", gnoModPath, err) - } - gm.Sanitize() - - // Override package info with mod infos - currentPkg.Name = gm.Module.Mod.Path - currentPkg.Draft = gm.Draft - - pkg, err := gnolang.ReadMemPackage(currentPkg.Dir, currentPkg.Name) - if err != nil { - return fmt.Errorf("unable to read package at %q: %w", currentPkg.Dir, err) - } - imports, err := packages.Imports(pkg, nil) - if err != nil { - return fmt.Errorf("unable to load package imports in %q: %w", currentPkg.Dir, err) - } - for _, imp := range imports { - if imp.PkgPath == currentPkg.Name || gnolang.IsStdlib(imp.PkgPath) { - continue - } - currentPkg.Imports = append(currentPkg.Imports, imp.PkgPath) - } - } - - if currentPkg.Draft { - continue // Skip draft package - } - - if pl.exist(currentPkg) { - continue - } - pl.add(currentPkg) - - // Add requirements to the queue - for _, pkgPath := range currentPkg.Imports { - fullPath := filepath.Join(modroot, pkgPath) - queue = append(queue, gnomod.Pkg{Dir: fullPath}) - } - } - - return nil -} - -func (pl *pkgsLoader) add(pkg gnomod.Pkg) { - pl.visited[pkg.Name] = struct{}{} - pl.pkgs = append(pl.pkgs, pkg) -} - -func (pl *pkgsLoader) exist(pkg gnomod.Pkg) (ok bool) { - _, ok = pl.visited[pkg.Name] - return -} diff --git a/gno.land/pkg/integration/testscript_gnoland.go b/gno.land/pkg/integration/testscript_gnoland.go new file mode 100644 index 00000000000..9781799ea7d --- /dev/null +++ b/gno.land/pkg/integration/testscript_gnoland.go @@ -0,0 +1,789 @@ +package integration + +import ( + "bytes" + "context" + "errors" + "flag" + "fmt" + "hash/crc32" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/gnolang/gno/gno.land/pkg/gnoland" + "github.com/gnolang/gno/gno.land/pkg/gnoland/ugnot" + "github.com/gnolang/gno/gno.land/pkg/keyscli" + "github.com/gnolang/gno/gnovm/pkg/gnoenv" + bft "github.com/gnolang/gno/tm2/pkg/bft/types" + "github.com/gnolang/gno/tm2/pkg/commands" + "github.com/gnolang/gno/tm2/pkg/crypto" + "github.com/gnolang/gno/tm2/pkg/crypto/bip39" + "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" + "github.com/gnolang/gno/tm2/pkg/crypto/hd" + "github.com/gnolang/gno/tm2/pkg/crypto/keys" + "github.com/gnolang/gno/tm2/pkg/crypto/keys/client" + "github.com/gnolang/gno/tm2/pkg/crypto/secp256k1" + "github.com/gnolang/gno/tm2/pkg/std" + "github.com/rogpeppe/go-internal/testscript" + "github.com/stretchr/testify/require" +) + +const nodeMaxLifespan = time.Second * 30 + +type envKey int + +const ( + envKeyGenesis envKey = iota + envKeyLogger + envKeyPkgsLoader + envKeyPrivValKey + envKeyExecCommand + envKeyExecBin +) + +type commandkind int + +const ( + // commandKindBin builds and uses an integration binary to run the testscript + // in a separate process. This should be used for any external package that + // wants to use test scripts. + commandKindBin commandkind = iota + // commandKindTesting uses the current testing binary to run the testscript + // in a separate process. This command cannot be used outside this package. + commandKindTesting + // commandKindInMemory runs testscripts in memory. + commandKindInMemory +) + +type tNodeProcess struct { + NodeProcess + cfg *gnoland.InMemoryNodeConfig + nGnoKeyExec uint // Counter for execution of gnokey. +} + +// NodesManager manages access to the nodes map with synchronization. +type NodesManager struct { + nodes map[string]*tNodeProcess + mu sync.RWMutex +} + +// NewNodesManager creates a new instance of NodesManager. +func NewNodesManager() *NodesManager { + return &NodesManager{ + nodes: make(map[string]*tNodeProcess), + } +} + +func (nm *NodesManager) IsNodeRunning(sid string) bool { + nm.mu.RLock() + defer nm.mu.RUnlock() + + _, ok := nm.nodes[sid] + return ok +} + +// Get retrieves a node by its SID. +func (nm *NodesManager) Get(sid string) (*tNodeProcess, bool) { + nm.mu.RLock() + defer nm.mu.RUnlock() + node, exists := nm.nodes[sid] + return node, exists +} + +// Set adds or updates a node in the map. +func (nm *NodesManager) Set(sid string, node *tNodeProcess) { + nm.mu.Lock() + defer nm.mu.Unlock() + nm.nodes[sid] = node +} + +// Delete removes a node from the map. +func (nm *NodesManager) Delete(sid string) { + nm.mu.Lock() + defer nm.mu.Unlock() + delete(nm.nodes, sid) +} + +func SetupGnolandTestscript(t *testing.T, p *testscript.Params) error { + t.Helper() + + gnoRootDir := gnoenv.RootDir() + + nodesManager := NewNodesManager() + + defaultPK, err := GeneratePrivKeyFromMnemonic(DefaultAccount_Seed, "", 0, 0) + require.NoError(t, err) + + var buildOnce sync.Once + var gnolandBin string + + // Store the original setup scripts for potential wrapping + origSetup := p.Setup + p.Setup = func(env *testscript.Env) error { + // If there's an original setup, execute it + if origSetup != nil { + if err := origSetup(env); err != nil { + return err + } + } + + cmd, isSet := env.Values[envKeyExecCommand].(commandkind) + switch { + case !isSet: + cmd = commandKindBin // fallback on commandKindBin + fallthrough + case cmd == commandKindBin: + buildOnce.Do(func() { + t.Logf("building the gnoland integration node") + start := time.Now() + gnolandBin = buildGnoland(t, gnoRootDir) + t.Logf("time to build the node: %v", time.Since(start).String()) + }) + + env.Values[envKeyExecBin] = gnolandBin + } + + tmpdir, dbdir := t.TempDir(), t.TempDir() + gnoHomeDir := filepath.Join(tmpdir, "gno") + + kb, err := keys.NewKeyBaseFromDir(gnoHomeDir) + if err != nil { + return err + } + + kb.ImportPrivKey(DefaultAccount_Name, defaultPK, "") + env.Setenv("USER_SEED_"+DefaultAccount_Name, DefaultAccount_Seed) + env.Setenv("USER_ADDR_"+DefaultAccount_Name, DefaultAccount_Address) + + // New private key + env.Values[envKeyPrivValKey] = ed25519.GenPrivKey() + env.Setenv("GNO_DBDIR", dbdir) + + // Generate node short id + var sid string + { + works := env.Getenv("WORK") + sum := crc32.ChecksumIEEE([]byte(works)) + sid = strconv.FormatUint(uint64(sum), 16) + env.Setenv("SID", sid) + } + + balanceFile := LoadDefaultGenesisBalanceFile(t, gnoRootDir) + genesisParamFile := LoadDefaultGenesisParamFile(t, gnoRootDir) + + // Track new user balances added via the `adduser` + // command and packages added with the `loadpkg` command. + // This genesis will be use when node is started. + genesis := &gnoland.GnoGenesisState{ + Balances: balanceFile, + Params: genesisParamFile, + Txs: []gnoland.TxWithMetadata{}, + } + + env.Values[envKeyGenesis] = genesis + env.Values[envKeyPkgsLoader] = NewPkgsLoader() + + env.Setenv("GNOROOT", gnoRootDir) + env.Setenv("GNOHOME", gnoHomeDir) + + env.Defer(func() { + // Gracefully stop the node, if any + n, exist := nodesManager.Get(sid) + if !exist { + return + } + + if err := n.Stop(); err != nil { + err = fmt.Errorf("unable to stop the node gracefully: %w", err) + env.T().Fatal(err.Error()) + } + }) + + return nil + } + + cmds := map[string]func(ts *testscript.TestScript, neg bool, args []string){ + "gnoland": gnolandCmd(t, nodesManager, gnoRootDir), + "gnokey": gnokeyCmd(nodesManager), + "adduser": adduserCmd(nodesManager), + "adduserfrom": adduserfromCmd(nodesManager), + "patchpkg": patchpkgCmd(), + "loadpkg": loadpkgCmd(gnoRootDir), + } + + // Initialize cmds map if needed + if p.Cmds == nil { + p.Cmds = make(map[string]func(ts *testscript.TestScript, neg bool, args []string)) + } + + // Register gnoland command + for cmd, call := range cmds { + if _, exist := p.Cmds[cmd]; exist { + panic(fmt.Errorf("unable register %q: command already exist", cmd)) + } + + p.Cmds[cmd] = call + } + + return nil +} + +func gnolandCmd(t *testing.T, nodesManager *NodesManager, gnoRootDir string) func(ts *testscript.TestScript, neg bool, args []string) { + t.Helper() + + defaultPK, err := GeneratePrivKeyFromMnemonic(DefaultAccount_Seed, "", 0, 0) + require.NoError(t, err) + + return func(ts *testscript.TestScript, neg bool, args []string) { + sid := getNodeSID(ts) + + cmd, cmdargs := "", []string{} + if len(args) > 0 { + cmd, cmdargs = args[0], args[1:] + } + + var err error + switch cmd { + case "": + err = errors.New("no command provided") + case "start": + if nodesManager.IsNodeRunning(sid) { + err = fmt.Errorf("node already started") + break + } + + // XXX: this is a bit hacky, we should consider moving + // gnoland into his own package to be able to use it + // directly or use the config command for this. + fs := flag.NewFlagSet("start", flag.ContinueOnError) + nonVal := fs.Bool("non-validator", false, "set up node as a non-validator") + if err := fs.Parse(cmdargs); err != nil { + ts.Fatalf("unable to parse `gnoland start` flags: %s", err) + } + + pkgs := ts.Value(envKeyPkgsLoader).(*PkgsLoader) + defaultFee := std.NewFee(50000, std.MustParseCoin(ugnot.ValueString(1000000))) + pkgsTxs, err := pkgs.LoadPackages(defaultPK, defaultFee, nil) + if err != nil { + ts.Fatalf("unable to load packages txs: %s", err) + } + + cfg := TestingMinimalNodeConfig(gnoRootDir) + genesis := ts.Value(envKeyGenesis).(*gnoland.GnoGenesisState) + genesis.Txs = append(pkgsTxs, genesis.Txs...) + + cfg.Genesis.AppState = *genesis + if *nonVal { + pv := gnoland.NewMockedPrivValidator() + cfg.Genesis.Validators = []bft.GenesisValidator{ + { + Address: pv.GetPubKey().Address(), + PubKey: pv.GetPubKey(), + Power: 10, + Name: "none", + }, + } + } + + ctx, cancel := context.WithTimeout(context.Background(), nodeMaxLifespan) + ts.Defer(cancel) + + dbdir := ts.Getenv("GNO_DBDIR") + priv := ts.Value(envKeyPrivValKey).(ed25519.PrivKeyEd25519) + nodep := setupNode(ts, ctx, &ProcessNodeConfig{ + ValidatorKey: priv, + DBDir: dbdir, + RootDir: gnoRootDir, + TMConfig: cfg.TMConfig, + Genesis: NewMarshalableGenesisDoc(cfg.Genesis), + }) + + nodesManager.Set(sid, &tNodeProcess{NodeProcess: nodep, cfg: cfg}) + + ts.Setenv("RPC_ADDR", nodep.Address()) + fmt.Fprintln(ts.Stdout(), "node started successfully") + + case "restart": + node, exists := nodesManager.Get(sid) + if !exists { + err = fmt.Errorf("node must be started before being restarted") + break + } + + if err := node.Stop(); err != nil { + err = fmt.Errorf("unable to stop the node gracefully: %w", err) + break + } + + ctx, cancel := context.WithTimeout(context.Background(), nodeMaxLifespan) + ts.Defer(cancel) + + priv := ts.Value(envKeyPrivValKey).(ed25519.PrivKeyEd25519) + dbdir := ts.Getenv("GNO_DBDIR") + nodep := setupNode(ts, ctx, &ProcessNodeConfig{ + ValidatorKey: priv, + DBDir: dbdir, + RootDir: gnoRootDir, + TMConfig: node.cfg.TMConfig, + Genesis: NewMarshalableGenesisDoc(node.cfg.Genesis), + }) + + ts.Setenv("RPC_ADDR", nodep.Address()) + nodesManager.Set(sid, &tNodeProcess{NodeProcess: nodep, cfg: node.cfg}) + + fmt.Fprintln(ts.Stdout(), "node restarted successfully") + + case "stop": + node, exists := nodesManager.Get(sid) + if !exists { + err = fmt.Errorf("node not started cannot be stopped") + break + } + + if err := node.Stop(); err != nil { + err = fmt.Errorf("unable to stop the node gracefully: %w", err) + break + } + + fmt.Fprintln(ts.Stdout(), "node stopped successfully") + nodesManager.Delete(sid) + + default: + err = fmt.Errorf("not supported command: %q", cmd) + // XXX: support gnoland other commands + } + + tsValidateError(ts, strings.TrimSpace("gnoland "+cmd), neg, err) + } +} + +func gnokeyCmd(nodes *NodesManager) func(ts *testscript.TestScript, neg bool, args []string) { + return func(ts *testscript.TestScript, neg bool, args []string) { + gnoHomeDir := ts.Getenv("GNOHOME") + + sid := getNodeSID(ts) + + args, err := unquote(args) + if err != nil { + tsValidateError(ts, "gnokey", neg, err) + } + + io := commands.NewTestIO() + io.SetOut(commands.WriteNopCloser(ts.Stdout())) + io.SetErr(commands.WriteNopCloser(ts.Stderr())) + cmd := keyscli.NewRootCmd(io, client.DefaultBaseOptions) + + io.SetIn(strings.NewReader("\n")) + defaultArgs := []string{ + "-home", gnoHomeDir, + "-insecure-password-stdin=true", + } + + if n, ok := nodes.Get(sid); ok { + if raddr := n.Address(); raddr != "" { + defaultArgs = append(defaultArgs, "-remote", raddr) + } + + n.nGnoKeyExec++ + } + + args = append(defaultArgs, args...) + + err = cmd.ParseAndRun(context.Background(), args) + tsValidateError(ts, "gnokey", neg, err) + } +} + +func adduserCmd(nodesManager *NodesManager) func(ts *testscript.TestScript, neg bool, args []string) { + return func(ts *testscript.TestScript, neg bool, args []string) { + gnoHomeDir := ts.Getenv("GNOHOME") + + sid := getNodeSID(ts) + if nodesManager.IsNodeRunning(sid) { + tsValidateError(ts, "adduser", neg, errors.New("adduser must be used before starting node")) + return + } + + if len(args) == 0 { + ts.Fatalf("new user name required") + } + + kb, err := keys.NewKeyBaseFromDir(gnoHomeDir) + if err != nil { + ts.Fatalf("unable to get keybase") + } + + balance, err := createAccount(ts, kb, args[0]) + if err != nil { + ts.Fatalf("error creating account %s: %s", args[0], err) + } + + genesis := ts.Value(envKeyGenesis).(*gnoland.GnoGenesisState) + genesis.Balances = append(genesis.Balances, balance) + } +} + +func adduserfromCmd(nodesManager *NodesManager) func(ts *testscript.TestScript, neg bool, args []string) { + return func(ts *testscript.TestScript, neg bool, args []string) { + gnoHomeDir := ts.Getenv("GNOHOME") + + sid := getNodeSID(ts) + if nodesManager.IsNodeRunning(sid) { + tsValidateError(ts, "adduserfrom", neg, errors.New("adduserfrom must be used before starting node")) + return + } + + var account, index uint64 + var err error + + switch len(args) { + case 2: + case 4: + index, err = strconv.ParseUint(args[3], 10, 32) + if err != nil { + ts.Fatalf("invalid index number %s", args[3]) + } + fallthrough + case 3: + account, err = strconv.ParseUint(args[2], 10, 32) + if err != nil { + ts.Fatalf("invalid account number %s", args[2]) + } + default: + ts.Fatalf("to create account from metadatas, user name and mnemonic are required ( account and index are optional )") + } + + kb, err := keys.NewKeyBaseFromDir(gnoHomeDir) + if err != nil { + ts.Fatalf("unable to get keybase") + } + + balance, err := createAccountFrom(ts, kb, args[0], args[1], uint32(account), uint32(index)) + if err != nil { + ts.Fatalf("error creating wallet %s", err) + } + + genesis := ts.Value(envKeyGenesis).(*gnoland.GnoGenesisState) + genesis.Balances = append(genesis.Balances, balance) + + fmt.Fprintf(ts.Stdout(), "Added %s(%s) to genesis", args[0], balance.Address) + } +} + +func patchpkgCmd() func(ts *testscript.TestScript, neg bool, args []string) { + return func(ts *testscript.TestScript, neg bool, args []string) { + args, err := unquote(args) + if err != nil { + tsValidateError(ts, "patchpkg", neg, err) + } + + if len(args) != 2 { + ts.Fatalf("`patchpkg`: should have exactly 2 arguments") + } + + pkgs := ts.Value(envKeyPkgsLoader).(*PkgsLoader) + replace, with := args[0], args[1] + pkgs.SetPatch(replace, with) + } +} + +func loadpkgCmd(gnoRootDir string) func(ts *testscript.TestScript, neg bool, args []string) { + return func(ts *testscript.TestScript, neg bool, args []string) { + workDir := ts.Getenv("WORK") + examplesDir := filepath.Join(gnoRootDir, "examples") + + pkgs := ts.Value(envKeyPkgsLoader).(*PkgsLoader) + + var path, name string + switch len(args) { + case 2: + name = args[0] + path = filepath.Clean(args[1]) + case 1: + path = filepath.Clean(args[0]) + case 0: + ts.Fatalf("`loadpkg`: no arguments specified") + default: + ts.Fatalf("`loadpkg`: too many arguments specified") + } + + if path == "all" { + ts.Logf("warning: loading all packages") + if err := pkgs.LoadAllPackagesFromDir(examplesDir); err != nil { + ts.Fatalf("unable to load packages from %q: %s", examplesDir, err) + } + + return + } + + if !strings.HasPrefix(path, workDir) { + path = filepath.Join(examplesDir, path) + } + + if err := pkgs.LoadPackage(examplesDir, path, name); err != nil { + ts.Fatalf("`loadpkg` unable to load package(s) from %q: %s", args[0], err) + } + + ts.Logf("%q package was added to genesis", args[0]) + } +} + +type tsLogWriter struct { + ts *testscript.TestScript +} + +func (l *tsLogWriter) Write(p []byte) (n int, err error) { + l.ts.Logf(string(p)) + return len(p), nil +} + +func setupNode(ts *testscript.TestScript, ctx context.Context, cfg *ProcessNodeConfig) NodeProcess { + pcfg := ProcessConfig{ + Node: cfg, + Stdout: &tsLogWriter{ts}, + Stderr: ts.Stderr(), + } + + // Setup coverdir provided + if coverdir := ts.Getenv("GOCOVERDIR"); coverdir != "" { + pcfg.CoverDir = coverdir + } + + val := ts.Value(envKeyExecCommand) + + switch cmd := val.(commandkind); cmd { + case commandKindInMemory: + nodep, err := RunInMemoryProcess(ctx, pcfg) + if err != nil { + ts.Fatalf("unable to start in memory node: %s", err) + } + + return nodep + + case commandKindTesting: + if !testing.Testing() { + ts.Fatalf("unable to invoke testing process while not testing") + } + + return runTestingNodeProcess(&testingTS{ts}, ctx, pcfg) + + case commandKindBin: + bin := ts.Value(envKeyExecBin).(string) + nodep, err := RunNodeProcess(ctx, pcfg, bin) + if err != nil { + ts.Fatalf("unable to start process node: %s", err) + } + + return nodep + + default: + ts.Fatalf("unknown command kind: %+v", cmd) + } + + return nil +} + +// `unquote` takes a slice of strings, resulting from splitting a string block by spaces, and +// processes them. The function handles quoted phrases and escape characters within these strings. +func unquote(args []string) ([]string, error) { + const quote = '"' + + parts := []string{} + var inQuote bool + + var part strings.Builder + for _, arg := range args { + var escaped bool + for _, c := range arg { + if escaped { + // If the character is meant to be escaped, it is processed with Unquote. + // We use `Unquote` here for two main reasons: + // 1. It will validate that the escape sequence is correct + // 2. It converts the escaped string to its corresponding raw character. + // For example, "\\t" becomes '\t'. + uc, err := strconv.Unquote(`"\` + string(c) + `"`) + if err != nil { + return nil, fmt.Errorf("unhandled escape sequence `\\%c`: %w", c, err) + } + + part.WriteString(uc) + escaped = false + continue + } + + // If we are inside a quoted string and encounter an escape character, + // flag the next character as `escaped` + if inQuote && c == '\\' { + escaped = true + continue + } + + // Detect quote and toggle inQuote state + if c == quote { + inQuote = !inQuote + continue + } + + // Handle regular character + part.WriteRune(c) + } + + // If we're inside a quote, add a single space. + // It reflects one or multiple spaces between args in the original string. + if inQuote { + part.WriteRune(' ') + continue + } + + // Finalize part, add to parts, and reset for next part + parts = append(parts, part.String()) + part.Reset() + } + + // Check if a quote is left open + if inQuote { + return nil, errors.New("unfinished quote") + } + + return parts, nil +} + +func getNodeSID(ts *testscript.TestScript) string { + return ts.Getenv("SID") +} + +func tsValidateError(ts *testscript.TestScript, cmd string, neg bool, err error) { + if err != nil { + fmt.Fprintf(ts.Stderr(), "%q error: %+v\n", cmd, err) + if !neg { + ts.Fatalf("unexpected %q command failure: %s", cmd, err) + } + } else { + if neg { + ts.Fatalf("unexpected %q command success", cmd) + } + } +} + +type envSetter interface { + Setenv(key, value string) +} + +// createAccount creates a new account with the given name and adds it to the keybase. +func createAccount(env envSetter, kb keys.Keybase, accountName string) (gnoland.Balance, error) { + var balance gnoland.Balance + entropy, err := bip39.NewEntropy(256) + if err != nil { + return balance, fmt.Errorf("error creating entropy: %w", err) + } + + mnemonic, err := bip39.NewMnemonic(entropy) + if err != nil { + return balance, fmt.Errorf("error generating mnemonic: %w", err) + } + + var keyInfo keys.Info + if keyInfo, err = kb.CreateAccount(accountName, mnemonic, "", "", 0, 0); err != nil { + return balance, fmt.Errorf("unable to create account: %w", err) + } + + address := keyInfo.GetAddress() + env.Setenv("USER_SEED_"+accountName, mnemonic) + env.Setenv("USER_ADDR_"+accountName, address.String()) + + return gnoland.Balance{ + Address: address, + Amount: std.Coins{std.NewCoin(ugnot.Denom, 10e6)}, + }, nil +} + +// createAccountFrom creates a new account with the given metadata and adds it to the keybase. +func createAccountFrom(env envSetter, kb keys.Keybase, accountName, mnemonic string, account, index uint32) (gnoland.Balance, error) { + var balance gnoland.Balance + + // check if mnemonic is valid + if !bip39.IsMnemonicValid(mnemonic) { + return balance, fmt.Errorf("invalid mnemonic") + } + + keyInfo, err := kb.CreateAccount(accountName, mnemonic, "", "", account, index) + if err != nil { + return balance, fmt.Errorf("unable to create account: %w", err) + } + + address := keyInfo.GetAddress() + env.Setenv("USER_SEED_"+accountName, mnemonic) + env.Setenv("USER_ADDR_"+accountName, address.String()) + + return gnoland.Balance{ + Address: address, + Amount: std.Coins{std.NewCoin(ugnot.Denom, 10e6)}, + }, nil +} + +func buildGnoland(t *testing.T, rootdir string) string { + t.Helper() + + bin := filepath.Join(t.TempDir(), "gnoland-test") + + t.Log("building gnoland integration binary...") + + // Build a fresh gno binary in a temp directory + gnoArgsBuilder := []string{"build", "-o", bin} + + os.Executable() + + // Forward `-covermode` settings if set + if coverMode := testing.CoverMode(); coverMode != "" { + gnoArgsBuilder = append(gnoArgsBuilder, + "-covermode", coverMode, + ) + } + + // Append the path to the gno command source + gnoArgsBuilder = append(gnoArgsBuilder, filepath.Join(rootdir, + "gno.land", "pkg", "integration", "process")) + + t.Logf("build command: %s", strings.Join(gnoArgsBuilder, " ")) + + cmd := exec.Command("go", gnoArgsBuilder...) + + var buff bytes.Buffer + cmd.Stderr, cmd.Stdout = &buff, &buff + defer buff.Reset() + + if err := cmd.Run(); err != nil { + require.FailNowf(t, "unable to build binary", "%q\n%s", + err.Error(), buff.String()) + } + + return bin +} + +// GeneratePrivKeyFromMnemonic generates a crypto.PrivKey from a mnemonic. +func GeneratePrivKeyFromMnemonic(mnemonic, bip39Passphrase string, account, index uint32) (crypto.PrivKey, error) { + // Generate Seed from Mnemonic + seed, err := bip39.NewSeedWithErrorChecking(mnemonic, bip39Passphrase) + if err != nil { + return nil, fmt.Errorf("failed to generate seed: %w", err) + } + + // Derive Private Key + coinType := crypto.CoinType // ensure this is set correctly in your context + hdPath := hd.NewFundraiserParams(account, coinType, index) + masterPriv, ch := hd.ComputeMastersFromSeed(seed) + derivedPriv, err := hd.DerivePrivateKeyForPath(masterPriv, ch, hdPath.String()) + if err != nil { + return nil, fmt.Errorf("failed to derive private key: %w", err) + } + + // Convert to secp256k1 private key + privKey := secp256k1.PrivKeySecp256k1(derivedPriv) + return privKey, nil +} diff --git a/gno.land/pkg/integration/integration_test.go b/gno.land/pkg/integration/testscript_gnoland_test.go similarity index 93% rename from gno.land/pkg/integration/integration_test.go rename to gno.land/pkg/integration/testscript_gnoland_test.go index 99a3e6c7eca..2c301064969 100644 --- a/gno.land/pkg/integration/integration_test.go +++ b/gno.land/pkg/integration/testscript_gnoland_test.go @@ -8,12 +8,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestTestdata(t *testing.T) { - t.Parallel() - - RunGnolandTestscripts(t, "testdata") -} - func TestUnquote(t *testing.T) { t.Parallel() diff --git a/gno.land/pkg/integration/testing.go b/gno.land/pkg/integration/testscript_testing.go similarity index 95% rename from gno.land/pkg/integration/testing.go rename to gno.land/pkg/integration/testscript_testing.go index 0cd3152d888..9eed180dd8b 100644 --- a/gno.land/pkg/integration/testing.go +++ b/gno.land/pkg/integration/testscript_testing.go @@ -2,6 +2,7 @@ package integration import ( "errors" + "testing" "github.com/rogpeppe/go-internal/testscript" "github.com/stretchr/testify/assert" @@ -16,6 +17,7 @@ var errFailNow = errors.New("fail now!") //nolint:stylecheck var ( _ require.TestingT = (*testingTS)(nil) _ assert.TestingT = (*testingTS)(nil) + _ TestingTS = &testing.T{} ) type TestingTS = require.TestingT diff --git a/gnovm/cmd/gno/download_deps.go b/gnovm/cmd/gno/download_deps.go index 5a8c50be20b..4e638eb4970 100644 --- a/gnovm/cmd/gno/download_deps.go +++ b/gnovm/cmd/gno/download_deps.go @@ -25,10 +25,11 @@ func downloadDeps(io commands.IO, pkgDir string, gnoMod *gnomod.File, fetcher pk if err != nil { return fmt.Errorf("read package at %q: %w", pkgDir, err) } - imports, err := packages.Imports(pkg, nil) + importsMap, err := packages.Imports(pkg, nil) if err != nil { return fmt.Errorf("read imports at %q: %w", pkgDir, err) } + imports := importsMap.Merge(packages.FileKindPackageSource, packages.FileKindTest, packages.FileKindXTest) for _, pkgPath := range imports { resolved := gnoMod.Resolve(module.Version{Path: pkgPath.PkgPath}) diff --git a/gnovm/cmd/gno/testdata_test.go b/gnovm/cmd/gno/testdata_test.go index 6b1bbd1d459..c5cb0def04e 100644 --- a/gnovm/cmd/gno/testdata_test.go +++ b/gnovm/cmd/gno/testdata_test.go @@ -3,7 +3,6 @@ package main import ( "os" "path/filepath" - "strconv" "testing" "github.com/gnolang/gno/gnovm/pkg/integration" @@ -18,25 +17,23 @@ func Test_Scripts(t *testing.T) { testdirs, err := os.ReadDir(testdata) require.NoError(t, err) + homeDir, buildDir := t.TempDir(), t.TempDir() for _, dir := range testdirs { if !dir.IsDir() { continue } name := dir.Name() + t.Logf("testing: %s", name) t.Run(name, func(t *testing.T) { - updateScripts, _ := strconv.ParseBool(os.Getenv("UPDATE_SCRIPTS")) - p := testscript.Params{ - UpdateScripts: updateScripts, - Dir: filepath.Join(testdata, name), - } - + testdir := filepath.Join(testdata, name) + p := integration.NewTestingParams(t, testdir) if coverdir, ok := integration.ResolveCoverageDir(); ok { err := integration.SetupTestscriptsCoverage(&p, coverdir) require.NoError(t, err) } - err := integration.SetupGno(&p, t.TempDir()) + err := integration.SetupGno(&p, homeDir, buildDir) require.NoError(t, err) testscript.Run(t, p) diff --git a/gnovm/pkg/doc/dirs.go b/gnovm/pkg/doc/dirs.go index b287fd20708..4c481324e9a 100644 --- a/gnovm/pkg/doc/dirs.go +++ b/gnovm/pkg/doc/dirs.go @@ -105,11 +105,12 @@ func packageImportsRecursive(root string, pkgPath string) []string { pkg = &gnovm.MemPackage{} } - resRaw, err := packages.Imports(pkg, nil) + importsMap, err := packages.Imports(pkg, nil) if err != nil { // ignore packages with invalid imports - resRaw = nil + importsMap = nil } + resRaw := importsMap.Merge(packages.FileKindPackageSource, packages.FileKindTest, packages.FileKindXTest) res := make([]string, len(resRaw)) for idx, imp := range resRaw { res[idx] = imp.PkgPath diff --git a/gnovm/pkg/gnolang/bench_test.go b/gnovm/pkg/gnolang/bench_test.go new file mode 100644 index 00000000000..b638ab66cd0 --- /dev/null +++ b/gnovm/pkg/gnolang/bench_test.go @@ -0,0 +1,31 @@ +package gnolang + +import ( + "testing" +) + +var sink any = nil + +var pkgIDPaths = []string{ + "encoding/json", + "math/bits", + "github.com/gnolang/gno/gnovm/pkg/gnolang", + "a", + " ", + "", + "github.com/gnolang/gno/gnovm/pkg/gnolang/vendor/pkg/github.com/gnolang/vendored", +} + +func BenchmarkPkgIDFromPkgPath(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for _, path := range pkgIDPaths { + sink = PkgIDFromPkgPath(path) + } + } + + if sink == nil { + b.Fatal("Benchmark did not run!") + } + sink = nil +} diff --git a/gnovm/pkg/gnolang/fuzz_test.go b/gnovm/pkg/gnolang/fuzz_test.go new file mode 100644 index 00000000000..977c7453b90 --- /dev/null +++ b/gnovm/pkg/gnolang/fuzz_test.go @@ -0,0 +1,89 @@ +package gnolang + +import ( + "os" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/cockroachdb/apd/v3" +) + +func FuzzConvertUntypedBigdecToFloat(f *testing.F) { + // 1. Firstly add seeds. + seeds := []string{ + "-100000", + "100000", + "0", + } + + check := new(apd.Decimal) + for _, seed := range seeds { + if check.UnmarshalText([]byte(seed)) == nil { + f.Add(seed) + } + } + + f.Fuzz(func(t *testing.T, apdStr string) { + switch { + case strings.HasPrefix(apdStr, ".-"): + return + } + + v := new(apd.Decimal) + if err := v.UnmarshalText([]byte(apdStr)); err != nil { + return + } + if _, err := v.Float64(); err != nil { + return + } + + bd := BigdecValue{ + V: v, + } + dst := new(TypedValue) + typ := Float64Type + ConvertUntypedBigdecTo(dst, bd, typ) + }) +} + +func FuzzParseFile(f *testing.F) { + // 1. Add the corpra. + parseFileDir := filepath.Join("testdata", "corpra", "parsefile") + paths, err := filepath.Glob(filepath.Join(parseFileDir, "*.go")) + if err != nil { + f.Fatal(err) + } + + // Also load in files from gno/gnovm/tests/files + pc, curFile, _, _ := runtime.Caller(0) + curFileDir := filepath.Dir(curFile) + gnovmTestFilesDir, err := filepath.Abs(filepath.Join(curFileDir, "..", "..", "tests", "files")) + if err != nil { + _ = pc // To silence the arbitrary golangci linter. + f.Fatal(err) + } + globGnoTestFiles := filepath.Join(gnovmTestFilesDir, "*.gno") + gnoTestFiles, err := filepath.Glob(globGnoTestFiles) + if err != nil { + f.Fatal(err) + } + if len(gnoTestFiles) == 0 { + f.Fatalf("no files found from globbing %q", globGnoTestFiles) + } + paths = append(paths, gnoTestFiles...) + + for _, path := range paths { + blob, err := os.ReadFile(path) + if err != nil { + f.Fatal(err) + } + f.Add(string(blob)) + } + + // 2. Now run the fuzzer. + f.Fuzz(func(t *testing.T, goFileContents string) { + _, _ = ParseFile("a.go", goFileContents) + }) +} diff --git a/gnovm/pkg/gnolang/gno_test.go b/gnovm/pkg/gnolang/gno_test.go index 89458667997..5a8c6faf315 100644 --- a/gnovm/pkg/gnolang/gno_test.go +++ b/gnovm/pkg/gnolang/gno_test.go @@ -36,6 +36,8 @@ func setupMachine(b *testing.B, numValues, numStmts, numExprs, numBlocks, numFra func BenchmarkStringLargeData(b *testing.B) { m := setupMachine(b, 10000, 5000, 5000, 2000, 3000, 1000) + b.ResetTimer() + b.ReportAllocs() for i := 0; i < b.N; i++ { _ = m.String() diff --git a/gnovm/pkg/gnolang/machine.go b/gnovm/pkg/gnolang/machine.go index 4480a89d16f..75d12ac5402 100644 --- a/gnovm/pkg/gnolang/machine.go +++ b/gnovm/pkg/gnolang/machine.go @@ -2117,6 +2117,10 @@ func (m *Machine) Printf(format string, args ...interface{}) { } func (m *Machine) String() string { + if m == nil { + return "Machine:nil" + } + // Calculate some reasonable total length to avoid reallocation // Assuming an average length of 32 characters per string var ( @@ -2131,25 +2135,26 @@ func (m *Machine) String() string { totalLength = vsLength + ssLength + xsLength + bsLength + obsLength + fsLength + exceptionsLength ) - var builder strings.Builder + var sb strings.Builder + builder := &sb // Pointer for use in fmt.Fprintf. builder.Grow(totalLength) - builder.WriteString(fmt.Sprintf("Machine:\n PreprocessorMode: %v\n Op: %v\n Values: (len: %d)\n", m.PreprocessorMode, m.Ops[:m.NumOps], m.NumValues)) + fmt.Fprintf(builder, "Machine:\n PreprocessorMode: %v\n Op: %v\n Values: (len: %d)\n", m.PreprocessorMode, m.Ops[:m.NumOps], m.NumValues) for i := m.NumValues - 1; i >= 0; i-- { - builder.WriteString(fmt.Sprintf(" #%d %v\n", i, m.Values[i])) + fmt.Fprintf(builder, " #%d %v\n", i, m.Values[i]) } builder.WriteString(" Exprs:\n") for i := len(m.Exprs) - 1; i >= 0; i-- { - builder.WriteString(fmt.Sprintf(" #%d %v\n", i, m.Exprs[i])) + fmt.Fprintf(builder, " #%d %v\n", i, m.Exprs[i]) } builder.WriteString(" Stmts:\n") for i := len(m.Stmts) - 1; i >= 0; i-- { - builder.WriteString(fmt.Sprintf(" #%d %v\n", i, m.Stmts[i])) + fmt.Fprintf(builder, " #%d %v\n", i, m.Stmts[i]) } builder.WriteString(" Blocks:\n") @@ -2166,17 +2171,17 @@ func (m *Machine) String() string { if pv, ok := b.Source.(*PackageNode); ok { // package blocks have too much, so just // print the pkgpath. - builder.WriteString(fmt.Sprintf(" %s(%d) %s\n", gens, gen, pv.PkgPath)) + fmt.Fprintf(builder, " %s(%d) %s\n", gens, gen, pv.PkgPath) } else { bsi := b.StringIndented(" ") - builder.WriteString(fmt.Sprintf(" %s(%d) %s\n", gens, gen, bsi)) + fmt.Fprintf(builder, " %s(%d) %s\n", gens, gen, bsi) if b.Source != nil { sb := b.GetSource(m.Store).GetStaticBlock().GetBlock() - builder.WriteString(fmt.Sprintf(" (s vals) %s(%d) %s\n", gens, gen, sb.StringIndented(" "))) + fmt.Fprintf(builder, " (s vals) %s(%d) %s\n", gens, gen, sb.StringIndented(" ")) sts := b.GetSource(m.Store).GetStaticBlock().Types - builder.WriteString(fmt.Sprintf(" (s typs) %s(%d) %s\n", gens, gen, sts)) + fmt.Fprintf(builder, " (s typs) %s(%d) %s\n", gens, gen, sts) } } @@ -2187,7 +2192,7 @@ func (m *Machine) String() string { case *Block: b = bp case RefValue: - builder.WriteString(fmt.Sprintf(" (block ref %v)\n", bp.ObjectID)) + fmt.Fprintf(builder, " (block ref %v)\n", bp.ObjectID) b = nil default: panic("should not happen") @@ -2206,12 +2211,12 @@ func (m *Machine) String() string { if _, ok := b.Source.(*PackageNode); ok { break // done, skip *PackageNode. } else { - builder.WriteString(fmt.Sprintf(" #%d %s\n", i, - b.StringIndented(" "))) + fmt.Fprintf(builder, " #%d %s\n", i, + b.StringIndented(" ")) if b.Source != nil { sb := b.GetSource(m.Store).GetStaticBlock().GetBlock() - builder.WriteString(fmt.Sprintf(" (static) #%d %s\n", i, - sb.StringIndented(" "))) + fmt.Fprintf(builder, " (static) #%d %s\n", i, + sb.StringIndented(" ")) } } } @@ -2219,17 +2224,17 @@ func (m *Machine) String() string { builder.WriteString(" Frames:\n") for i := len(m.Frames) - 1; i >= 0; i-- { - builder.WriteString(fmt.Sprintf(" #%d %s\n", i, m.Frames[i])) + fmt.Fprintf(builder, " #%d %s\n", i, m.Frames[i]) } if m.Realm != nil { - builder.WriteString(fmt.Sprintf(" Realm:\n %s\n", m.Realm.Path)) + fmt.Fprintf(builder, " Realm:\n %s\n", m.Realm.Path) } builder.WriteString(" Exceptions:\n") for _, ex := range m.Exceptions { - builder.WriteString(fmt.Sprintf(" %s\n", ex.Sprint(m))) + fmt.Fprintf(builder, " %s\n", ex.Sprint(m)) } return builder.String() diff --git a/gnovm/pkg/gnolang/machine_test.go b/gnovm/pkg/gnolang/machine_test.go index c3b2118f099..c2ab8ea12c5 100644 --- a/gnovm/pkg/gnolang/machine_test.go +++ b/gnovm/pkg/gnolang/machine_test.go @@ -9,6 +9,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/store/dbadapter" "github.com/gnolang/gno/tm2/pkg/store/iavl" stypes "github.com/gnolang/gno/tm2/pkg/store/types" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" ) @@ -56,3 +57,61 @@ func TestRunMemPackageWithOverrides_revertToOld(t *testing.T) { assert.Equal(t, StringKind, v.T.Kind()) assert.Equal(t, StringValue("1"), v.V) } + +func TestMachineString(t *testing.T) { + cases := []struct { + name string + in *Machine + want string + }{ + { + "nil Machine", + nil, + "Machine:nil", + }, + { + "created with defaults", + NewMachineWithOptions(MachineOptions{}), + "Machine:\n PreprocessorMode: false\n Op: []\n Values: (len: 0)\n Exprs:\n Stmts:\n Blocks:\n Blocks (other):\n Frames:\n Exceptions:\n", + }, + { + "created with store and defaults", + func() *Machine { + db := memdb.NewMemDB() + baseStore := dbadapter.StoreConstructor(db, stypes.StoreOptions{}) + iavlStore := iavl.StoreConstructor(db, stypes.StoreOptions{}) + store := NewStore(nil, baseStore, iavlStore) + return NewMachine("std", store) + }(), + "Machine:\n PreprocessorMode: false\n Op: []\n Values: (len: 0)\n Exprs:\n Stmts:\n Blocks:\n Blocks (other):\n Frames:\n Exceptions:\n", + }, + { + "filled in", + func() *Machine { + db := memdb.NewMemDB() + baseStore := dbadapter.StoreConstructor(db, stypes.StoreOptions{}) + iavlStore := iavl.StoreConstructor(db, stypes.StoreOptions{}) + store := NewStore(nil, baseStore, iavlStore) + m := NewMachine("std", store) + m.PushOp(OpHalt) + m.PushExpr(&BasicLitExpr{ + Kind: INT, + Value: "100", + }) + m.Blocks = make([]*Block, 1, 1) + m.PushStmts(S(Call(X("Redecl"), 11))) + return m + }(), + "Machine:\n PreprocessorMode: false\n Op: [OpHalt]\n Values: (len: 0)\n Exprs:\n #0 100\n Stmts:\n #0 Redecl(11)\n Blocks:\n Blocks (other):\n Frames:\n Exceptions:\n", + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got := tt.in.String() + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Fatalf("Mismatch: got - want +\n%s", diff) + } + }) + } +} diff --git a/gnovm/pkg/gnolang/op_bench_test.go b/gnovm/pkg/gnolang/op_bench_test.go new file mode 100644 index 00000000000..5874f980285 --- /dev/null +++ b/gnovm/pkg/gnolang/op_bench_test.go @@ -0,0 +1,70 @@ +package gnolang + +import ( + "testing" + + "github.com/gnolang/gno/tm2/pkg/overflow" +) + +func BenchmarkOpAdd(b *testing.B) { + m := NewMachine("bench", nil) + x := TypedValue{T: IntType} + x.SetInt(4) + y := TypedValue{T: IntType} + y.SetInt(3) + + b.ResetTimer() + + for range b.N { + m.PushOp(OpHalt) + m.PushExpr(&BinaryExpr{}) + m.PushValue(x) + m.PushValue(y) + m.PushOp(OpAdd) + m.Run() + } +} + +//go:noinline +func AddNoOverflow(x, y int) int { return x + y } + +func BenchmarkAddNoOverflow(b *testing.B) { + x, y := 4, 3 + c := 0 + for range b.N { + c = AddNoOverflow(x, y) + } + if c != 7 { + b.Error("invalid result") + } +} + +func BenchmarkAddOverflow(b *testing.B) { + x, y := 4, 3 + c := 0 + for range b.N { + c = overflow.Addp(x, y) + } + if c != 7 { + b.Error("invalid result") + } +} + +func TestOpAdd1(t *testing.T) { + m := NewMachine("test", nil) + a := TypedValue{T: IntType} + a.SetInt(4) + b := TypedValue{T: IntType} + b.SetInt(3) + t.Log("a:", a, "b:", b) + + start := m.NumValues + m.PushOp(OpHalt) + m.PushExpr(&BinaryExpr{}) + m.PushValue(a) + m.PushValue(b) + m.PushOp(OpAdd) + m.Run() + res := m.ReapValues(start) + t.Log("res:", res) +} diff --git a/gnovm/pkg/gnolang/op_binary.go b/gnovm/pkg/gnolang/op_binary.go index 6d26fa7ce54..0f66da5e685 100644 --- a/gnovm/pkg/gnolang/op_binary.go +++ b/gnovm/pkg/gnolang/op_binary.go @@ -6,6 +6,7 @@ import ( "math/big" "github.com/cockroachdb/apd/v3" + "github.com/gnolang/gno/tm2/pkg/overflow" ) // ---------------------------------------- @@ -183,7 +184,9 @@ func (m *Machine) doOpAdd() { } // add rv to lv. - addAssign(m.Alloc, lv, rv) + if err := addAssign(m.Alloc, lv, rv); err != nil { + panic(err) + } } func (m *Machine) doOpSub() { @@ -197,7 +200,9 @@ func (m *Machine) doOpSub() { } // sub rv from lv. - subAssign(lv, rv) + if err := subAssign(lv, rv); err != nil { + panic(err) + } } func (m *Machine) doOpBor() { @@ -253,8 +258,7 @@ func (m *Machine) doOpQuo() { } // lv / rv - err := quoAssign(lv, rv) - if err != nil { + if err := quoAssign(lv, rv); err != nil { panic(err) } } @@ -270,8 +274,7 @@ func (m *Machine) doOpRem() { } // lv % rv - err := remAssign(lv, rv) - if err != nil { + if err := remAssign(lv, rv); err != nil { panic(err) } } @@ -683,23 +686,38 @@ func isGeq(lv, rv *TypedValue) bool { } } -// for doOpAdd and doOpAddAssign. -func addAssign(alloc *Allocator, lv, rv *TypedValue) { +// addAssign adds lv to rv and stores the result to lv. +// It returns an exception in case of overflow on signed integers. +// The assignement is performed even in case of exception. +func addAssign(alloc *Allocator, lv, rv *TypedValue) *Exception { // set the result in lv. // NOTE this block is replicated in op_assign.go + ok := true switch baseOf(lv.T) { case StringType, UntypedStringType: lv.V = alloc.NewString(lv.GetString() + rv.GetString()) + // Signed integers may overflow, which triggers an exception. case IntType: - lv.SetInt(lv.GetInt() + rv.GetInt()) + var r int + r, ok = overflow.Add(lv.GetInt(), rv.GetInt()) + lv.SetInt(r) case Int8Type: - lv.SetInt8(lv.GetInt8() + rv.GetInt8()) + var r int8 + r, ok = overflow.Add8(lv.GetInt8(), rv.GetInt8()) + lv.SetInt8(r) case Int16Type: - lv.SetInt16(lv.GetInt16() + rv.GetInt16()) + var r int16 + r, ok = overflow.Add16(lv.GetInt16(), rv.GetInt16()) + lv.SetInt16(r) case Int32Type, UntypedRuneType: - lv.SetInt32(lv.GetInt32() + rv.GetInt32()) + var r int32 + r, ok = overflow.Add32(lv.GetInt32(), rv.GetInt32()) + lv.SetInt32(r) case Int64Type: - lv.SetInt64(lv.GetInt64() + rv.GetInt64()) + var r int64 + r, ok = overflow.Add64(lv.GetInt64(), rv.GetInt64()) + lv.SetInt64(r) + // Unsigned integers do not overflow, they just wrap. case UintType: lv.SetUint(lv.GetUint() + rv.GetUint()) case Uint8Type: @@ -739,23 +757,42 @@ func addAssign(alloc *Allocator, lv, rv *TypedValue) { lv.T, )) } + if !ok { + return &Exception{Value: typedString("addition overflow")} + } + return nil } -// for doOpSub and doOpSubAssign. -func subAssign(lv, rv *TypedValue) { +// subAssign subtracts lv to rv and stores the result to lv. +// It returns an exception in case of overflow on signed integers. +// The subtraction is performed even in case of exception. +func subAssign(lv, rv *TypedValue) *Exception { // set the result in lv. // NOTE this block is replicated in op_assign.go + ok := true switch baseOf(lv.T) { + // Signed integers may overflow, which triggers an exception. case IntType: - lv.SetInt(lv.GetInt() - rv.GetInt()) + var r int + r, ok = overflow.Sub(lv.GetInt(), rv.GetInt()) + lv.SetInt(r) case Int8Type: - lv.SetInt8(lv.GetInt8() - rv.GetInt8()) + var r int8 + r, ok = overflow.Sub8(lv.GetInt8(), rv.GetInt8()) + lv.SetInt8(r) case Int16Type: - lv.SetInt16(lv.GetInt16() - rv.GetInt16()) + var r int16 + r, ok = overflow.Sub16(lv.GetInt16(), rv.GetInt16()) + lv.SetInt16(r) case Int32Type, UntypedRuneType: - lv.SetInt32(lv.GetInt32() - rv.GetInt32()) + var r int32 + r, ok = overflow.Sub32(lv.GetInt32(), rv.GetInt32()) + lv.SetInt32(r) case Int64Type: - lv.SetInt64(lv.GetInt64() - rv.GetInt64()) + var r int64 + r, ok = overflow.Sub64(lv.GetInt64(), rv.GetInt64()) + lv.SetInt64(r) + // Unsigned integers do not overflow, they just wrap. case UintType: lv.SetUint(lv.GetUint() - rv.GetUint()) case Uint8Type: @@ -795,23 +832,39 @@ func subAssign(lv, rv *TypedValue) { lv.T, )) } + if !ok { + return &Exception{Value: typedString("subtraction overflow")} + } + return nil } // for doOpMul and doOpMulAssign. -func mulAssign(lv, rv *TypedValue) { +func mulAssign(lv, rv *TypedValue) *Exception { // set the result in lv. // NOTE this block is replicated in op_assign.go + ok := true switch baseOf(lv.T) { + // Signed integers may overflow, which triggers a panic. case IntType: - lv.SetInt(lv.GetInt() * rv.GetInt()) + var r int + r, ok = overflow.Mul(lv.GetInt(), rv.GetInt()) + lv.SetInt(r) case Int8Type: - lv.SetInt8(lv.GetInt8() * rv.GetInt8()) + var r int8 + r, ok = overflow.Mul8(lv.GetInt8(), rv.GetInt8()) + lv.SetInt8(r) case Int16Type: - lv.SetInt16(lv.GetInt16() * rv.GetInt16()) + var r int16 + r, ok = overflow.Mul16(lv.GetInt16(), rv.GetInt16()) + lv.SetInt16(r) case Int32Type, UntypedRuneType: - lv.SetInt32(lv.GetInt32() * rv.GetInt32()) + var r int32 + r, ok = overflow.Mul32(lv.GetInt32(), rv.GetInt32()) + lv.SetInt32(r) case Int64Type: - lv.SetInt64(lv.GetInt64() * rv.GetInt64()) + var r int64 + r, ok = overflow.Mul64(lv.GetInt64(), rv.GetInt64()) + lv.SetInt64(r) case UintType: lv.SetUint(lv.GetUint() * rv.GetUint()) case Uint8Type: @@ -849,96 +902,105 @@ func mulAssign(lv, rv *TypedValue) { lv.T, )) } + if !ok { + return &Exception{Value: typedString("multiplication overflow")} + } + return nil } // for doOpQuo and doOpQuoAssign. func quoAssign(lv, rv *TypedValue) *Exception { - expt := &Exception{ - Value: typedString("division by zero"), - } - // set the result in lv. // NOTE this block is replicated in op_assign.go + ok := true switch baseOf(lv.T) { + // Signed integers may overflow or cause a division by 0, which triggers a panic. case IntType: - if rv.GetInt() == 0 { - return expt - } - lv.SetInt(lv.GetInt() / rv.GetInt()) + var q int + q, _, ok = overflow.Quotient(lv.GetInt(), rv.GetInt()) + lv.SetInt(q) case Int8Type: - if rv.GetInt8() == 0 { - return expt - } - lv.SetInt8(lv.GetInt8() / rv.GetInt8()) + var q int8 + q, _, ok = overflow.Quotient8(lv.GetInt8(), rv.GetInt8()) + lv.SetInt8(q) case Int16Type: - if rv.GetInt16() == 0 { - return expt - } - lv.SetInt16(lv.GetInt16() / rv.GetInt16()) + var q int16 + q, _, ok = overflow.Quotient16(lv.GetInt16(), rv.GetInt16()) + lv.SetInt16(q) case Int32Type, UntypedRuneType: - if rv.GetInt32() == 0 { - return expt - } - lv.SetInt32(lv.GetInt32() / rv.GetInt32()) + var q int32 + q, _, ok = overflow.Quotient32(lv.GetInt32(), rv.GetInt32()) + lv.SetInt32(q) case Int64Type: - if rv.GetInt64() == 0 { - return expt - } - lv.SetInt64(lv.GetInt64() / rv.GetInt64()) + var q int64 + q, _, ok = overflow.Quotient64(lv.GetInt64(), rv.GetInt64()) + lv.SetInt64(q) + // Unsigned integers do not cause overflow, but a division by 0 may still occur. case UintType: - if rv.GetUint() == 0 { - return expt + y := rv.GetUint() + ok = y != 0 + if ok { + lv.SetUint(lv.GetUint() / y) } - lv.SetUint(lv.GetUint() / rv.GetUint()) case Uint8Type: - if rv.GetUint8() == 0 { - return expt + y := rv.GetUint8() + ok = y != 0 + if ok { + lv.SetUint8(lv.GetUint8() / y) } - lv.SetUint8(lv.GetUint8() / rv.GetUint8()) case DataByteType: - if rv.GetUint8() == 0 { - return expt + y := rv.GetUint8() + ok = y != 0 + if ok { + lv.SetDataByte(lv.GetDataByte() / y) } - lv.SetDataByte(lv.GetDataByte() / rv.GetUint8()) case Uint16Type: - if rv.GetUint16() == 0 { - return expt + y := rv.GetUint16() + ok = y != 0 + if ok { + lv.SetUint16(lv.GetUint16() / y) } - lv.SetUint16(lv.GetUint16() / rv.GetUint16()) case Uint32Type: - if rv.GetUint32() == 0 { - return expt + y := rv.GetUint32() + ok = y != 0 + if ok { + lv.SetUint32(lv.GetUint32() / y) } - lv.SetUint32(lv.GetUint32() / rv.GetUint32()) case Uint64Type: - if rv.GetUint64() == 0 { - return expt + y := rv.GetUint64() + ok = y != 0 + if ok { + lv.SetUint64(lv.GetUint64() / y) } - lv.SetUint64(lv.GetUint64() / rv.GetUint64()) + // XXX Handling float overflows is more complex. case Float32Type: // NOTE: gno doesn't fuse *+. - if rv.GetFloat32() == 0 { - return expt + y := rv.GetFloat32() + ok = y != 0 + if ok { + lv.SetFloat32(lv.GetFloat32() / y) } - lv.SetFloat32(lv.GetFloat32() / rv.GetFloat32()) // XXX FOR DETERMINISM, PANIC IF NAN. case Float64Type: // NOTE: gno doesn't fuse *+. - if rv.GetFloat64() == 0 { - return expt + y := rv.GetFloat64() + ok = y != 0 + if ok { + lv.SetFloat64(lv.GetFloat64() / y) } - lv.SetFloat64(lv.GetFloat64() / rv.GetFloat64()) // XXX FOR DETERMINISM, PANIC IF NAN. case BigintType, UntypedBigintType: if rv.GetBigInt().Sign() == 0 { - return expt + ok = false + break } lb := lv.GetBigInt() lb = big.NewInt(0).Quo(lb, rv.GetBigInt()) lv.V = BigintValue{V: lb} case BigdecType, UntypedBigdecType: if rv.GetBigDec().Cmp(apd.New(0, 0)) == 0 { - return expt + ok = false + break } lb := lv.GetBigDec() rb := rv.GetBigDec() @@ -955,81 +1017,83 @@ func quoAssign(lv, rv *TypedValue) *Exception { )) } + if !ok { + return &Exception{Value: typedString("division by zero or overflow")} + } return nil } // for doOpRem and doOpRemAssign. func remAssign(lv, rv *TypedValue) *Exception { - expt := &Exception{ - Value: typedString("division by zero"), - } - // set the result in lv. // NOTE this block is replicated in op_assign.go + ok := true switch baseOf(lv.T) { + // Signed integers may overflow or cause a division by 0, which triggers a panic. case IntType: - if rv.GetInt() == 0 { - return expt - } - lv.SetInt(lv.GetInt() % rv.GetInt()) + var r int + _, r, ok = overflow.Quotient(lv.GetInt(), rv.GetInt()) + lv.SetInt(r) case Int8Type: - if rv.GetInt8() == 0 { - return expt - } - lv.SetInt8(lv.GetInt8() % rv.GetInt8()) + var r int8 + _, r, ok = overflow.Quotient8(lv.GetInt8(), rv.GetInt8()) + lv.SetInt8(r) case Int16Type: - if rv.GetInt16() == 0 { - return expt - } - lv.SetInt16(lv.GetInt16() % rv.GetInt16()) + var r int16 + _, r, ok = overflow.Quotient16(lv.GetInt16(), rv.GetInt16()) + lv.SetInt16(r) case Int32Type, UntypedRuneType: - if rv.GetInt32() == 0 { - return expt - } - lv.SetInt32(lv.GetInt32() % rv.GetInt32()) + var r int32 + _, r, ok = overflow.Quotient32(lv.GetInt32(), rv.GetInt32()) + lv.SetInt32(r) case Int64Type: - if rv.GetInt64() == 0 { - return expt - } - lv.SetInt64(lv.GetInt64() % rv.GetInt64()) + var r int64 + _, r, ok = overflow.Quotient64(lv.GetInt64(), rv.GetInt64()) + lv.SetInt64(r) + // Unsigned integers do not cause overflow, but a division by 0 may still occur. case UintType: - if rv.GetUint() == 0 { - return expt + y := rv.GetUint() + ok = y != 0 + if ok { + lv.SetUint(lv.GetUint() % y) } - lv.SetUint(lv.GetUint() % rv.GetUint()) case Uint8Type: - if rv.GetUint8() == 0 { - return expt + y := rv.GetUint8() + ok = y != 0 + if ok { + lv.SetUint8(lv.GetUint8() % y) } - lv.SetUint8(lv.GetUint8() % rv.GetUint8()) case DataByteType: - if rv.GetUint8() == 0 { - return expt + y := rv.GetUint8() + ok = y != 0 + if ok { + lv.SetDataByte(lv.GetDataByte() % y) } - lv.SetDataByte(lv.GetDataByte() % rv.GetUint8()) case Uint16Type: - if rv.GetUint16() == 0 { - return expt + y := rv.GetUint16() + ok = y != 0 + if ok { + lv.SetUint16(lv.GetUint16() % y) } - lv.SetUint16(lv.GetUint16() % rv.GetUint16()) case Uint32Type: - if rv.GetUint32() == 0 { - return expt + y := rv.GetUint32() + ok = y != 0 + if ok { + lv.SetUint32(lv.GetUint32() % y) } - lv.SetUint32(lv.GetUint32() % rv.GetUint32()) case Uint64Type: - if rv.GetUint64() == 0 { - return expt + y := rv.GetUint64() + ok = y != 0 + if ok { + lv.SetUint64(lv.GetUint64() % y) } - lv.SetUint64(lv.GetUint64() % rv.GetUint64()) case BigintType, UntypedBigintType: - if rv.GetBigInt().Sign() == 0 { - return expt + ok = rv.GetBigInt().Sign() != 0 + if ok { + lb := lv.GetBigInt() + lb = big.NewInt(0).Rem(lb, rv.GetBigInt()) + lv.V = BigintValue{V: lb} } - - lb := lv.GetBigInt() - lb = big.NewInt(0).Rem(lb, rv.GetBigInt()) - lv.V = BigintValue{V: lb} default: panic(fmt.Sprintf( "operators %% and %%= not defined for %s", @@ -1037,6 +1101,9 @@ func remAssign(lv, rv *TypedValue) *Exception { )) } + if !ok { + return &Exception{Value: typedString("division by zero or overflow")} + } return nil } diff --git a/gnovm/pkg/gnolang/op_inc_dec.go b/gnovm/pkg/gnolang/op_inc_dec.go index 7a8a885bcf0..1e68e195596 100644 --- a/gnovm/pkg/gnolang/op_inc_dec.go +++ b/gnovm/pkg/gnolang/op_inc_dec.go @@ -5,6 +5,7 @@ import ( "math/big" "github.com/cockroachdb/apd/v3" + "github.com/gnolang/gno/tm2/pkg/overflow" ) func (m *Machine) doOpInc() { @@ -31,16 +32,18 @@ func (m *Machine) doOpInc() { // because it could be a type alias // type num int switch baseOf(lv.T) { + // Signed integers may overflow, which triggers a panic. case IntType: - lv.SetInt(lv.GetInt() + 1) + lv.SetInt(overflow.Addp(lv.GetInt(), 1)) case Int8Type: - lv.SetInt8(lv.GetInt8() + 1) + lv.SetInt8(overflow.Add8p(lv.GetInt8(), 1)) case Int16Type: - lv.SetInt16(lv.GetInt16() + 1) + lv.SetInt16(overflow.Add16p(lv.GetInt16(), 1)) case Int32Type: - lv.SetInt32(lv.GetInt32() + 1) + lv.SetInt32(overflow.Add32p(lv.GetInt32(), 1)) case Int64Type: - lv.SetInt64(lv.GetInt64() + 1) + lv.SetInt64(overflow.Add64p(lv.GetInt64(), 1)) + // Unsigned integers do not overflow, they just wrap. case UintType: lv.SetUint(lv.GetUint() + 1) case Uint8Type: @@ -101,16 +104,18 @@ func (m *Machine) doOpDec() { } } switch baseOf(lv.T) { + // Signed integers may overflow, which triggers a panic. case IntType: - lv.SetInt(lv.GetInt() - 1) + lv.SetInt(overflow.Subp(lv.GetInt(), 1)) case Int8Type: - lv.SetInt8(lv.GetInt8() - 1) + lv.SetInt8(overflow.Sub8p(lv.GetInt8(), 1)) case Int16Type: - lv.SetInt16(lv.GetInt16() - 1) + lv.SetInt16(overflow.Sub16p(lv.GetInt16(), 1)) case Int32Type: - lv.SetInt32(lv.GetInt32() - 1) + lv.SetInt32(overflow.Sub32p(lv.GetInt32(), 1)) case Int64Type: - lv.SetInt64(lv.GetInt64() - 1) + lv.SetInt64(overflow.Sub64p(lv.GetInt64(), 1)) + // Unsigned integers do not overflow, they just wrap. case UintType: lv.SetUint(lv.GetUint() - 1) case Uint8Type: diff --git a/gnovm/pkg/gnolang/realm.go b/gnovm/pkg/gnolang/realm.go index d822eb290eb..04de760037a 100644 --- a/gnovm/pkg/gnolang/realm.go +++ b/gnovm/pkg/gnolang/realm.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" bm "github.com/gnolang/gno/gnovm/pkg/benchops" ) @@ -71,8 +72,25 @@ func (pid PkgID) Bytes() []byte { return pid.Hashlet[:] } +var ( + pkgIDFromPkgPathCacheMu sync.Mutex // protects the shared cache. + // TODO: later on switch this to an LRU if needed to ensure + // fixed memory caps. For now though it isn't a problem: + // https://github.com/gnolang/gno/pull/3424#issuecomment-2564571785 + pkgIDFromPkgPathCache = make(map[string]*PkgID, 100) +) + func PkgIDFromPkgPath(path string) PkgID { - return PkgID{HashBytes([]byte(path))} + pkgIDFromPkgPathCacheMu.Lock() + defer pkgIDFromPkgPathCacheMu.Unlock() + + pkgID, ok := pkgIDFromPkgPathCache[path] + if !ok { + pkgID = new(PkgID) + *pkgID = PkgID{HashBytes([]byte(path))} + pkgIDFromPkgPathCache[path] = pkgID + } + return *pkgID } // Returns the ObjectID of the PackageValue associated with path. diff --git a/gnovm/pkg/gnolang/testdata/corpra/parsefile/a.go b/gnovm/pkg/gnolang/testdata/corpra/parsefile/a.go new file mode 100644 index 00000000000..ae05a655fd7 --- /dev/null +++ b/gnovm/pkg/gnolang/testdata/corpra/parsefile/a.go @@ -0,0 +1,9 @@ +package main + +import + _ "math/big" +) + +func main() { + println("Foo") +} diff --git a/gnovm/pkg/gnolang/testdata/corpra/parsefile/b.go b/gnovm/pkg/gnolang/testdata/corpra/parsefile/b.go new file mode 100644 index 00000000000..07592ff47ed --- /dev/null +++ b/gnovm/pkg/gnolang/testdata/corpra/parsefile/b.go @@ -0,0 +1,16 @@ +package main + +import "crypto/rand" + +func init() { +} + +func init() { +} + +func init() { +} + +func it() { + _ = rand.Read +} diff --git a/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3013.go b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3013.go new file mode 100644 index 00000000000..f664f68f1b6 --- /dev/null +++ b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3013.go @@ -0,0 +1,22 @@ +package main + +import "testing" + +func TestDummy(t *testing.T) { + testTable := []struct { + name string + }{ + { + "one", + }, + { + "two", + }, + } + + for _, testCase := range testTable { + testCase := testCase + + println(testCase.name) + } +} diff --git a/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine.go b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine.go new file mode 100644 index 00000000000..877b5fafc1d --- /dev/null +++ b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine.go @@ -0,0 +1,10 @@ +package main + +var ss = []int{1, 2, 3} + +func main() { + for _, s := range ss { + s := s + println(s) + } +} diff --git a/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine2.go b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine2.go new file mode 100644 index 00000000000..450406a2202 --- /dev/null +++ b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine2.go @@ -0,0 +1,21 @@ +package main + +var testTable = []struct { + name string +}{ + { + "one", + }, + { + "two", + }, +} + +func main() { + + for _, testCase := range testTable { + testCase := testCase + + println(testCase.name) + } +} diff --git a/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine3.go b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine3.go new file mode 100644 index 00000000000..74ed729fb28 --- /dev/null +++ b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine3.go @@ -0,0 +1,8 @@ +package main + +func main() { + for i := 0; i < 3; i++ { + i := i + println(i) + } +} diff --git a/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine4.go b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine4.go new file mode 100644 index 00000000000..db27dcdc6bf --- /dev/null +++ b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine4.go @@ -0,0 +1,11 @@ +package main + +func main() { + a := 1 + b := 3 + println(a, b) // prints 1 3 + + // Re-declaration of 'a' is allowed because 'c' is a new variable + a, c := 2, 5 + println(a, c) // prints 2 5 +} diff --git a/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine5.go b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine5.go new file mode 100644 index 00000000000..97ec50f3330 --- /dev/null +++ b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine5.go @@ -0,0 +1,13 @@ +package main + +func main() { + a := 1 + println(a) // prints 1 + + if true { + a := 2 // valid: new scope inside the if statement + println(a) // prints 2 + } + + println(a) // prints 1: outer variable is unchanged +} diff --git a/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine6.go b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine6.go new file mode 100644 index 00000000000..49b871be9c0 --- /dev/null +++ b/gnovm/pkg/gnolang/testdata/corpra/parsefile/bug_3014_redefine6.go @@ -0,0 +1,6 @@ +package main + +func main() { + a, b := 1, 2 + a, b := 3, 4 +} diff --git a/gnovm/pkg/gnolang/testdata/fuzz/FuzzConvertUntypedBigdecToFloat/4be24841138e3224 b/gnovm/pkg/gnolang/testdata/fuzz/FuzzConvertUntypedBigdecToFloat/4be24841138e3224 new file mode 100644 index 00000000000..8894efa8d8e --- /dev/null +++ b/gnovm/pkg/gnolang/testdata/fuzz/FuzzConvertUntypedBigdecToFloat/4be24841138e3224 @@ -0,0 +1,2 @@ +go test fuzz v1 +string(".-700000000000000000000000000000000000000") diff --git a/gnovm/pkg/gnolang/testdata/fuzz/FuzzConvertUntypedBigdecToFloat/94196a9456e79dac b/gnovm/pkg/gnolang/testdata/fuzz/FuzzConvertUntypedBigdecToFloat/94196a9456e79dac new file mode 100644 index 00000000000..a317fbab107 --- /dev/null +++ b/gnovm/pkg/gnolang/testdata/fuzz/FuzzConvertUntypedBigdecToFloat/94196a9456e79dac @@ -0,0 +1,2 @@ +go test fuzz v1 +string("200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000") diff --git a/gnovm/pkg/gnomod/pkg.go b/gnovm/pkg/gnomod/pkg.go index a0831d494b0..85f1d31442d 100644 --- a/gnovm/pkg/gnomod/pkg.go +++ b/gnovm/pkg/gnomod/pkg.go @@ -121,11 +121,13 @@ func ListPkgs(root string) (PkgList, error) { pkg = &gnovm.MemPackage{} } - importsRaw, err := packages.Imports(pkg, nil) + importsMap, err := packages.Imports(pkg, nil) if err != nil { // ignore imports on error - importsRaw = nil + importsMap = nil } + importsRaw := importsMap.Merge(packages.FileKindPackageSource, packages.FileKindTest, packages.FileKindXTest) + imports := make([]string, 0, len(importsRaw)) for _, imp := range importsRaw { // remove self and standard libraries from imports diff --git a/gnovm/pkg/integration/testscript.go b/gnovm/pkg/integration/testscript.go new file mode 100644 index 00000000000..5b7c5c81a44 --- /dev/null +++ b/gnovm/pkg/integration/testscript.go @@ -0,0 +1,36 @@ +package integration + +import ( + "os" + "strconv" + "testing" + + "github.com/rogpeppe/go-internal/testscript" +) + +// NewTestingParams setup and initialize base params for testing. +func NewTestingParams(t *testing.T, testdir string) testscript.Params { + t.Helper() + + var params testscript.Params + params.Dir = testdir + + params.UpdateScripts, _ = strconv.ParseBool(os.Getenv("UPDATE_SCRIPTS")) + params.TestWork, _ = strconv.ParseBool(os.Getenv("TESTWORK")) + if deadline, ok := t.Deadline(); ok && params.Deadline.IsZero() { + params.Deadline = deadline + } + + // Store the original setup scripts for potential wrapping + params.Setup = func(env *testscript.Env) error { + // Set the UPDATE_SCRIPTS environment variable + env.Setenv("UPDATE_SCRIPTS", strconv.FormatBool(params.UpdateScripts)) + + // Set the environment variable + env.Setenv("TESTWORK", strconv.FormatBool(params.TestWork)) + + return nil + } + + return params +} diff --git a/gnovm/pkg/integration/coverage.go b/gnovm/pkg/integration/testscript_coverage.go similarity index 100% rename from gnovm/pkg/integration/coverage.go rename to gnovm/pkg/integration/testscript_coverage.go diff --git a/gnovm/pkg/integration/gno.go b/gnovm/pkg/integration/testscript_gno.go similarity index 87% rename from gnovm/pkg/integration/gno.go rename to gnovm/pkg/integration/testscript_gno.go index a389b6a9b24..03a3fcd6056 100644 --- a/gnovm/pkg/integration/gno.go +++ b/gnovm/pkg/integration/testscript_gno.go @@ -17,7 +17,7 @@ import ( // If the `gno` binary doesn't exist, it's built using the `go build` command into the specified buildDir. // The function also include the `gno` command into `p.Cmds` to and wrap environment into p.Setup // to correctly set up the environment variables needed for the `gno` command. -func SetupGno(p *testscript.Params, buildDir string) error { +func SetupGno(p *testscript.Params, homeDir, buildDir string) error { // Try to fetch `GNOROOT` from the environment variables gnoroot := gnoenv.RootDir() @@ -62,18 +62,10 @@ func SetupGno(p *testscript.Params, buildDir string) error { // Set the GNOROOT environment variable env.Setenv("GNOROOT", gnoroot) - // Create a temporary home directory because certain commands require access to $HOME/.cache/go-build - home, err := os.MkdirTemp("", "gno") - if err != nil { - return fmt.Errorf("unable to create temporary home directory: %w", err) - } - env.Setenv("HOME", home) + env.Setenv("HOME", homeDir) // Avoids go command printing errors relating to lack of go.mod. env.Setenv("GO111MODULE", "off") - // Cleanup home folder - env.Defer(func() { os.RemoveAll(home) }) - return nil } diff --git a/gnovm/pkg/packages/filekind.go b/gnovm/pkg/packages/filekind.go new file mode 100644 index 00000000000..ed2ca84b7d0 --- /dev/null +++ b/gnovm/pkg/packages/filekind.go @@ -0,0 +1,56 @@ +package packages + +import ( + "fmt" + "go/parser" + "go/token" + "strings" +) + +// FileKind represent the category a gno source file falls in, can be one of: +// +// - [FileKindPackageSource] -> A *.gno file that will be included in the gnovm package +// +// - [FileKindTest] -> A *_test.gno file that will be used for testing +// +// - [FileKindXTest] -> A *_test.gno file with a package name ending in _test that will be used for blackbox testing +// +// - [FileKindFiletest] -> A *_filetest.gno file that will be used for snapshot testing +type FileKind uint + +const ( + FileKindUnknown FileKind = iota + FileKindPackageSource + FileKindTest + FileKindXTest + FileKindFiletest +) + +// GetFileKind analyzes a file's name and body to get it's [FileKind], fset is optional +func GetFileKind(filename string, body string, fset *token.FileSet) (FileKind, error) { + if !strings.HasSuffix(filename, ".gno") { + return FileKindUnknown, fmt.Errorf("%s:1:1: not a gno file", filename) + } + + if strings.HasSuffix(filename, "_filetest.gno") { + return FileKindFiletest, nil + } + + if !strings.HasSuffix(filename, "_test.gno") { + return FileKindPackageSource, nil + } + + if fset == nil { + fset = token.NewFileSet() + } + ast, err := parser.ParseFile(fset, filename, body, parser.PackageClauseOnly) + if err != nil { + return FileKindUnknown, err + } + packageName := ast.Name.Name + + if strings.HasSuffix(packageName, "_test") { + return FileKindXTest, nil + } + return FileKindTest, nil +} diff --git a/gnovm/pkg/packages/filekind_test.go b/gnovm/pkg/packages/filekind_test.go new file mode 100644 index 00000000000..bd06b49fb45 --- /dev/null +++ b/gnovm/pkg/packages/filekind_test.go @@ -0,0 +1,63 @@ +package packages + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetFileKind(t *testing.T) { + tcs := []struct { + name string + filename string + body string + fileKind FileKind + errorContains string + }{ + { + name: "compiled", + filename: "foo.gno", + fileKind: FileKindPackageSource, + }, + { + name: "test", + filename: "foo_test.gno", + body: "package foo", + fileKind: FileKindTest, + }, + { + name: "xtest", + filename: "foo_test.gno", + body: "package foo_test", + fileKind: FileKindXTest, + }, + { + name: "filetest", + filename: "foo_filetest.gno", + fileKind: FileKindFiletest, + }, + { + name: "err_badpkgclause", + filename: "foo_test.gno", + body: "pakage foo", + errorContains: "foo_test.gno:1:1: expected 'package', found pakage", + }, + { + name: "err_notgnofile", + filename: "foo.gno.bck", + errorContains: `foo.gno.bck:1:1: not a gno file`, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + out, err := GetFileKind(tc.filename, tc.body, nil) + if len(tc.errorContains) != 0 { + require.ErrorContains(t, err, tc.errorContains) + } else { + require.NoError(t, err) + } + require.Equal(t, tc.fileKind, out) + }) + } +} diff --git a/gnovm/pkg/packages/imports.go b/gnovm/pkg/packages/imports.go index 201965bc588..3bc60be6664 100644 --- a/gnovm/pkg/packages/imports.go +++ b/gnovm/pkg/packages/imports.go @@ -14,33 +14,40 @@ import ( // Imports returns the list of gno imports from a [gnovm.MemPackage]. // fset is optional. -func Imports(pkg *gnovm.MemPackage, fset *token.FileSet) ([]FileImport, error) { - allImports := make([]FileImport, 0, 16) - seen := make(map[string]struct{}, 16) +func Imports(pkg *gnovm.MemPackage, fset *token.FileSet) (ImportsMap, error) { + res := make(ImportsMap, 16) + seen := make(map[FileKind]map[string]struct{}, 16) + for _, file := range pkg.Files { if !strings.HasSuffix(file.Name, ".gno") { continue } - if strings.HasSuffix(file.Name, "_filetest.gno") { - continue + + fileKind, err := GetFileKind(file.Name, file.Body, fset) + if err != nil { + return nil, err } imports, err := FileImports(file.Name, file.Body, fset) if err != nil { return nil, err } for _, im := range imports { - if _, ok := seen[im.PkgPath]; ok { + if _, ok := seen[fileKind][im.PkgPath]; ok { continue } - allImports = append(allImports, im) - seen[im.PkgPath] = struct{}{} + res[fileKind] = append(res[fileKind], im) + if _, ok := seen[fileKind]; !ok { + seen[fileKind] = make(map[string]struct{}, 16) + } + seen[fileKind][im.PkgPath] = struct{}{} } } - sort.Slice(allImports, func(i, j int) bool { - return allImports[i].PkgPath < allImports[j].PkgPath - }) - return allImports, nil + for _, imports := range res { + sortImports(imports) + } + + return res, nil } // FileImport represents an import @@ -75,3 +82,31 @@ func FileImports(filename string, src string, fset *token.FileSet) ([]FileImport } return res, nil } + +type ImportsMap map[FileKind][]FileImport + +// Merge merges imports, it removes duplicates and sorts the result +func (imap ImportsMap) Merge(kinds ...FileKind) []FileImport { + res := make([]FileImport, 0, 16) + seen := make(map[string]struct{}, 16) + + for _, kind := range kinds { + for _, im := range imap[kind] { + if _, ok := seen[im.PkgPath]; ok { + continue + } + seen[im.PkgPath] = struct{}{} + + res = append(res, im) + } + } + + sortImports(res) + return res +} + +func sortImports(imports []FileImport) { + sort.Slice(imports, func(i, j int) bool { + return imports[i].PkgPath < imports[j].PkgPath + }) +} diff --git a/gnovm/pkg/packages/imports_test.go b/gnovm/pkg/packages/imports_test.go index 3750aa9108c..f9f58b967c8 100644 --- a/gnovm/pkg/packages/imports_test.go +++ b/gnovm/pkg/packages/imports_test.go @@ -58,13 +58,26 @@ func TestImports(t *testing.T) { ) `, }, + { + name: "file2_test.gno", + data: ` + package tmp_test + + import ( + "testing" + + "gno.land/p/demo/testpkg" + "gno.land/p/demo/xtestdep" + ) + `, + }, { name: "z_0_filetest.gno", data: ` package main import ( - "gno.land/p/demo/filetestpkg" + "gno.land/p/demo/filetestdep" ) `, }, @@ -95,17 +108,28 @@ func TestImports(t *testing.T) { }, } - // Expected list of imports + // Expected lists of imports // - ignore subdirs // - ignore duplicate - // - ignore *_filetest.gno // - should be sorted - expected := []string{ - "gno.land/p/demo/pkg1", - "gno.land/p/demo/pkg2", - "gno.land/p/demo/testpkg", - "std", - "testing", + expected := map[FileKind][]string{ + FileKindPackageSource: { + "gno.land/p/demo/pkg1", + "gno.land/p/demo/pkg2", + "std", + }, + FileKindTest: { + "gno.land/p/demo/testpkg", + "testing", + }, + FileKindXTest: { + "gno.land/p/demo/testpkg", + "gno.land/p/demo/xtestdep", + "testing", + }, + FileKindFiletest: { + "gno.land/p/demo/filetestdep", + }, } // Create subpkg dir @@ -120,12 +144,19 @@ func TestImports(t *testing.T) { pkg, err := gnolang.ReadMemPackage(tmpDir, "test") require.NoError(t, err) - imports, err := Imports(pkg, nil) + + importsMap, err := Imports(pkg, nil) require.NoError(t, err) - importsStrings := make([]string, len(imports)) - for idx, imp := range imports { - importsStrings[idx] = imp.PkgPath + + // ignore specs + got := map[FileKind][]string{} + for key, vals := range importsMap { + stringVals := make([]string, len(vals)) + for i, val := range vals { + stringVals[i] = val.PkgPath + } + got[key] = stringVals } - require.Equal(t, expected, importsStrings) + require.Equal(t, expected, got) } diff --git a/gnovm/pkg/test/imports.go b/gnovm/pkg/test/imports.go index 8b24fdeaa77..a8dd709e501 100644 --- a/gnovm/pkg/test/imports.go +++ b/gnovm/pkg/test/imports.go @@ -242,10 +242,11 @@ func LoadImports(store gno.Store, memPkg *gnovm.MemPackage) (err error) { }() fset := token.NewFileSet() - imports, err := packages.Imports(memPkg, fset) + importsMap, err := packages.Imports(memPkg, fset) if err != nil { return err } + imports := importsMap.Merge(packages.FileKindPackageSource, packages.FileKindTest, packages.FileKindXTest) for _, imp := range imports { if gno.IsRealmPath(imp.PkgPath) { // Don't eagerly load realms. diff --git a/gnovm/pkg/test/test.go b/gnovm/pkg/test/test.go index 80f56e66d2e..d06540761d7 100644 --- a/gnovm/pkg/test/test.go +++ b/gnovm/pkg/test/test.go @@ -315,7 +315,7 @@ func (opts *TestOptions) runTestFiles( // RunFiles doesn't do that currently) // - Wrap here. m = Machine(gs, opts.Output, memPkg.Path) - m.Alloc = alloc + m.Alloc = alloc.Reset() m.SetActivePackage(pv) testingpv := m.Store.GetPackage("testing", false) diff --git a/gnovm/pkg/transpiler/fuzz_test.go b/gnovm/pkg/transpiler/fuzz_test.go new file mode 100644 index 00000000000..d3ab26eea86 --- /dev/null +++ b/gnovm/pkg/transpiler/fuzz_test.go @@ -0,0 +1,76 @@ +package transpiler_test + +import ( + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/gnolang/gno/gnovm/pkg/gnoenv" + "github.com/gnolang/gno/gnovm/pkg/transpiler" +) + +func FuzzTranspiling(f *testing.F) { + if testing.Short() { + f.Skip("Running in -short mode") + } + + // 1. Derive the seeds from our seedGnoFiles. + ffs := os.DirFS(filepath.Join(gnoenv.RootDir(), "examples")) + fs.WalkDir(ffs, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + panic(err) + } + if !strings.HasSuffix(path, ".gno") { + return nil + } + file, err := ffs.Open(path) + if err != nil { + panic(err) + } + blob, err := io.ReadAll(file) + file.Close() + if err != nil { + panic(err) + } + f.Add(blob) + return nil + }) + + // 2. Run the fuzzers. + f.Fuzz(func(t *testing.T, gnoSourceCode []byte) { + // 3. Add timings to ensure that if transpiling takes a long time + // to run, that we report this as problematic. + doneCh := make(chan bool, 1) + readyCh := make(chan bool) + go func() { + defer func() { + r := recover() + if r == nil { + return + } + + sr := fmt.Sprintf("%s", r) + if !strings.Contains(sr, "invalid line number ") { + panic(r) + } + }() + close(readyCh) + defer close(doneCh) + _, _ = transpiler.Transpile(string(gnoSourceCode), "gno", "in.gno") + doneCh <- true + }() + + <-readyCh + + select { + case <-time.After(2 * time.Second): + t.Fatalf("took more than 2 seconds to transpile\n\n%s", gnoSourceCode) + case <-doneCh: + } + }) +} diff --git a/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/26ec2c0a22e242fc b/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/26ec2c0a22e242fc new file mode 100644 index 00000000000..d82acb3eaa5 --- /dev/null +++ b/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/26ec2c0a22e242fc @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("package\tA\nimport(\"\"//\"\n\"\"/***/)") diff --git a/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/54f7287f473abfa5 b/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/54f7287f473abfa5 new file mode 100644 index 00000000000..a3accb3dd8e --- /dev/null +++ b/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/54f7287f473abfa5 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("package A\nimport(\"\"//\"\n\"\"/***/)") diff --git a/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/703a1cacabb84c6d b/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/703a1cacabb84c6d new file mode 100644 index 00000000000..c08c7cfe904 --- /dev/null +++ b/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/703a1cacabb84c6d @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("package A\ncon\x12\n\xec|b\x80\xddQst(\n/*\n\n\n\n\n\n\nka\n*/A)") diff --git a/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/bcd60839c81ca478 b/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/bcd60839c81ca478 new file mode 100644 index 00000000000..df9d3b1b5b9 --- /dev/null +++ b/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/bcd60839c81ca478 @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("//\"\npackage\tA\nimport(\"\"//\"\n\"\"/***/)") diff --git a/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/bf4f7515b25bf71b b/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/bf4f7515b25bf71b new file mode 100644 index 00000000000..cbc2bd9be5b --- /dev/null +++ b/gnovm/pkg/transpiler/testdata/fuzz/FuzzTranspiling/bf4f7515b25bf71b @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("//0000\x170000000000:0\npackage A") diff --git a/gnovm/stdlibs/encoding/base32/base32.gno b/gnovm/stdlibs/encoding/base32/base32.gno new file mode 100644 index 00000000000..eba0a1d539e --- /dev/null +++ b/gnovm/stdlibs/encoding/base32/base32.gno @@ -0,0 +1,590 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package base32 implements base32 encoding as specified by RFC 4648. +package base32 + +import ( + "io" + "strconv" +) + +/* + * Encodings + */ + +// An Encoding is a radix 32 encoding/decoding scheme, defined by a +// 32-character alphabet. The most common is the "base32" encoding +// introduced for SASL GSSAPI and standardized in RFC 4648. +// The alternate "base32hex" encoding is used in DNSSEC. +type Encoding struct { + encode [32]byte // mapping of symbol index to symbol byte value + decodeMap [256]uint8 // mapping of symbol byte value to symbol index + padChar rune +} + +const ( + StdPadding rune = '=' // Standard padding character + NoPadding rune = -1 // No padding +) + +const ( + decodeMapInitialize = "" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + + "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" + invalidIndex = '\xff' +) + +// NewEncoding returns a new padded Encoding defined by the given alphabet, +// which must be a 32-byte string that contains unique byte values and +// does not contain the padding character or CR / LF ('\r', '\n'). +// The alphabet is treated as a sequence of byte values +// without any special treatment for multi-byte UTF-8. +// The resulting Encoding uses the default padding character ('='), +// which may be changed or disabled via [Encoding.WithPadding]. +func NewEncoding(encoder string) *Encoding { + if len(encoder) != 32 { + panic("encoding alphabet is not 32-bytes long") + } + + e := new(Encoding) + e.padChar = StdPadding + copy(e.encode[:], encoder) + copy(e.decodeMap[:], decodeMapInitialize) + + for i := 0; i < len(encoder); i++ { + // Note: While we document that the alphabet cannot contain + // the padding character, we do not enforce it since we do not know + // if the caller intends to switch the padding from StdPadding later. + switch { + case encoder[i] == '\n' || encoder[i] == '\r': + panic("encoding alphabet contains newline character") + case e.decodeMap[encoder[i]] != invalidIndex: + panic("encoding alphabet includes duplicate symbols") + } + e.decodeMap[encoder[i]] = uint8(i) + } + return e +} + +// StdEncoding is the standard base32 encoding, as defined in RFC 4648. +var StdEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZ234567") + +// HexEncoding is the “Extended Hex Alphabet” defined in RFC 4648. +// It is typically used in DNS. +var HexEncoding = NewEncoding("0123456789ABCDEFGHIJKLMNOPQRSTUV") + +// WithPadding creates a new encoding identical to enc except +// with a specified padding character, or NoPadding to disable padding. +// The padding character must not be '\r' or '\n', +// must not be contained in the encoding's alphabet, +// must not be negative, and must be a rune equal or below '\xff'. +// Padding characters above '\x7f' are encoded as their exact byte value +// rather than using the UTF-8 representation of the codepoint. +func (enc Encoding) WithPadding(padding rune) *Encoding { + switch { + case padding < NoPadding || padding == '\r' || padding == '\n' || padding > 0xff: + panic("invalid padding") + case padding != NoPadding && enc.decodeMap[byte(padding)] != invalidIndex: + panic("padding contained in alphabet") + } + enc.padChar = padding + return &enc +} + +/* + * Encoder + */ + +// Encode encodes src using the encoding enc, +// writing [Encoding.EncodedLen](len(src)) bytes to dst. +// +// The encoding pads the output to a multiple of 8 bytes, +// so Encode is not appropriate for use on individual blocks +// of a large data stream. Use [NewEncoder] instead. +func (enc *Encoding) Encode(dst, src []byte) { + if len(src) == 0 { + return + } + // enc is a pointer receiver, so the use of enc.encode within the hot + // loop below means a nil check at every operation. Lift that nil check + // outside of the loop to speed up the encoder. + _ = enc.encode + + di, si := 0, 0 + n := (len(src) / 5) * 5 + for si < n { + // Combining two 32 bit loads allows the same code to be used + // for 32 and 64 bit platforms. + hi := uint32(src[si+0])<<24 | uint32(src[si+1])<<16 | uint32(src[si+2])<<8 | uint32(src[si+3]) + lo := hi<<8 | uint32(src[si+4]) + + dst[di+0] = enc.encode[(hi>>27)&0x1F] + dst[di+1] = enc.encode[(hi>>22)&0x1F] + dst[di+2] = enc.encode[(hi>>17)&0x1F] + dst[di+3] = enc.encode[(hi>>12)&0x1F] + dst[di+4] = enc.encode[(hi>>7)&0x1F] + dst[di+5] = enc.encode[(hi>>2)&0x1F] + dst[di+6] = enc.encode[(lo>>5)&0x1F] + dst[di+7] = enc.encode[(lo)&0x1F] + + si += 5 + di += 8 + } + + // Add the remaining small block + remain := len(src) - si + if remain == 0 { + return + } + + // Encode the remaining bytes in reverse order. + val := uint32(0) + switch remain { + case 4: + val |= uint32(src[si+3]) + dst[di+6] = enc.encode[val<<3&0x1F] + dst[di+5] = enc.encode[val>>2&0x1F] + fallthrough + case 3: + val |= uint32(src[si+2]) << 8 + dst[di+4] = enc.encode[val>>7&0x1F] + fallthrough + case 2: + val |= uint32(src[si+1]) << 16 + dst[di+3] = enc.encode[val>>12&0x1F] + dst[di+2] = enc.encode[val>>17&0x1F] + fallthrough + case 1: + val |= uint32(src[si+0]) << 24 + dst[di+1] = enc.encode[val>>22&0x1F] + dst[di+0] = enc.encode[val>>27&0x1F] + } + + // Pad the final quantum + if enc.padChar != NoPadding { + nPad := (remain * 8 / 5) + 1 + for i := nPad; i < 8; i++ { + dst[di+i] = byte(enc.padChar) + } + } +} + +// AppendEncode appends the base32 encoded src to dst +// and returns the extended buffer. +func (enc *Encoding) AppendEncode(dst, src []byte) []byte { + n := enc.EncodedLen(len(src)) + dst = slicesGrow(dst, n) + enc.Encode(dst[len(dst):][:n], src) + return dst[:len(dst)+n] +} + +// XXX: non-generic slices.Grow +func slicesGrow(s []byte, n int) []byte { + if n -= cap(s) - len(s); n > 0 { + s = append(s[:cap(s)], make([]byte, n)...)[:len(s)] + } + return s +} + +// EncodeToString returns the base32 encoding of src. +func (enc *Encoding) EncodeToString(src []byte) string { + buf := make([]byte, enc.EncodedLen(len(src))) + enc.Encode(buf, src) + return string(buf) +} + +type encoder struct { + err error + enc *Encoding + w io.Writer + buf [5]byte // buffered data waiting to be encoded + nbuf int // number of bytes in buf + out [1024]byte // output buffer +} + +func (e *encoder) Write(p []byte) (n int, err error) { + if e.err != nil { + return 0, e.err + } + + // Leading fringe. + if e.nbuf > 0 { + var i int + for i = 0; i < len(p) && e.nbuf < 5; i++ { + e.buf[e.nbuf] = p[i] + e.nbuf++ + } + n += i + p = p[i:] + if e.nbuf < 5 { + return + } + e.enc.Encode(e.out[0:], e.buf[0:]) + if _, e.err = e.w.Write(e.out[0:8]); e.err != nil { + return n, e.err + } + e.nbuf = 0 + } + + // Large interior chunks. + for len(p) >= 5 { + nn := len(e.out) / 8 * 5 + if nn > len(p) { + nn = len(p) + nn -= nn % 5 + } + e.enc.Encode(e.out[0:], p[0:nn]) + if _, e.err = e.w.Write(e.out[0 : nn/5*8]); e.err != nil { + return n, e.err + } + n += nn + p = p[nn:] + } + + // Trailing fringe. + copy(e.buf[:], p) + e.nbuf = len(p) + n += len(p) + return +} + +// Close flushes any pending output from the encoder. +// It is an error to call Write after calling Close. +func (e *encoder) Close() error { + // If there's anything left in the buffer, flush it out + if e.err == nil && e.nbuf > 0 { + e.enc.Encode(e.out[0:], e.buf[0:e.nbuf]) + encodedLen := e.enc.EncodedLen(e.nbuf) + e.nbuf = 0 + _, e.err = e.w.Write(e.out[0:encodedLen]) + } + return e.err +} + +// NewEncoder returns a new base32 stream encoder. Data written to +// the returned writer will be encoded using enc and then written to w. +// Base32 encodings operate in 5-byte blocks; when finished +// writing, the caller must Close the returned encoder to flush any +// partially written blocks. +func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser { + return &encoder{enc: enc, w: w} +} + +// EncodedLen returns the length in bytes of the base32 encoding +// of an input buffer of length n. +func (enc *Encoding) EncodedLen(n int) int { + if enc.padChar == NoPadding { + return n/5*8 + (n%5*8+4)/5 + } + return (n + 4) / 5 * 8 +} + +/* + * Decoder + */ + +type CorruptInputError int64 + +func (e CorruptInputError) Error() string { + return "illegal base32 data at input byte " + strconv.FormatInt(int64(e), 10) +} + +// decode is like Decode but returns an additional 'end' value, which +// indicates if end-of-message padding was encountered and thus any +// additional data is an error. This method assumes that src has been +// stripped of all supported whitespace ('\r' and '\n'). +func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { + // Lift the nil check outside of the loop. + _ = enc.decodeMap + + dsti := 0 + olen := len(src) + + for len(src) > 0 && !end { + // Decode quantum using the base32 alphabet + var dbuf [8]byte + dlen := 8 + + for j := 0; j < 8; { + + if len(src) == 0 { + if enc.padChar != NoPadding { + // We have reached the end and are missing padding + return n, false, CorruptInputError(olen - len(src) - j) + } + // We have reached the end and are not expecting any padding + dlen, end = j, true + break + } + in := src[0] + src = src[1:] + if in == byte(enc.padChar) && j >= 2 && len(src) < 8 { + // We've reached the end and there's padding + if len(src)+j < 8-1 { + // not enough padding + return n, false, CorruptInputError(olen) + } + for k := 0; k < 8-1-j; k++ { + if len(src) > k && src[k] != byte(enc.padChar) { + // incorrect padding + return n, false, CorruptInputError(olen - len(src) + k - 1) + } + } + dlen, end = j, true + // 7, 5 and 2 are not valid padding lengths, and so 1, 3 and 6 are not + // valid dlen values. See RFC 4648 Section 6 "Base 32 Encoding" listing + // the five valid padding lengths, and Section 9 "Illustrations and + // Examples" for an illustration for how the 1st, 3rd and 6th base32 + // src bytes do not yield enough information to decode a dst byte. + if dlen == 1 || dlen == 3 || dlen == 6 { + return n, false, CorruptInputError(olen - len(src) - 1) + } + break + } + dbuf[j] = enc.decodeMap[in] + if dbuf[j] == 0xFF { + return n, false, CorruptInputError(olen - len(src) - 1) + } + j++ + } + + // Pack 8x 5-bit source blocks into 5 byte destination + // quantum + switch dlen { + case 8: + dst[dsti+4] = dbuf[6]<<5 | dbuf[7] + n++ + fallthrough + case 7: + dst[dsti+3] = dbuf[4]<<7 | dbuf[5]<<2 | dbuf[6]>>3 + n++ + fallthrough + case 5: + dst[dsti+2] = dbuf[3]<<4 | dbuf[4]>>1 + n++ + fallthrough + case 4: + dst[dsti+1] = dbuf[1]<<6 | dbuf[2]<<1 | dbuf[3]>>4 + n++ + fallthrough + case 2: + dst[dsti+0] = dbuf[0]<<3 | dbuf[1]>>2 + n++ + } + dsti += 5 + } + return n, end, nil +} + +// Decode decodes src using the encoding enc. It writes at most +// [Encoding.DecodedLen](len(src)) bytes to dst and returns the number of bytes +// written. If src contains invalid base32 data, it will return the +// number of bytes successfully written and [CorruptInputError]. +// Newline characters (\r and \n) are ignored. +func (enc *Encoding) Decode(dst, src []byte) (n int, err error) { + buf := make([]byte, len(src)) + l := stripNewlines(buf, src) + n, _, err = enc.decode(dst, buf[:l]) + return +} + +// AppendDecode appends the base32 decoded src to dst +// and returns the extended buffer. +// If the input is malformed, it returns the partially decoded src and an error. +func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) { + // Compute the output size without padding to avoid over allocating. + n := len(src) + for n > 0 && rune(src[n-1]) == enc.padChar { + n-- + } + n = decodedLen(n, NoPadding) + + dst = slicesGrow(dst, n) + n, err := enc.Decode(dst[len(dst):][:n], src) + return dst[:len(dst)+n], err +} + +// DecodeString returns the bytes represented by the base32 string s. +func (enc *Encoding) DecodeString(s string) ([]byte, error) { + buf := []byte(s) + l := stripNewlines(buf, buf) + n, _, err := enc.decode(buf, buf[:l]) + return buf[:n], err +} + +type decoder struct { + err error + enc *Encoding + r io.Reader + end bool // saw end of message + buf [1024]byte // leftover input + nbuf int + out []byte // leftover decoded output + outbuf [1024 / 8 * 5]byte +} + +func readEncodedData(r io.Reader, buf []byte, min int, expectsPadding bool) (n int, err error) { + for n < min && err == nil { + var nn int + nn, err = r.Read(buf[n:]) + n += nn + } + // data was read, less than min bytes could be read + if n < min && n > 0 && err == io.EOF { + err = io.ErrUnexpectedEOF + } + // no data was read, the buffer already contains some data + // when padding is disabled this is not an error, as the message can be of + // any length + if expectsPadding && min < 8 && n == 0 && err == io.EOF { + err = io.ErrUnexpectedEOF + } + return +} + +func (d *decoder) Read(p []byte) (n int, err error) { + // Use leftover decoded output from last read. + if len(d.out) > 0 { + n = copy(p, d.out) + d.out = d.out[n:] + if len(d.out) == 0 { + return n, d.err + } + return n, nil + } + + if d.err != nil { + return 0, d.err + } + + // Read a chunk. + nn := (len(p) + 4) / 5 * 8 + if nn < 8 { + nn = 8 + } + if nn > len(d.buf) { + nn = len(d.buf) + } + + // Minimum amount of bytes that needs to be read each cycle + var min int + var expectsPadding bool + if d.enc.padChar == NoPadding { + min = 1 + expectsPadding = false + } else { + min = 8 - d.nbuf + expectsPadding = true + } + + nn, d.err = readEncodedData(d.r, d.buf[d.nbuf:nn], min, expectsPadding) + d.nbuf += nn + if d.nbuf < min { + return 0, d.err + } + if nn > 0 && d.end { + return 0, CorruptInputError(0) + } + + // Decode chunk into p, or d.out and then p if p is too small. + var nr int + if d.enc.padChar == NoPadding { + nr = d.nbuf + } else { + nr = d.nbuf / 8 * 8 + } + nw := d.enc.DecodedLen(d.nbuf) + + if nw > len(p) { + nw, d.end, err = d.enc.decode(d.outbuf[0:], d.buf[0:nr]) + d.out = d.outbuf[0:nw] + n = copy(p, d.out) + d.out = d.out[n:] + } else { + n, d.end, err = d.enc.decode(p, d.buf[0:nr]) + } + d.nbuf -= nr + for i := 0; i < d.nbuf; i++ { + d.buf[i] = d.buf[i+nr] + } + + if err != nil && (d.err == nil || d.err == io.EOF) { + d.err = err + } + + if len(d.out) > 0 { + // We cannot return all the decoded bytes to the caller in this + // invocation of Read, so we return a nil error to ensure that Read + // will be called again. The error stored in d.err, if any, will be + // returned with the last set of decoded bytes. + return n, nil + } + + return n, d.err +} + +type newlineFilteringReader struct { + wrapped io.Reader +} + +// stripNewlines removes newline characters and returns the number +// of non-newline characters copied to dst. +func stripNewlines(dst, src []byte) int { + offset := 0 + for _, b := range src { + if b == '\r' || b == '\n' { + continue + } + dst[offset] = b + offset++ + } + return offset +} + +func (r *newlineFilteringReader) Read(p []byte) (int, error) { + n, err := r.wrapped.Read(p) + for n > 0 { + s := p[0:n] + offset := stripNewlines(s, s) + if err != nil || offset > 0 { + return offset, err + } + // Previous buffer entirely whitespace, read again + n, err = r.wrapped.Read(p) + } + return n, err +} + +// NewDecoder constructs a new base32 stream decoder. +func NewDecoder(enc *Encoding, r io.Reader) io.Reader { + return &decoder{enc: enc, r: &newlineFilteringReader{r}} +} + +// DecodedLen returns the maximum length in bytes of the decoded data +// corresponding to n bytes of base32-encoded data. +func (enc *Encoding) DecodedLen(n int) int { + return decodedLen(n, enc.padChar) +} + +func decodedLen(n int, padChar rune) int { + if padChar == NoPadding { + return n/8*5 + n%8*5/8 + } + return n / 8 * 5 +} diff --git a/gnovm/stdlibs/encoding/base32/base32_test.gno b/gnovm/stdlibs/encoding/base32/base32_test.gno new file mode 100644 index 00000000000..fa09c2cd2f3 --- /dev/null +++ b/gnovm/stdlibs/encoding/base32/base32_test.gno @@ -0,0 +1,913 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package base32 + +import ( + "bytes" + "errors" + "io" + "math" + "strconv" + "strings" + "testing" +) + +type testpair struct { + decoded, encoded string +} + +var pairs = []testpair{ + // RFC 4648 examples + {"", ""}, + {"f", "MY======"}, + {"fo", "MZXQ===="}, + {"foo", "MZXW6==="}, + {"foob", "MZXW6YQ="}, + {"fooba", "MZXW6YTB"}, + {"foobar", "MZXW6YTBOI======"}, + + // Wikipedia examples, converted to base32 + {"sure.", "ON2XEZJO"}, + {"sure", "ON2XEZI="}, + {"sur", "ON2XE==="}, + {"su", "ON2Q===="}, + {"leasure.", "NRSWC43VOJSS4==="}, + {"easure.", "MVQXG5LSMUXA===="}, + {"asure.", "MFZXK4TFFY======"}, + {"sure.", "ON2XEZJO"}, +} + +var bigtest = testpair{ + "Twas brillig, and the slithy toves", + "KR3WC4ZAMJZGS3DMNFTSYIDBNZSCA5DIMUQHG3DJORUHSIDUN53GK4Y=", +} + +func testEqual(t *testing.T, msg string, args ...interface{}) bool { + t.Helper() + if args[len(args)-2] != args[len(args)-1] { + t.Errorf(msg, args...) + return false + } + return true +} + +func TestEncode(t *testing.T) { + for _, p := range pairs { + got := StdEncoding.EncodeToString([]byte(p.decoded)) + testEqual(t, "Encode(%q) = %q, want %q", p.decoded, got, p.encoded) + dst := StdEncoding.AppendEncode([]byte("lead"), []byte(p.decoded)) + testEqual(t, `AppendEncode("lead", %q) = %q, want %q`, p.decoded, string(dst), "lead"+p.encoded) + } +} + +func TestEncoder(t *testing.T) { + for _, p := range pairs { + bb := &strings.Builder{} + encoder := NewEncoder(StdEncoding, bb) + encoder.Write([]byte(p.decoded)) + encoder.Close() + testEqual(t, "Encode(%q) = %q, want %q", p.decoded, bb.String(), p.encoded) + } +} + +func TestEncoderBuffering(t *testing.T) { + input := []byte(bigtest.decoded) + for bs := 1; bs <= 12; bs++ { + bb := &strings.Builder{} + encoder := NewEncoder(StdEncoding, bb) + for pos := 0; pos < len(input); pos += bs { + end := pos + bs + if end > len(input) { + end = len(input) + } + n, err := encoder.Write(input[pos:end]) + testEqual(t, "Write(%q) gave error %v, want %v", input[pos:end], err, error(nil)) + testEqual(t, "Write(%q) gave length %v, want %v", input[pos:end], n, end-pos) + } + err := encoder.Close() + testEqual(t, "Close gave error %v, want %v", err, error(nil)) + testEqual(t, "Encoding/%d of %q = %q, want %q", bs, bigtest.decoded, bb.String(), bigtest.encoded) + } +} + +func TestDecoderBufferingWithPadding(t *testing.T) { + for bs := 0; bs <= 12; bs++ { + for _, s := range pairs { + decoder := NewDecoder(StdEncoding, strings.NewReader(s.encoded)) + buf := make([]byte, len(s.decoded)+bs) + + var n int + var err error + n, err = decoder.Read(buf) + + if err != nil && err != io.EOF { + t.Errorf("Read from %q at pos %d = %d, unexpected error %v", s.encoded, len(s.decoded), n, err) + } + testEqual(t, "Decoding/%d of %q = %q, want %q\n", bs, s.encoded, string(buf[:n]), s.decoded) + } + } +} + +func TestDecoderBufferingWithoutPadding(t *testing.T) { + for bs := 0; bs <= 12; bs++ { + for _, s := range pairs { + encoded := strings.TrimRight(s.encoded, "=") + decoder := NewDecoder(StdEncoding.WithPadding(NoPadding), strings.NewReader(encoded)) + buf := make([]byte, len(s.decoded)+bs) + + var n int + var err error + n, err = decoder.Read(buf) + + if err != nil && err != io.EOF { + t.Errorf("Read from %q at pos %d = %d, unexpected error %v", encoded, len(s.decoded), n, err) + } + testEqual(t, "Decoding/%d of %q = %q, want %q\n", bs, encoded, string(buf[:n]), s.decoded) + } + } +} + +func TestDecode(t *testing.T) { + for _, p := range pairs { + dbuf := make([]byte, StdEncoding.DecodedLen(len(p.encoded))) + count, end, err := StdEncoding.decode(dbuf, []byte(p.encoded)) + testEqual(t, "Decode(%q) = error %v, want %v", p.encoded, err, error(nil)) + testEqual(t, "Decode(%q) = length %v, want %v", p.encoded, count, len(p.decoded)) + if len(p.encoded) > 0 { + testEqual(t, "Decode(%q) = end %v, want %v", p.encoded, end, (p.encoded[len(p.encoded)-1] == '=')) + } + testEqual(t, "Decode(%q) = %q, want %q", p.encoded, string(dbuf[0:count]), p.decoded) + + dbuf, err = StdEncoding.DecodeString(p.encoded) + testEqual(t, "DecodeString(%q) = error %v, want %v", p.encoded, err, error(nil)) + testEqual(t, "DecodeString(%q) = %q, want %q", p.encoded, string(dbuf), p.decoded) + + dst, err := StdEncoding.AppendDecode([]byte("lead"), []byte(p.encoded)) + testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil)) + testEqual(t, `AppendDecode("lead", %q) = %q, want %q`, p.encoded, string(dst), "lead"+p.decoded) + + dst2, err := StdEncoding.AppendDecode(dst[:0:len(p.decoded)], []byte(p.encoded)) + testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil)) + testEqual(t, `AppendDecode("", %q) = %q, want %q`, p.encoded, string(dst2), p.decoded) + if len(dst) > 0 && len(dst2) > 0 && &dst[0] != &dst2[0] { + t.Errorf("unexpected capacity growth: got %d, want %d", cap(dst2), cap(dst)) + } + } +} + +func TestDecoder(t *testing.T) { + for _, p := range pairs { + decoder := NewDecoder(StdEncoding, strings.NewReader(p.encoded)) + dbuf := make([]byte, StdEncoding.DecodedLen(len(p.encoded))) + count, err := decoder.Read(dbuf) + if err != nil && err != io.EOF { + t.Fatal("Read failed", err) + } + testEqual(t, "Read from %q = length %v, want %v", p.encoded, count, len(p.decoded)) + testEqual(t, "Decoding of %q = %q, want %q", p.encoded, string(dbuf[0:count]), p.decoded) + if err != io.EOF { + _, err = decoder.Read(dbuf) + } + testEqual(t, "Read from %q = %v, want %v", p.encoded, err, io.EOF) + } +} + +type badReader struct { + data []byte + errs []error + called int + limit int +} + +// Populates p with data, returns a count of the bytes written and an +// error. The error returned is taken from badReader.errs, with each +// invocation of Read returning the next error in this slice, or io.EOF, +// if all errors from the slice have already been returned. The +// number of bytes returned is determined by the size of the input buffer +// the test passes to decoder.Read and will be a multiple of 8, unless +// badReader.limit is non zero. +func (b *badReader) Read(p []byte) (int, error) { + lim := len(p) + if b.limit != 0 && b.limit < lim { + lim = b.limit + } + if len(b.data) < lim { + lim = len(b.data) + } + for i := range p[:lim] { + p[i] = b.data[i] + } + b.data = b.data[lim:] + err := io.EOF + if b.called < len(b.errs) { + err = b.errs[b.called] + } + b.called++ + return lim, err +} + +// TestIssue20044 tests that decoder.Read behaves correctly when the caller +// supplied reader returns an error. +func TestIssue20044(t *testing.T) { + badErr := errors.New("bad reader error") + testCases := []struct { + r badReader + res string + err error + dbuflen int + }{ + // Check valid input data accompanied by an error is processed and the error is propagated. + {r: badReader{data: []byte("MY======"), errs: []error{badErr}}, + res: "f", err: badErr}, + // Check a read error accompanied by input data consisting of newlines only is propagated. + {r: badReader{data: []byte("\n\n\n\n\n\n\n\n"), errs: []error{badErr, nil}}, + res: "", err: badErr}, + // Reader will be called twice. The first time it will return 8 newline characters. The + // second time valid base32 encoded data and an error. The data should be decoded + // correctly and the error should be propagated. + {r: badReader{data: []byte("\n\n\n\n\n\n\n\nMY======"), errs: []error{nil, badErr}}, + res: "f", err: badErr, dbuflen: 8}, + // Reader returns invalid input data (too short) and an error. Verify the reader + // error is returned. + {r: badReader{data: []byte("MY====="), errs: []error{badErr}}, + res: "", err: badErr}, + // Reader returns invalid input data (too short) but no error. Verify io.ErrUnexpectedEOF + // is returned. + {r: badReader{data: []byte("MY====="), errs: []error{nil}}, + res: "", err: io.ErrUnexpectedEOF}, + // Reader returns invalid input data and an error. Verify the reader and not the + // decoder error is returned. + {r: badReader{data: []byte("Ma======"), errs: []error{badErr}}, + res: "", err: badErr}, + // Reader returns valid data and io.EOF. Check data is decoded and io.EOF is propagated. + {r: badReader{data: []byte("MZXW6YTB"), errs: []error{io.EOF}}, + res: "fooba", err: io.EOF}, + // Check errors are properly reported when decoder.Read is called multiple times. + // decoder.Read will be called 8 times, badReader.Read will be called twice, returning + // valid data both times but an error on the second call. + {r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{nil, badErr}}, + res: "leasure.", err: badErr, dbuflen: 1}, + // Check io.EOF is properly reported when decoder.Read is called multiple times. + // decoder.Read will be called 8 times, badReader.Read will be called twice, returning + // valid data both times but io.EOF on the second call. + {r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{nil, io.EOF}}, + res: "leasure.", err: io.EOF, dbuflen: 1}, + // The following two test cases check that errors are propagated correctly when more than + // 8 bytes are read at a time. + {r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{io.EOF}}, + res: "leasure.", err: io.EOF, dbuflen: 11}, + {r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{badErr}}, + res: "leasure.", err: badErr, dbuflen: 11}, + // Check that errors are correctly propagated when the reader returns valid bytes in + // groups that are not divisible by 8. The first read will return 11 bytes and no + // error. The second will return 7 and an error. The data should be decoded correctly + // and the error should be propagated. + {r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{nil, badErr}, limit: 11}, + res: "leasure.", err: badErr}, + } + + for _, tc := range testCases { + input := tc.r.data + decoder := NewDecoder(StdEncoding, &tc.r) + var dbuflen int + if tc.dbuflen > 0 { + dbuflen = tc.dbuflen + } else { + dbuflen = StdEncoding.DecodedLen(len(input)) + } + dbuf := make([]byte, dbuflen) + var err error + var res []byte + for err == nil { + var n int + n, err = decoder.Read(dbuf) + if n > 0 { + res = append(res, dbuf[:n]...) + } + } + + testEqual(t, "Decoding of %q = %q, want %q", string(input), string(res), tc.res) + testEqual(t, "Decoding of %q err = %v, expected %v", string(input), err, tc.err) + } +} + +// TestDecoderError verifies decode errors are propagated when there are no read +// errors. +func TestDecoderError(t *testing.T) { + for _, readErr := range []error{io.EOF, nil} { + input := "MZXW6YTb" + dbuf := make([]byte, StdEncoding.DecodedLen(len(input))) + br := badReader{data: []byte(input), errs: []error{readErr}} + decoder := NewDecoder(StdEncoding, &br) + n, err := decoder.Read(dbuf) + testEqual(t, "Read after EOF, n = %d, expected %d", n, 0) + if _, ok := err.(CorruptInputError); !ok { + t.Errorf("Corrupt input error expected. Found %T", err) + } + } +} + +// TestReaderEOF ensures decoder.Read behaves correctly when input data is +// exhausted. +func TestReaderEOF(t *testing.T) { + for _, readErr := range []error{io.EOF, nil} { + input := "MZXW6YTB" + br := badReader{data: []byte(input), errs: []error{nil, readErr}} + decoder := NewDecoder(StdEncoding, &br) + dbuf := make([]byte, StdEncoding.DecodedLen(len(input))) + n, err := decoder.Read(dbuf) + testEqual(t, "Decoding of %q err = %v, expected %v", input, err, error(nil)) + n, err = decoder.Read(dbuf) + testEqual(t, "Read after EOF, n = %d, expected %d", n, 0) + testEqual(t, "Read after EOF, err = %v, expected %v", err, io.EOF) + n, err = decoder.Read(dbuf) + testEqual(t, "Read after EOF, n = %d, expected %d", n, 0) + testEqual(t, "Read after EOF, err = %v, expected %v", err, io.EOF) + } +} + +func TestDecoderBuffering(t *testing.T) { + for bs := 1; bs <= 12; bs++ { + decoder := NewDecoder(StdEncoding, strings.NewReader(bigtest.encoded)) + buf := make([]byte, len(bigtest.decoded)+12) + var total int + var n int + var err error + for total = 0; total < len(bigtest.decoded) && err == nil; { + n, err = decoder.Read(buf[total : total+bs]) + total += n + } + if err != nil && err != io.EOF { + t.Errorf("Read from %q at pos %d = %d, unexpected error %v", bigtest.encoded, total, n, err) + } + testEqual(t, "Decoding/%d of %q = %q, want %q", bs, bigtest.encoded, string(buf[0:total]), bigtest.decoded) + } +} + +func TestDecodeCorrupt(t *testing.T) { + testCases := []struct { + input string + offset int // -1 means no corruption. + }{ + {"", -1}, + {"!!!!", 0}, + {"x===", 0}, + {"AA=A====", 2}, + {"AAA=AAAA", 3}, + {"MMMMMMMMM", 8}, + {"MMMMMM", 0}, + {"A=", 1}, + {"AA=", 3}, + {"AA==", 4}, + {"AA===", 5}, + {"AAAA=", 5}, + {"AAAA==", 6}, + {"AAAAA=", 6}, + {"AAAAA==", 7}, + {"A=======", 1}, + {"AA======", -1}, + {"AAA=====", 3}, + {"AAAA====", -1}, + {"AAAAA===", -1}, + {"AAAAAA==", 6}, + {"AAAAAAA=", -1}, + {"AAAAAAAA", -1}, + } + for _, tc := range testCases { + dbuf := make([]byte, StdEncoding.DecodedLen(len(tc.input))) + _, err := StdEncoding.Decode(dbuf, []byte(tc.input)) + if tc.offset == -1 { + if err != nil { + t.Error("Decoder wrongly detected corruption in", tc.input) + } + continue + } + switch err := err.(type) { + case CorruptInputError: + testEqual(t, "Corruption in %q at offset %v, want %v", tc.input, int(err), tc.offset) + default: + t.Error("Decoder failed to detect corruption in", tc) + } + } +} + +func TestBig(t *testing.T) { + n := 3*1000 + 1 + raw := make([]byte, n) + const alpha = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + for i := 0; i < n; i++ { + raw[i] = alpha[i%len(alpha)] + } + encoded := new(bytes.Buffer) + w := NewEncoder(StdEncoding, encoded) + nn, err := w.Write(raw) + if nn != n || err != nil { + t.Fatalf("Encoder.Write(raw) = %d, %v want %d, nil", nn, err, n) + } + err = w.Close() + if err != nil { + t.Fatalf("Encoder.Close() = %v want nil", err) + } + decoded, err := io.ReadAll(NewDecoder(StdEncoding, encoded)) + if err != nil { + t.Fatalf("io.ReadAll(NewDecoder(...)): %v", err) + } + + if !bytes.Equal(raw, decoded) { + var i int + for i = 0; i < len(decoded) && i < len(raw); i++ { + if decoded[i] != raw[i] { + break + } + } + t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i) + } +} + +func testStringEncoding(t *testing.T, expected string, examples []string) { + for _, e := range examples { + buf, err := StdEncoding.DecodeString(e) + if err != nil { + t.Errorf("Decode(%q) failed: %v", e, err) + continue + } + if s := string(buf); s != expected { + t.Errorf("Decode(%q) = %q, want %q", e, s, expected) + } + } +} + +func TestNewLineCharacters(t *testing.T) { + // Each of these should decode to the string "sure", without errors. + examples := []string{ + "ON2XEZI=", + "ON2XEZI=\r", + "ON2XEZI=\n", + "ON2XEZI=\r\n", + "ON2XEZ\r\nI=", + "ON2X\rEZ\nI=", + "ON2X\nEZ\rI=", + "ON2XEZ\nI=", + "ON2XEZI\n=", + } + testStringEncoding(t, "sure", examples) + + // Each of these should decode to the string "foobar", without errors. + examples = []string{ + "MZXW6YTBOI======", + "MZXW6YTBOI=\r\n=====", + } + testStringEncoding(t, "foobar", examples) +} + +func TestDecoderIssue4779(t *testing.T) { + encoded := `JRXXEZLNEBUXA43VNUQGI33MN5ZCA43JOQQGC3LFOQWCAY3PNZZWKY3UMV2HK4 +RAMFSGS4DJONUWG2LOM4QGK3DJOQWCA43FMQQGI3YKMVUXK43NN5SCA5DFNVYG64RANFXGG2LENFSH +K3TUEB2XIIDMMFRG64TFEBSXIIDEN5WG64TFEBWWCZ3OMEQGC3DJOF2WCLRAKV2CAZLONFWQUYLEEB +WWS3TJNUQHMZLONFQW2LBAOF2WS4ZANZXXG5DSOVSCAZLYMVZGG2LUMF2GS33OEB2WY3DBNVRW6IDM +MFRG64TJOMQG42LTNEQHK5AKMFWGS4LVNFYCAZLYEBSWCIDDN5WW233EN4QGG33OONSXC5LBOQXCAR +DVNFZSAYLVORSSA2LSOVZGKIDEN5WG64RANFXAU4TFOBZGK2DFNZSGK4TJOQQGS3RAOZXWY5LQORQX +IZJAOZSWY2LUEBSXG43FEBRWS3DMOVWSAZDPNRXXEZJAMV2SAZTVM5UWC5BANZ2WY3DBBJYGC4TJMF +2HK4ROEBCXQY3FOB2GK5LSEBZWS3TUEBXWGY3BMVRWC5BAMN2XA2LEMF2GC5BANZXW4IDQOJXWSZDF +NZ2CYIDTOVXHIIDJNYFGG5LMOBQSA4LVNEQG6ZTGNFRWSYJAMRSXGZLSOVXHIIDNN5WGY2LUEBQW42 +LNEBUWIIDFON2CA3DBMJXXE5LNFY== +====` + encodedShort := strings.ReplaceAll(encoded, "\n", "") + + dec := NewDecoder(StdEncoding, strings.NewReader(encoded)) + res1, err := io.ReadAll(dec) + if err != nil { + t.Errorf("ReadAll failed: %v", err) + } + + dec = NewDecoder(StdEncoding, strings.NewReader(encodedShort)) + var res2 []byte + res2, err = io.ReadAll(dec) + if err != nil { + t.Errorf("ReadAll failed: %v", err) + } + + if !bytes.Equal(res1, res2) { + t.Error("Decoded results not equal") + } +} + +func BenchmarkEncode(b *testing.B) { + data := make([]byte, 8192) + buf := make([]byte, StdEncoding.EncodedLen(len(data))) + b.SetBytes(int64(len(data))) + for i := 0; i < b.N; i++ { + StdEncoding.Encode(buf, data) + } +} + +func BenchmarkEncodeToString(b *testing.B) { + data := make([]byte, 8192) + b.SetBytes(int64(len(data))) + for i := 0; i < b.N; i++ { + StdEncoding.EncodeToString(data) + } +} + +func BenchmarkDecode(b *testing.B) { + data := make([]byte, StdEncoding.EncodedLen(8192)) + StdEncoding.Encode(data, make([]byte, 8192)) + buf := make([]byte, 8192) + b.SetBytes(int64(len(data))) + for i := 0; i < b.N; i++ { + StdEncoding.Decode(buf, data) + } +} + +func BenchmarkDecodeString(b *testing.B) { + data := StdEncoding.EncodeToString(make([]byte, 8192)) + b.SetBytes(int64(len(data))) + for i := 0; i < b.N; i++ { + StdEncoding.DecodeString(data) + } +} + +func TestWithCustomPadding(t *testing.T) { + for _, testcase := range pairs { + defaultPadding := StdEncoding.EncodeToString([]byte(testcase.decoded)) + customPadding := StdEncoding.WithPadding('@').EncodeToString([]byte(testcase.decoded)) + expected := strings.ReplaceAll(defaultPadding, "=", "@") + + if expected != customPadding { + t.Errorf("Expected custom %s, got %s", expected, customPadding) + } + if testcase.encoded != defaultPadding { + t.Errorf("Expected %s, got %s", testcase.encoded, defaultPadding) + } + } +} + +func TestWithoutPadding(t *testing.T) { + for _, testcase := range pairs { + defaultPadding := StdEncoding.EncodeToString([]byte(testcase.decoded)) + customPadding := StdEncoding.WithPadding(NoPadding).EncodeToString([]byte(testcase.decoded)) + expected := strings.TrimRight(defaultPadding, "=") + + if expected != customPadding { + t.Errorf("Expected custom %s, got %s", expected, customPadding) + } + if testcase.encoded != defaultPadding { + t.Errorf("Expected %s, got %s", testcase.encoded, defaultPadding) + } + } +} + +func TestDecodeWithPadding(t *testing.T) { + encodings := []*Encoding{ + StdEncoding, + StdEncoding.WithPadding('-'), + StdEncoding.WithPadding(NoPadding), + } + + for i, enc := range encodings { + for _, pair := range pairs { + + input := pair.decoded + encoded := enc.EncodeToString([]byte(input)) + + decoded, err := enc.DecodeString(encoded) + if err != nil { + t.Errorf("DecodeString Error for encoding %d (%q): %v", i, input, err) + } + + if input != string(decoded) { + t.Errorf("Unexpected result for encoding %d: got %q; want %q", i, decoded, input) + } + } + } +} + +func TestDecodeWithWrongPadding(t *testing.T) { + encoded := StdEncoding.EncodeToString([]byte("foobar")) + + _, err := StdEncoding.WithPadding('-').DecodeString(encoded) + if err == nil { + t.Error("expected error") + } + + _, err = StdEncoding.WithPadding(NoPadding).DecodeString(encoded) + if err == nil { + t.Error("expected error") + } +} + +func TestBufferedDecodingSameError(t *testing.T) { + testcases := []struct { + prefix string + chunkCombinations [][]string + expected error + }{ + // NBSWY3DPO5XXE3DE == helloworld + // Test with "ZZ" as extra input + {"helloworld", [][]string{ + {"NBSW", "Y3DP", "O5XX", "E3DE", "ZZ"}, + {"NBSWY3DPO5XXE3DE", "ZZ"}, + {"NBSWY3DPO5XXE3DEZZ"}, + {"NBS", "WY3", "DPO", "5XX", "E3D", "EZZ"}, + {"NBSWY3DPO5XXE3", "DEZZ"}, + }, io.ErrUnexpectedEOF}, + + // Test with "ZZY" as extra input + {"helloworld", [][]string{ + {"NBSW", "Y3DP", "O5XX", "E3DE", "ZZY"}, + {"NBSWY3DPO5XXE3DE", "ZZY"}, + {"NBSWY3DPO5XXE3DEZZY"}, + {"NBS", "WY3", "DPO", "5XX", "E3D", "EZZY"}, + {"NBSWY3DPO5XXE3", "DEZZY"}, + }, io.ErrUnexpectedEOF}, + + // Normal case, this is valid input + {"helloworld", [][]string{ + {"NBSW", "Y3DP", "O5XX", "E3DE"}, + {"NBSWY3DPO5XXE3DE"}, + {"NBS", "WY3", "DPO", "5XX", "E3D", "E"}, + {"NBSWY3DPO5XXE3", "DE"}, + }, nil}, + + // MZXW6YTB = fooba + {"fooba", [][]string{ + {"MZXW6YTBZZ"}, + {"MZXW6YTBZ", "Z"}, + {"MZXW6YTB", "ZZ"}, + {"MZXW6YT", "BZZ"}, + {"MZXW6Y", "TBZZ"}, + {"MZXW6Y", "TB", "ZZ"}, + {"MZXW6", "YTBZZ"}, + {"MZXW6", "YTB", "ZZ"}, + {"MZXW6", "YT", "BZZ"}, + }, io.ErrUnexpectedEOF}, + + // Normal case, this is valid input + {"fooba", [][]string{ + {"MZXW6YTB"}, + {"MZXW6YT", "B"}, + {"MZXW6Y", "TB"}, + {"MZXW6", "YTB"}, + {"MZXW6", "YT", "B"}, + {"MZXW", "6YTB"}, + {"MZXW", "6Y", "TB"}, + }, nil}, + } + + for _, testcase := range testcases { + for _, chunks := range testcase.chunkCombinations { + r := &chunkReader{chunks: chunks} + + decoder := NewDecoder(StdEncoding, r) + _, err := io.ReadAll(decoder) + + if err != testcase.expected { + t.Errorf("Expected %v, got %v; case %s %+v", testcase.expected, err, testcase.prefix, chunks) + } + } + } +} + +// XXX: +// chunkReader exists as an alternative to the original behaviour of these +// tests in go, which used io.Pipe and a separate goroutine. +type chunkReader struct { + chunks []string + nChunk int + read int +} + +func (c *chunkReader) Read(b []byte) (int, error) { + if c.nChunk >= len(c.chunks) { + return 0, io.EOF + } + chunk := c.chunks[c.nChunk] + read := copy(b, chunk[c.read:]) + c.read += read + if c.read == len(chunk) { + c.nChunk++ + c.read = 0 + } + return read, nil +} + +func TestBufferedDecodingPadding(t *testing.T) { + testcases := []struct { + chunks []string + expectedError string + }{ + {[]string{ + "I4======", + "==", + }, "unexpected EOF"}, + + {[]string{ + "I4======N4======", + }, "illegal base32 data at input byte 2"}, + + {[]string{ + "I4======", + "N4======", + }, "illegal base32 data at input byte 0"}, + + {[]string{ + "I4======", + "========", + }, "illegal base32 data at input byte 0"}, + + {[]string{ + "I4I4I4I4", + "I4======", + "I4======", + }, "illegal base32 data at input byte 0"}, + } + + for _, testcase := range testcases { + r := &chunkReader{chunks: testcase.chunks} + + decoder := NewDecoder(StdEncoding, r) + _, err := io.ReadAll(decoder) + + if err == nil && len(testcase.expectedError) != 0 { + t.Errorf("case %q: got nil error, want %v", testcase.chunks, testcase.expectedError) + } else if err.Error() != testcase.expectedError { + t.Errorf("case %q: got %v, want %v", testcase.chunks, err, testcase.expectedError) + } + } +} + +func TestEncodedLen(t *testing.T) { + var rawStdEncoding = StdEncoding.WithPadding(NoPadding) + type test struct { + enc *Encoding + n int + want int64 + } + tests := []test{ + {StdEncoding, 0, 0}, + {StdEncoding, 1, 8}, + {StdEncoding, 2, 8}, + {StdEncoding, 3, 8}, + {StdEncoding, 4, 8}, + {StdEncoding, 5, 8}, + {StdEncoding, 6, 16}, + {StdEncoding, 10, 16}, + {StdEncoding, 11, 24}, + {rawStdEncoding, 0, 0}, + {rawStdEncoding, 1, 2}, + {rawStdEncoding, 2, 4}, + {rawStdEncoding, 3, 5}, + {rawStdEncoding, 4, 7}, + {rawStdEncoding, 5, 8}, + {rawStdEncoding, 6, 10}, + {rawStdEncoding, 7, 12}, + {rawStdEncoding, 10, 16}, + {rawStdEncoding, 11, 18}, + } + // check overflow + switch strconv.IntSize { + case 32: + tests = append(tests, test{rawStdEncoding, (math.MaxInt-4)/8 + 1, 429496730}) + tests = append(tests, test{rawStdEncoding, math.MaxInt/8*5 + 4, math.MaxInt}) + case 64: + tests = append(tests, test{rawStdEncoding, (math.MaxInt-4)/8 + 1, 1844674407370955162}) + tests = append(tests, test{rawStdEncoding, math.MaxInt/8*5 + 4, math.MaxInt}) + } + for _, tt := range tests { + if got := tt.enc.EncodedLen(tt.n); int64(got) != tt.want { + t.Errorf("EncodedLen(%d): got %d, want %d", tt.n, got, tt.want) + } + } +} + +func TestDecodedLen(t *testing.T) { + var rawStdEncoding = StdEncoding.WithPadding(NoPadding) + type test struct { + enc *Encoding + n int + want int64 + } + tests := []test{ + {StdEncoding, 0, 0}, + {StdEncoding, 8, 5}, + {StdEncoding, 16, 10}, + {StdEncoding, 24, 15}, + {rawStdEncoding, 0, 0}, + {rawStdEncoding, 2, 1}, + {rawStdEncoding, 4, 2}, + {rawStdEncoding, 5, 3}, + {rawStdEncoding, 7, 4}, + {rawStdEncoding, 8, 5}, + {rawStdEncoding, 10, 6}, + {rawStdEncoding, 12, 7}, + {rawStdEncoding, 16, 10}, + {rawStdEncoding, 18, 11}, + } + // check overflow + switch strconv.IntSize { + case 32: + tests = append(tests, test{rawStdEncoding, math.MaxInt/5 + 1, 268435456}) + tests = append(tests, test{rawStdEncoding, math.MaxInt, 1342177279}) + case 64: + tests = append(tests, test{rawStdEncoding, math.MaxInt/5 + 1, 1152921504606846976}) + tests = append(tests, test{rawStdEncoding, math.MaxInt, 5764607523034234879}) + } + for _, tt := range tests { + if got := tt.enc.DecodedLen(tt.n); int64(got) != tt.want { + t.Errorf("DecodedLen(%d): got %d, want %d", tt.n, got, tt.want) + } + } +} + +func TestWithoutPaddingClose(t *testing.T) { + encodings := []*Encoding{ + StdEncoding, + StdEncoding.WithPadding(NoPadding), + } + + for _, encoding := range encodings { + for _, testpair := range pairs { + + var buf strings.Builder + encoder := NewEncoder(encoding, &buf) + encoder.Write([]byte(testpair.decoded)) + encoder.Close() + + expected := testpair.encoded + if encoding.padChar == NoPadding { + expected = strings.ReplaceAll(expected, "=", "") + } + + res := buf.String() + + if res != expected { + t.Errorf("Expected %s got %s; padChar=%d", expected, res, encoding.padChar) + } + } + } +} + +func TestDecodeReadAll(t *testing.T) { + encodings := []*Encoding{ + StdEncoding, + StdEncoding.WithPadding(NoPadding), + } + + for _, pair := range pairs { + for encIndex, encoding := range encodings { + encoded := pair.encoded + if encoding.padChar == NoPadding { + encoded = strings.ReplaceAll(encoded, "=", "") + } + + decReader, err := io.ReadAll(NewDecoder(encoding, strings.NewReader(encoded))) + if err != nil { + t.Errorf("NewDecoder error: %v", err) + } + + if pair.decoded != string(decReader) { + t.Errorf("Expected %s got %s; Encoding %d", pair.decoded, decReader, encIndex) + } + } + } +} + +func TestDecodeSmallBuffer(t *testing.T) { + encodings := []*Encoding{ + StdEncoding, + StdEncoding.WithPadding(NoPadding), + } + + for bufferSize := 1; bufferSize < 200; bufferSize++ { + for _, pair := range pairs { + for encIndex, encoding := range encodings { + encoded := pair.encoded + if encoding.padChar == NoPadding { + encoded = strings.ReplaceAll(encoded, "=", "") + } + + decoder := NewDecoder(encoding, strings.NewReader(encoded)) + + var allRead []byte + + for { + buf := make([]byte, bufferSize) + n, err := decoder.Read(buf) + allRead = append(allRead, buf[0:n]...) + if err == io.EOF { + break + } + if err != nil { + t.Error(err) + } + } + + if pair.decoded != string(allRead) { + t.Errorf("Expected %s got %s; Encoding %d; bufferSize %d", pair.decoded, allRead, encIndex, bufferSize) + } + } + } + } +} diff --git a/gnovm/stdlibs/encoding/base32/example_test.gno b/gnovm/stdlibs/encoding/base32/example_test.gno new file mode 100644 index 00000000000..251624f0bd8 --- /dev/null +++ b/gnovm/stdlibs/encoding/base32/example_test.gno @@ -0,0 +1,68 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Keep in sync with ../base64/example_test.go. + +package base32_test + +import ( + "encoding/base32" + "fmt" + "os" +) + +func ExampleEncoding_EncodeToString() { + data := []byte("any + old & data") + str := base32.StdEncoding.EncodeToString(data) + fmt.Println(str) + // Output: + // MFXHSIBLEBXWYZBAEYQGIYLUME====== +} + +func ExampleEncoding_Encode() { + data := []byte("Hello, world!") + dst := make([]byte, base32.StdEncoding.EncodedLen(len(data))) + base32.StdEncoding.Encode(dst, data) + fmt.Println(string(dst)) + // Output: + // JBSWY3DPFQQHO33SNRSCC=== +} + +func ExampleEncoding_DecodeString() { + str := "ONXW2ZJAMRQXIYJAO5UXI2BAAAQGC3TEEDX3XPY=" + data, err := base32.StdEncoding.DecodeString(str) + if err != nil { + fmt.Println("error:", err) + return + } + fmt.Printf("%q\n", data) + // Output: + // "some data with \x00 and \ufeff" +} + +func ExampleEncoding_Decode() { + str := "JBSWY3DPFQQHO33SNRSCC===" + dst := make([]byte, base32.StdEncoding.DecodedLen(len(str))) + n, err := base32.StdEncoding.Decode(dst, []byte(str)) + if err != nil { + fmt.Println("decode error:", err) + return + } + dst = dst[:n] + fmt.Printf("%q\n", dst) + // Output: + // "Hello, world!" +} + +func ExampleNewEncoder() { + input := []byte("foo\x00bar") + encoder := base32.NewEncoder(base32.StdEncoding, os.Stdout) + encoder.Write(input) + // Must close the encoder when finished to flush any partial blocks. + // If you comment out the following line, the last partial block "r" + // won't be encoded. + encoder.Close() + // Output: + // MZXW6ADCMFZA==== +} diff --git a/gnovm/stdlibs/encoding/binary/varint.gno b/gnovm/stdlibs/encoding/binary/varint.gno new file mode 100644 index 00000000000..7b14fb2b631 --- /dev/null +++ b/gnovm/stdlibs/encoding/binary/varint.gno @@ -0,0 +1,166 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package binary + +// This file implements "varint" encoding of 64-bit integers. +// The encoding is: +// - unsigned integers are serialized 7 bits at a time, starting with the +// least significant bits +// - the most significant bit (msb) in each output byte indicates if there +// is a continuation byte (msb = 1) +// - signed integers are mapped to unsigned integers using "zig-zag" +// encoding: Positive values x are written as 2*x + 0, negative values +// are written as 2*(^x) + 1; that is, negative numbers are complemented +// and whether to complement is encoded in bit 0. +// +// Design note: +// At most 10 bytes are needed for 64-bit values. The encoding could +// be more dense: a full 64-bit value needs an extra byte just to hold bit 63. +// Instead, the msb of the previous byte could be used to hold bit 63 since we +// know there can't be more than 64 bits. This is a trivial improvement and +// would reduce the maximum encoding length to 9 bytes. However, it breaks the +// invariant that the msb is always the "continuation bit" and thus makes the +// format incompatible with a varint encoding for larger numbers (say 128-bit). + +import ( + "errors" + "io" +) + +// MaxVarintLenN is the maximum length of a varint-encoded N-bit integer. +const ( + MaxVarintLen16 = 3 + MaxVarintLen32 = 5 + MaxVarintLen64 = 10 +) + +// AppendUvarint appends the varint-encoded form of x, +// as generated by PutUvarint, to buf and returns the extended buffer. +func AppendUvarint(buf []byte, x uint64) []byte { + for x >= 0x80 { + buf = append(buf, byte(x)|0x80) + x >>= 7 + } + return append(buf, byte(x)) +} + +// PutUvarint encodes a uint64 into buf and returns the number of bytes written. +// If the buffer is too small, PutUvarint will panic. +func PutUvarint(buf []byte, x uint64) int { + i := 0 + for x >= 0x80 { + buf[i] = byte(x) | 0x80 + x >>= 7 + i++ + } + buf[i] = byte(x) + return i + 1 +} + +// Uvarint decodes a uint64 from buf and returns that value and the +// number of bytes read (> 0). If an error occurred, the value is 0 +// and the number of bytes n is <= 0 meaning: +// +// n == 0: buf too small +// n < 0: value larger than 64 bits (overflow) +// and -n is the number of bytes read +func Uvarint(buf []byte) (uint64, int) { + var x uint64 + var s uint + for i, b := range buf { + if i == MaxVarintLen64 { + // Catch byte reads past MaxVarintLen64. + // See issue https://golang.org/issues/41185 + return 0, -(i + 1) // overflow + } + if b < 0x80 { + if i == MaxVarintLen64-1 && b > 1 { + return 0, -(i + 1) // overflow + } + return x | uint64(b)< 0). If an error occurred, the value is 0 +// and the number of bytes n is <= 0 with the following meaning: +// +// n == 0: buf too small +// n < 0: value larger than 64 bits (overflow) +// and -n is the number of bytes read +func Varint(buf []byte) (int64, int) { + ux, n := Uvarint(buf) // ok to continue in presence of error + x := int64(ux >> 1) + if ux&1 != 0 { + x = ^x + } + return x, n +} + +var errOverflow = errors.New("binary: varint overflows a 64-bit integer") + +// ReadUvarint reads an encoded unsigned integer from r and returns it as a uint64. +// The error is EOF only if no bytes were read. +// If an EOF happens after reading some but not all the bytes, +// ReadUvarint returns io.ErrUnexpectedEOF. +func ReadUvarint(r io.ByteReader) (uint64, error) { + var x uint64 + var s uint + for i := 0; i < MaxVarintLen64; i++ { + b, err := r.ReadByte() + if err != nil { + if i > 0 && err == io.EOF { + err = io.ErrUnexpectedEOF + } + return x, err + } + if b < 0x80 { + if i == MaxVarintLen64-1 && b > 1 { + return x, errOverflow + } + return x | uint64(b)<> 1) + if ux&1 != 0 { + x = ^x + } + return x, err +} diff --git a/gnovm/stdlibs/encoding/binary/varint_test.gno b/gnovm/stdlibs/encoding/binary/varint_test.gno new file mode 100644 index 00000000000..274ac74d0dc --- /dev/null +++ b/gnovm/stdlibs/encoding/binary/varint_test.gno @@ -0,0 +1,245 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package binary + +import ( + "bytes" + "io" + "math" + "testing" +) + +func testConstant(t *testing.T, w uint, max int) { + buf := make([]byte, MaxVarintLen64) + n := PutUvarint(buf, 1< 0 { + wantErr = io.ErrUnexpectedEOF + } + if x != 0 || err != wantErr { + t.Errorf("ReadUvarint(%v): got x = %d, err = %s", buf, x, err) + } + } +} + +// Ensure that we catch overflows of bytes going past MaxVarintLen64. +// See issue https://golang.org/issues/41185 +func TestBufferTooBigWithOverflow(t *testing.T) { + tests := []struct { + in []byte + name string + wantN int + wantValue uint64 + }{ + { + name: "invalid: 1000 bytes", + in: func() []byte { + b := make([]byte, 1000) + for i := range b { + b[i] = 0xff + } + b[999] = 0 + return b + }(), + wantN: -11, + wantValue: 0, + }, + { + name: "valid: math.MaxUint64-40", + in: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01}, + wantValue: math.MaxUint64 - 40, + wantN: 10, + }, + { + name: "invalid: with more than MaxVarintLen64 bytes", + in: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01}, + wantN: -11, + wantValue: 0, + }, + { + name: "invalid: 10th byte", + in: []byte{0xd7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, + wantN: -10, + wantValue: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + value, n := Uvarint(tt.in) + if g, w := n, tt.wantN; g != w { + t.Errorf("bytes returned=%d, want=%d", g, w) + } + if g, w := value, tt.wantValue; g != w { + t.Errorf("value=%d, want=%d", g, w) + } + }) + } +} + +func testOverflow(t *testing.T, buf []byte, x0 uint64, n0 int, err0 error) { + x, n := Uvarint(buf) + if x != 0 || n != n0 { + t.Errorf("Uvarint(% X): got x = %d, n = %d; want 0, %d", buf, x, n, n0) + } + + r := bytes.NewReader(buf) + ln := r.Len() + x, err := ReadUvarint(r) + if x != x0 || err != err0 { + t.Errorf("ReadUvarint(%v): got x = %d, err = %s; want %d, %s", buf, x, err, x0, err0) + } + if read := ln - r.Len(); read > MaxVarintLen64 { + t.Errorf("ReadUvarint(%v): read more than MaxVarintLen64 bytes, got %d", buf, read) + } +} + +func TestOverflow(t *testing.T) { + testOverflow(t, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x2}, 0, -10, errOverflow) + testOverflow(t, []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x1, 0, 0}, 0, -11, errOverflow) + testOverflow(t, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, 1<<64-1, -11, errOverflow) // 11 bytes, should overflow +} + +func TestNonCanonicalZero(t *testing.T) { + buf := []byte{0x80, 0x80, 0x80, 0} + x, n := Uvarint(buf) + if x != 0 || n != 4 { + t.Errorf("Uvarint(%v): got x = %d, n = %d; want 0, 4", buf, x, n) + } +} + +func BenchmarkPutUvarint32(b *testing.B) { + buf := make([]byte, MaxVarintLen32) + b.SetBytes(4) + for i := 0; i < b.N; i++ { + for j := uint(0); j < MaxVarintLen32; j++ { + PutUvarint(buf, 1<<(j*7)) + } + } +} + +func BenchmarkPutUvarint64(b *testing.B) { + buf := make([]byte, MaxVarintLen64) + b.SetBytes(8) + for i := 0; i < b.N; i++ { + for j := uint(0); j < MaxVarintLen64; j++ { + PutUvarint(buf, 1<<(j*7)) + } + } +} diff --git a/gnovm/stdlibs/encoding/csv/reader.gno b/gnovm/stdlibs/encoding/csv/reader.gno new file mode 100644 index 00000000000..889b7ff5588 --- /dev/null +++ b/gnovm/stdlibs/encoding/csv/reader.gno @@ -0,0 +1,467 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package csv reads and writes comma-separated values (CSV) files. +// There are many kinds of CSV files; this package supports the format +// described in RFC 4180. +// +// A csv file contains zero or more records of one or more fields per record. +// Each record is separated by the newline character. The final record may +// optionally be followed by a newline character. +// +// field1,field2,field3 +// +// White space is considered part of a field. +// +// Carriage returns before newline characters are silently removed. +// +// Blank lines are ignored. A line with only whitespace characters (excluding +// the ending newline character) is not considered a blank line. +// +// Fields which start and stop with the quote character " are called +// quoted-fields. The beginning and ending quote are not part of the +// field. +// +// The source: +// +// normal string,"quoted-field" +// +// results in the fields +// +// {`normal string`, `quoted-field`} +// +// Within a quoted-field a quote character followed by a second quote +// character is considered a single quote. +// +// "the ""word"" is true","a ""quoted-field""" +// +// results in +// +// {`the "word" is true`, `a "quoted-field"`} +// +// Newlines and commas may be included in a quoted-field +// +// "Multi-line +// field","comma is ," +// +// results in +// +// {`Multi-line +// field`, `comma is ,`} +package csv + +import ( + "bufio" + "bytes" + "errors" + "io" + "strconv" + "unicode" + "unicode/utf8" +) + +// A ParseError is returned for parsing errors. +// Line and column numbers are 1-indexed. +type ParseError struct { + StartLine int // Line where the record starts + Line int // Line where the error occurred + Column int // Column (1-based byte index) where the error occurred + Err error // The actual error +} + +func (e *ParseError) Error() string { + if e.Err == ErrFieldCount { + return "record on line " + strconv.Itoa(e.Line) + ": " + e.Err.Error() + } + if e.StartLine != e.Line { + return "record on line " + strconv.Itoa(e.StartLine) + ": parse error on line " + strconv.Itoa(e.Line) + + ", column " + strconv.Itoa(e.Column) + ": " + e.Err.Error() + } + return "parse error on line " + strconv.Itoa(e.Line) + ", column " + strconv.Itoa(e.Column) + ": " + e.Err.Error() +} + +func (e *ParseError) Unwrap() error { return e.Err } + +// These are the errors that can be returned in [ParseError.Err]. +var ( + ErrBareQuote = errors.New("bare \" in non-quoted-field") + ErrQuote = errors.New("extraneous or missing \" in quoted-field") + ErrFieldCount = errors.New("wrong number of fields") + + // Deprecated: ErrTrailingComma is no longer used. + ErrTrailingComma = errors.New("extra delimiter at end of line") +) + +var ErrInvalidDelim = errors.New("csv: invalid field or comment delimiter") + +func validDelim(r rune) bool { + return r != 0 && r != '"' && r != '\r' && r != '\n' && utf8.ValidRune(r) && r != utf8.RuneError +} + +// A Reader reads records from a CSV-encoded file. +// +// As returned by [NewReader], a Reader expects input conforming to RFC 4180. +// The exported fields can be changed to customize the details before the +// first call to [Reader.Read] or [Reader.ReadAll]. +// +// The Reader converts all \r\n sequences in its input to plain \n, +// including in multiline field values, so that the returned data does +// not depend on which line-ending convention an input file uses. +type Reader struct { + // Comma is the field delimiter. + // It is set to comma (',') by NewReader. + // Comma must be a valid rune and must not be \r, \n, + // or the Unicode replacement character (0xFFFD). + Comma rune + + // Comment, if not 0, is the comment character. Lines beginning with the + // Comment character without preceding whitespace are ignored. + // With leading whitespace the Comment character becomes part of the + // field, even if TrimLeadingSpace is true. + // Comment must be a valid rune and must not be \r, \n, + // or the Unicode replacement character (0xFFFD). + // It must also not be equal to Comma. + Comment rune + + // FieldsPerRecord is the number of expected fields per record. + // If FieldsPerRecord is positive, Read requires each record to + // have the given number of fields. If FieldsPerRecord is 0, Read sets it to + // the number of fields in the first record, so that future records must + // have the same field count. If FieldsPerRecord is negative, no check is + // made and records may have a variable number of fields. + FieldsPerRecord int + + // If LazyQuotes is true, a quote may appear in an unquoted field and a + // non-doubled quote may appear in a quoted field. + LazyQuotes bool + + // If TrimLeadingSpace is true, leading white space in a field is ignored. + // This is done even if the field delimiter, Comma, is white space. + TrimLeadingSpace bool + + // ReuseRecord controls whether calls to Read may return a slice sharing + // the backing array of the previous call's returned slice for performance. + // By default, each call to Read returns newly allocated memory owned by the caller. + ReuseRecord bool + + // Deprecated: TrailingComma is no longer used. + TrailingComma bool + + r *bufio.Reader + + // numLine is the current line being read in the CSV file. + numLine int + + // offset is the input stream byte offset of the current reader position. + offset int64 + + // rawBuffer is a line buffer only used by the readLine method. + rawBuffer []byte + + // recordBuffer holds the unescaped fields, one after another. + // The fields can be accessed by using the indexes in fieldIndexes. + // E.g., For the row `a,"b","c""d",e`, recordBuffer will contain `abc"de` + // and fieldIndexes will contain the indexes [1, 2, 5, 6]. + recordBuffer []byte + + // fieldIndexes is an index of fields inside recordBuffer. + // The i'th field ends at offset fieldIndexes[i] in recordBuffer. + fieldIndexes []int + + // fieldPositions is an index of field positions for the + // last record returned by Read. + fieldPositions []position + + // lastRecord is a record cache and only used when ReuseRecord == true. + lastRecord []string +} + +// NewReader returns a new Reader that reads from r. +func NewReader(r io.Reader) *Reader { + return &Reader{ + Comma: ',', + r: bufio.NewReader(r), + } +} + +// Read reads one record (a slice of fields) from r. +// If the record has an unexpected number of fields, +// Read returns the record along with the error [ErrFieldCount]. +// If the record contains a field that cannot be parsed, +// Read returns a partial record along with the parse error. +// The partial record contains all fields read before the error. +// If there is no data left to be read, Read returns nil, [io.EOF]. +// If [Reader.ReuseRecord] is true, the returned slice may be shared +// between multiple calls to Read. +func (r *Reader) Read() (record []string, err error) { + if r.ReuseRecord { + record, err = r.readRecord(r.lastRecord) + r.lastRecord = record + } else { + record, err = r.readRecord(nil) + } + return record, err +} + +// FieldPos returns the line and column corresponding to +// the start of the field with the given index in the slice most recently +// returned by [Reader.Read]. Numbering of lines and columns starts at 1; +// columns are counted in bytes, not runes. +// +// If this is called with an out-of-bounds index, it panics. +func (r *Reader) FieldPos(field int) (line, column int) { + if field < 0 || field >= len(r.fieldPositions) { + panic("out of range index passed to FieldPos") + } + p := &r.fieldPositions[field] + return p.line, p.col +} + +// InputOffset returns the input stream byte offset of the current reader +// position. The offset gives the location of the end of the most recently +// read row and the beginning of the next row. +func (r *Reader) InputOffset() int64 { + return r.offset +} + +// pos holds the position of a field in the current line. +type position struct { + line, col int +} + +// ReadAll reads all the remaining records from r. +// Each record is a slice of fields. +// A successful call returns err == nil, not err == [io.EOF]. Because ReadAll is +// defined to read until EOF, it does not treat end of file as an error to be +// reported. +func (r *Reader) ReadAll() (records [][]string, err error) { + for { + record, err := r.readRecord(nil) + if err == io.EOF { + return records, nil + } + if err != nil { + return nil, err + } + records = append(records, record) + } +} + +// readLine reads the next line (with the trailing endline). +// If EOF is hit without a trailing endline, it will be omitted. +// If some bytes were read, then the error is never [io.EOF]. +// The result is only valid until the next call to readLine. +func (r *Reader) readLine() ([]byte, error) { + line, err := r.r.ReadSlice('\n') + if err == bufio.ErrBufferFull { + r.rawBuffer = append(r.rawBuffer[:0], line...) + for err == bufio.ErrBufferFull { + line, err = r.r.ReadSlice('\n') + r.rawBuffer = append(r.rawBuffer, line...) + } + line = r.rawBuffer + } + readSize := len(line) + if readSize > 0 && err == io.EOF { + err = nil + // For backwards compatibility, drop trailing \r before EOF. + if line[readSize-1] == '\r' { + line = line[:readSize-1] + } + } + r.numLine++ + r.offset += int64(readSize) + // Normalize \r\n to \n on all input lines. + if n := len(line); n >= 2 && line[n-2] == '\r' && line[n-1] == '\n' { + line[n-2] = '\n' + line = line[:n-1] + } + return line, err +} + +// lengthNL reports the number of bytes for the trailing \n. +func lengthNL(b []byte) int { + if len(b) > 0 && b[len(b)-1] == '\n' { + return 1 + } + return 0 +} + +// nextRune returns the next rune in b or utf8.RuneError. +func nextRune(b []byte) rune { + r, _ := utf8.DecodeRune(b) + return r +} + +func (r *Reader) readRecord(dst []string) ([]string, error) { + if r.Comma == r.Comment || !validDelim(r.Comma) || (r.Comment != 0 && !validDelim(r.Comment)) { + return nil, ErrInvalidDelim + } + + // Read line (automatically skipping past empty lines and any comments). + var line []byte + var errRead error + for errRead == nil { + line, errRead = r.readLine() + if r.Comment != 0 && nextRune(line) == r.Comment { + line = nil + continue // Skip comment lines + } + if errRead == nil && len(line) == lengthNL(line) { + line = nil + continue // Skip empty lines + } + break + } + if errRead == io.EOF { + return nil, errRead + } + + // Parse each field in the record. + var err error + const quoteLen = len(`"`) + commaLen := utf8.RuneLen(r.Comma) + recLine := r.numLine // Starting line for record + r.recordBuffer = r.recordBuffer[:0] + r.fieldIndexes = r.fieldIndexes[:0] + r.fieldPositions = r.fieldPositions[:0] + pos := position{line: r.numLine, col: 1} +parseField: + for { + if r.TrimLeadingSpace { + i := bytes.IndexFunc(line, func(r rune) bool { + return !unicode.IsSpace(r) + }) + if i < 0 { + i = len(line) + pos.col -= lengthNL(line) + } + line = line[i:] + pos.col += i + } + if len(line) == 0 || line[0] != '"' { + // Non-quoted string field + i := bytes.IndexRune(line, r.Comma) + field := line + if i >= 0 { + field = field[:i] + } else { + field = field[:len(field)-lengthNL(field)] + } + // Check to make sure a quote does not appear in field. + if !r.LazyQuotes { + if j := bytes.IndexByte(field, '"'); j >= 0 { + col := pos.col + j + err = &ParseError{StartLine: recLine, Line: r.numLine, Column: col, Err: ErrBareQuote} + break parseField + } + } + r.recordBuffer = append(r.recordBuffer, field...) + r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer)) + r.fieldPositions = append(r.fieldPositions, pos) + if i >= 0 { + line = line[i+commaLen:] + pos.col += i + commaLen + continue parseField + } + break parseField + } else { + // Quoted string field + fieldPos := pos + line = line[quoteLen:] + pos.col += quoteLen + for { + i := bytes.IndexByte(line, '"') + if i >= 0 { + // Hit next quote. + r.recordBuffer = append(r.recordBuffer, line[:i]...) + line = line[i+quoteLen:] + pos.col += i + quoteLen + switch rn := nextRune(line); { + case rn == '"': + // `""` sequence (append quote). + r.recordBuffer = append(r.recordBuffer, '"') + line = line[quoteLen:] + pos.col += quoteLen + case rn == r.Comma: + // `",` sequence (end of field). + line = line[commaLen:] + pos.col += commaLen + r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer)) + r.fieldPositions = append(r.fieldPositions, fieldPos) + continue parseField + case lengthNL(line) == len(line): + // `"\n` sequence (end of line). + r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer)) + r.fieldPositions = append(r.fieldPositions, fieldPos) + break parseField + case r.LazyQuotes: + // `"` sequence (bare quote). + r.recordBuffer = append(r.recordBuffer, '"') + default: + // `"*` sequence (invalid non-escaped quote). + err = &ParseError{StartLine: recLine, Line: r.numLine, Column: pos.col - quoteLen, Err: ErrQuote} + break parseField + } + } else if len(line) > 0 { + // Hit end of line (copy all data so far). + r.recordBuffer = append(r.recordBuffer, line...) + if errRead != nil { + break parseField + } + pos.col += len(line) + line, errRead = r.readLine() + if len(line) > 0 { + pos.line++ + pos.col = 1 + } + if errRead == io.EOF { + errRead = nil + } + } else { + // Abrupt end of file (EOF or error). + if !r.LazyQuotes && errRead == nil { + err = &ParseError{StartLine: recLine, Line: pos.line, Column: pos.col, Err: ErrQuote} + break parseField + } + r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer)) + r.fieldPositions = append(r.fieldPositions, fieldPos) + break parseField + } + } + } + } + if err == nil { + err = errRead + } + + // Create a single string and create slices out of it. + // This pins the memory of the fields together, but allocates once. + str := string(r.recordBuffer) // Convert to string once to batch allocations + dst = dst[:0] + if cap(dst) < len(r.fieldIndexes) { + dst = make([]string, len(r.fieldIndexes)) + } + dst = dst[:len(r.fieldIndexes)] + var preIdx int + for i, idx := range r.fieldIndexes { + dst[i] = str[preIdx:idx] + preIdx = idx + } + + // Check or update the expected fields per record. + if r.FieldsPerRecord > 0 { + if len(dst) != r.FieldsPerRecord && err == nil { + err = &ParseError{ + StartLine: recLine, + Line: recLine, + Column: 1, + Err: ErrFieldCount, + } + } + } else if r.FieldsPerRecord == 0 { + r.FieldsPerRecord = len(dst) + } + return dst, err +} diff --git a/gnovm/stdlibs/encoding/csv/reader_test.gno b/gnovm/stdlibs/encoding/csv/reader_test.gno new file mode 100644 index 00000000000..fa4ecdce3ed --- /dev/null +++ b/gnovm/stdlibs/encoding/csv/reader_test.gno @@ -0,0 +1,706 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package csv + +import ( + "fmt" + "io" + + // "reflect" + "strings" + "testing" + "unicode/utf8" +) + +type readTest struct { + Name string + Input string + Output [][]string + Positions [][][2]int + Errors []error + + // These fields are copied into the Reader + Comma rune + Comment rune + UseFieldsPerRecord bool // false (default) means FieldsPerRecord is -1 + FieldsPerRecord int + LazyQuotes bool + TrimLeadingSpace bool + ReuseRecord bool +} + +// In these tests, the §, ¶ and ∑ characters in readTest.Input are used to denote +// the start of a field, a record boundary and the position of an error respectively. +// They are removed before parsing and are used to verify the position +// information reported by FieldPos. + +var readTests = []readTest{{ + Name: "Simple", + Input: "§a,§b,§c\n", + Output: [][]string{{"a", "b", "c"}}, +}, { + Name: "CRLF", + Input: "§a,§b\r\n¶§c,§d\r\n", + Output: [][]string{{"a", "b"}, {"c", "d"}}, +}, { + Name: "BareCR", + Input: "§a,§b\rc,§d\r\n", + Output: [][]string{{"a", "b\rc", "d"}}, +}, { + Name: "RFC4180test", + Input: `§#field1,§field2,§field3 +¶§"aaa",§"bb +b",§"ccc" +¶§"a,a",§"b""bb",§"ccc" +¶§zzz,§yyy,§xxx +`, + Output: [][]string{ + {"#field1", "field2", "field3"}, + {"aaa", "bb\nb", "ccc"}, + {"a,a", `b"bb`, "ccc"}, + {"zzz", "yyy", "xxx"}, + }, + UseFieldsPerRecord: true, + FieldsPerRecord: 0, +}, { + Name: "NoEOLTest", + Input: "§a,§b,§c", + Output: [][]string{{"a", "b", "c"}}, +}, { + Name: "Semicolon", + Input: "§a;§b;§c\n", + Output: [][]string{{"a", "b", "c"}}, + Comma: ';', +}, { + Name: "MultiLine", + Input: `§"two +line",§"one line",§"three +line +field"`, + Output: [][]string{{"two\nline", "one line", "three\nline\nfield"}}, +}, { + Name: "BlankLine", + Input: "§a,§b,§c\n\n¶§d,§e,§f\n\n", + Output: [][]string{ + {"a", "b", "c"}, + {"d", "e", "f"}, + }, +}, { + Name: "BlankLineFieldCount", + Input: "§a,§b,§c\n\n¶§d,§e,§f\n\n", + Output: [][]string{ + {"a", "b", "c"}, + {"d", "e", "f"}, + }, + UseFieldsPerRecord: true, + FieldsPerRecord: 0, +}, { + Name: "TrimSpace", + Input: " §a, §b, §c\n", + Output: [][]string{{"a", "b", "c"}}, + TrimLeadingSpace: true, +}, { + Name: "LeadingSpace", + Input: "§ a,§ b,§ c\n", + Output: [][]string{{" a", " b", " c"}}, +}, { + Name: "Comment", + Input: "#1,2,3\n§a,§b,§c\n#comment", + Output: [][]string{{"a", "b", "c"}}, + Comment: '#', +}, { + Name: "NoComment", + Input: "§#1,§2,§3\n¶§a,§b,§c", + Output: [][]string{{"#1", "2", "3"}, {"a", "b", "c"}}, +}, { + Name: "LazyQuotes", + Input: `§a "word",§"1"2",§a",§"b`, + Output: [][]string{{`a "word"`, `1"2`, `a"`, `b`}}, + LazyQuotes: true, +}, { + Name: "BareQuotes", + Input: `§a "word",§"1"2",§a"`, + Output: [][]string{{`a "word"`, `1"2`, `a"`}}, + LazyQuotes: true, +}, { + Name: "BareDoubleQuotes", + Input: `§a""b,§c`, + Output: [][]string{{`a""b`, `c`}}, + LazyQuotes: true, +}, { + Name: "BadDoubleQuotes", + Input: `§a∑""b,c`, + Errors: []error{&ParseError{Err: ErrBareQuote}}, +}, { + Name: "TrimQuote", + Input: ` §"a",§" b",§c`, + Output: [][]string{{"a", " b", "c"}}, + TrimLeadingSpace: true, +}, { + Name: "BadBareQuote", + Input: `§a ∑"word","b"`, + Errors: []error{&ParseError{Err: ErrBareQuote}}, +}, { + Name: "BadTrailingQuote", + Input: `§"a word",b∑"`, + Errors: []error{&ParseError{Err: ErrBareQuote}}, +}, { + Name: "ExtraneousQuote", + Input: `§"a ∑"word","b"`, + Errors: []error{&ParseError{Err: ErrQuote}}, +}, { + Name: "BadFieldCount", + Input: "§a,§b,§c\n¶∑§d,§e", + Errors: []error{nil, &ParseError{Err: ErrFieldCount}}, + Output: [][]string{{"a", "b", "c"}, {"d", "e"}}, + UseFieldsPerRecord: true, + FieldsPerRecord: 0, +}, { + Name: "BadFieldCountMultiple", + Input: "§a,§b,§c\n¶∑§d,§e\n¶∑§f", + Errors: []error{nil, &ParseError{Err: ErrFieldCount}, &ParseError{Err: ErrFieldCount}}, + Output: [][]string{{"a", "b", "c"}, {"d", "e"}, {"f"}}, + UseFieldsPerRecord: true, + FieldsPerRecord: 0, +}, { + Name: "BadFieldCount1", + Input: `§∑a,§b,§c`, + Errors: []error{&ParseError{Err: ErrFieldCount}}, + Output: [][]string{{"a", "b", "c"}}, + UseFieldsPerRecord: true, + FieldsPerRecord: 2, +}, { + Name: "FieldCount", + Input: "§a,§b,§c\n¶§d,§e", + Output: [][]string{{"a", "b", "c"}, {"d", "e"}}, +}, { + Name: "TrailingCommaEOF", + Input: "§a,§b,§c,§", + Output: [][]string{{"a", "b", "c", ""}}, +}, { + Name: "TrailingCommaEOL", + Input: "§a,§b,§c,§\n", + Output: [][]string{{"a", "b", "c", ""}}, +}, { + Name: "TrailingCommaSpaceEOF", + Input: "§a,§b,§c, §", + Output: [][]string{{"a", "b", "c", ""}}, + TrimLeadingSpace: true, +}, { + Name: "TrailingCommaSpaceEOL", + Input: "§a,§b,§c, §\n", + Output: [][]string{{"a", "b", "c", ""}}, + TrimLeadingSpace: true, +}, { + Name: "TrailingCommaLine3", + Input: "§a,§b,§c\n¶§d,§e,§f\n¶§g,§hi,§", + Output: [][]string{{"a", "b", "c"}, {"d", "e", "f"}, {"g", "hi", ""}}, + TrimLeadingSpace: true, +}, { + Name: "NotTrailingComma3", + Input: "§a,§b,§c,§ \n", + Output: [][]string{{"a", "b", "c", " "}}, +}, { + Name: "CommaFieldTest", + Input: `§x,§y,§z,§w +¶§x,§y,§z,§ +¶§x,§y,§,§ +¶§x,§,§,§ +¶§,§,§,§ +¶§"x",§"y",§"z",§"w" +¶§"x",§"y",§"z",§"" +¶§"x",§"y",§"",§"" +¶§"x",§"",§"",§"" +¶§"",§"",§"",§"" +`, + Output: [][]string{ + {"x", "y", "z", "w"}, + {"x", "y", "z", ""}, + {"x", "y", "", ""}, + {"x", "", "", ""}, + {"", "", "", ""}, + {"x", "y", "z", "w"}, + {"x", "y", "z", ""}, + {"x", "y", "", ""}, + {"x", "", "", ""}, + {"", "", "", ""}, + }, +}, { + Name: "TrailingCommaIneffective1", + Input: "§a,§b,§\n¶§c,§d,§e", + Output: [][]string{ + {"a", "b", ""}, + {"c", "d", "e"}, + }, + TrimLeadingSpace: true, +}, { + Name: "ReadAllReuseRecord", + Input: "§a,§b\n¶§c,§d", + Output: [][]string{ + {"a", "b"}, + {"c", "d"}, + }, + ReuseRecord: true, +}, { + Name: "StartLine1", // Issue 19019 + Input: "§a,\"b\nc∑\"d,e", + Errors: []error{&ParseError{Err: ErrQuote}}, +}, { + Name: "StartLine2", + Input: "§a,§b\n¶§\"d\n\n,e∑", + Errors: []error{nil, &ParseError{Err: ErrQuote}}, + Output: [][]string{{"a", "b"}}, +}, { + Name: "CRLFInQuotedField", // Issue 21201 + Input: "§A,§\"Hello\r\nHi\",§B\r\n", + Output: [][]string{ + {"A", "Hello\nHi", "B"}, + }, +}, { + Name: "BinaryBlobField", // Issue 19410 + Input: "§x09\x41\xb4\x1c,§aktau", + Output: [][]string{{"x09A\xb4\x1c", "aktau"}}, +}, { + Name: "TrailingCR", + Input: "§field1,§field2\r", + Output: [][]string{{"field1", "field2"}}, +}, { + Name: "QuotedTrailingCR", + Input: "§\"field\"\r", + Output: [][]string{{"field"}}, +}, { + Name: "QuotedTrailingCRCR", + Input: "§\"field∑\"\r\r", + Errors: []error{&ParseError{Err: ErrQuote}}, +}, { + Name: "FieldCR", + Input: "§field\rfield\r", + Output: [][]string{{"field\rfield"}}, +}, { + Name: "FieldCRCR", + Input: "§field\r\rfield\r\r", + Output: [][]string{{"field\r\rfield\r"}}, +}, { + Name: "FieldCRCRLF", + Input: "§field\r\r\n¶§field\r\r\n", + Output: [][]string{{"field\r"}, {"field\r"}}, +}, { + Name: "FieldCRCRLFCR", + Input: "§field\r\r\n¶§\rfield\r\r\n\r", + Output: [][]string{{"field\r"}, {"\rfield\r"}}, +}, { + Name: "FieldCRCRLFCRCR", + Input: "§field\r\r\n¶§\r\rfield\r\r\n¶§\r\r", + Output: [][]string{{"field\r"}, {"\r\rfield\r"}, {"\r"}}, +}, { + Name: "MultiFieldCRCRLFCRCR", + Input: "§field1,§field2\r\r\n¶§\r\rfield1,§field2\r\r\n¶§\r\r,§", + Output: [][]string{ + {"field1", "field2\r"}, + {"\r\rfield1", "field2\r"}, + {"\r\r", ""}, + }, +}, { + Name: "NonASCIICommaAndComment", + Input: "§a£§b,c£ \t§d,e\n€ comment\n", + Output: [][]string{{"a", "b,c", "d,e"}}, + TrimLeadingSpace: true, + Comma: '£', + Comment: '€', +}, { + Name: "NonASCIICommaAndCommentWithQuotes", + Input: "§a€§\" b,\"€§ c\nλ comment\n", + Output: [][]string{{"a", " b,", " c"}}, + Comma: '€', + Comment: 'λ', +}, { + // λ and θ start with the same byte. + // This tests that the parser doesn't confuse such characters. + Name: "NonASCIICommaConfusion", + Input: "§\"abθcd\"λ§efθgh", + Output: [][]string{{"abθcd", "efθgh"}}, + Comma: 'λ', + Comment: '€', +}, { + Name: "NonASCIICommentConfusion", + Input: "§λ\n¶§λ\nθ\n¶§λ\n", + Output: [][]string{{"λ"}, {"λ"}, {"λ"}}, + Comment: 'θ', +}, { + Name: "QuotedFieldMultipleLF", + Input: "§\"\n\n\n\n\"", + Output: [][]string{{"\n\n\n\n"}}, +}, { + Name: "MultipleCRLF", + Input: "\r\n\r\n\r\n\r\n", +}, { + // The implementation may read each line in several chunks if it doesn't fit entirely + // in the read buffer, so we should test the code to handle that condition. + Name: "HugeLines", + Input: strings.Repeat("#ignore\n", 10000) + "§" + strings.Repeat("@", 5000) + ",§" + strings.Repeat("*", 5000), + Output: [][]string{{strings.Repeat("@", 5000), strings.Repeat("*", 5000)}}, + Comment: '#', +}, { + Name: "QuoteWithTrailingCRLF", + Input: "§\"foo∑\"bar\"\r\n", + Errors: []error{&ParseError{Err: ErrQuote}}, +}, { + Name: "LazyQuoteWithTrailingCRLF", + Input: "§\"foo\"bar\"\r\n", + Output: [][]string{{`foo"bar`}}, + LazyQuotes: true, +}, { + Name: "DoubleQuoteWithTrailingCRLF", + Input: "§\"foo\"\"bar\"\r\n", + Output: [][]string{{`foo"bar`}}, +}, { + Name: "EvenQuotes", + Input: `§""""""""`, + Output: [][]string{{`"""`}}, +}, { + Name: "OddQuotes", + Input: `§"""""""∑`, + Errors: []error{&ParseError{Err: ErrQuote}}, +}, { + Name: "LazyOddQuotes", + Input: `§"""""""`, + Output: [][]string{{`"""`}}, + LazyQuotes: true, +}, { + Name: "BadComma1", + Comma: '\n', + Errors: []error{ErrInvalidDelim}, +}, { + Name: "BadComma2", + Comma: '\r', + Errors: []error{ErrInvalidDelim}, +}, { + Name: "BadComma3", + Comma: '"', + Errors: []error{ErrInvalidDelim}, +}, { + Name: "BadComma4", + Comma: utf8.RuneError, + Errors: []error{ErrInvalidDelim}, +}, { + Name: "BadComment1", + Comment: '\n', + Errors: []error{ErrInvalidDelim}, +}, { + Name: "BadComment2", + Comment: '\r', + Errors: []error{ErrInvalidDelim}, +}, { + Name: "BadComment3", + Comment: utf8.RuneError, + Errors: []error{ErrInvalidDelim}, +}, { + Name: "BadCommaComment", + Comma: 'X', + Comment: 'X', + Errors: []error{ErrInvalidDelim}, +}} + +func TestRead(t *testing.T) { + newReader := func(tt readTest) (*Reader, [][][2]int, map[int][2]int, string) { + positions, errPositions, input := makePositions(tt.Input) + r := NewReader(strings.NewReader(input)) + + if tt.Comma != 0 { + r.Comma = tt.Comma + } + r.Comment = tt.Comment + if tt.UseFieldsPerRecord { + r.FieldsPerRecord = tt.FieldsPerRecord + } else { + r.FieldsPerRecord = -1 + } + r.LazyQuotes = tt.LazyQuotes + r.TrimLeadingSpace = tt.TrimLeadingSpace + r.ReuseRecord = tt.ReuseRecord + return r, positions, errPositions, input + } + + for _, tt := range readTests { + t.Run(tt.Name, func(t *testing.T) { + r, positions, errPositions, input := newReader(tt) + out, err := r.ReadAll() + if wantErr := firstError(tt.Errors, positions, errPositions); wantErr != nil { + if !deepEqual(err, wantErr) { + t.Fatalf("ReadAll() error mismatch:\ngot %v (%#v)\nwant %v (%#v)", err, err, wantErr, wantErr) + } + if out != nil { + t.Fatalf("ReadAll() output:\ngot %q\nwant nil", out) + } + } else { + if err != nil { + t.Fatalf("unexpected Readall() error: %v", err) + } + if !deepEqual(out, tt.Output) { + t.Fatalf("ReadAll() output:\ngot %q\nwant %q", out, tt.Output) + } + } + + // Check input offset after call ReadAll() + inputByteSize := len(input) + inputOffset := r.InputOffset() + if err == nil && int64(inputByteSize) != inputOffset { + t.Errorf("wrong input offset after call ReadAll():\ngot: %d\nwant: %d\ninput: %s", inputOffset, inputByteSize, input) + } + + // Check field and error positions. + r, _, _, _ = newReader(tt) + for recNum := 0; ; recNum++ { + rec, err := r.Read() + var wantErr error + if recNum < len(tt.Errors) && tt.Errors[recNum] != nil { + wantErr = errorWithPosition(tt.Errors[recNum], recNum, positions, errPositions) + } else if recNum >= len(tt.Output) { + wantErr = io.EOF + } + if !deepEqual(err, wantErr) { + t.Fatalf("Read() error at record %d:\ngot %v (%#v)\nwant %v (%#v)", recNum, err, err, wantErr, wantErr) + } + // ErrFieldCount is explicitly non-fatal. + if err != nil && !isErr(err, ErrFieldCount) { + if recNum < len(tt.Output) { + t.Fatalf("need more records; got %d want %d", recNum, len(tt.Output)) + } + break + } + if got, want := rec, tt.Output[recNum]; !deepEqual(got, want) { + t.Errorf("Read vs ReadAll mismatch;\ngot %q\nwant %q", got, want) + } + pos := positions[recNum] + if len(pos) != len(rec) { + t.Fatalf("mismatched position length at record %d", recNum) + } + for i := range rec { + line, col := r.FieldPos(i) + if got, want := [2]int{line, col}, pos[i]; got != want { + t.Errorf("position mismatch at record %d, field %d;\ngot %v\nwant %v", recNum, i, got, want) + } + } + } + }) + } +} + +// XXX: substitute for errors.Is +func isErr(err, match error) bool { + if err == match { + return true + } + if w, ok := err.(interface{ Unwrap() error }); ok { + return isErr(w.Unwrap(), match) + } + return false +} + +// XXX: substitute for reflect.DeepEqual +func deepEqual(v, x interface{}) bool { + if v == nil { + return x == nil + } + // dumb deep equal, handling only a few cases + switch v := v.(type) { + case error: + return v.Error() == x.(error).Error() + case string: + return v == x.(string) + case []string: + x := x.([]string) + if len(v) != len(x) { + return false + } + for i := range v { + if !deepEqual(v[i], x[i]) { + return false + } + } + return true + case [][]string: + x := x.([][]string) + if len(v) != len(x) { + return false + } + for i := range v { + if !deepEqual(v[i], x[i]) { + return false + } + } + return true + default: + panic("not handled " + fmt.Sprintf("%T %T", v, x)) + } +} + +// firstError returns the first non-nil error in errs, +// with the position adjusted according to the error's +// index inside positions. +func firstError(errs []error, positions [][][2]int, errPositions map[int][2]int) error { + for i, err := range errs { + if err != nil { + return errorWithPosition(err, i, positions, errPositions) + } + } + return nil +} + +func errorWithPosition(err error, recNum int, positions [][][2]int, errPositions map[int][2]int) error { + parseErr, ok := err.(*ParseError) + if !ok { + return err + } + if recNum >= len(positions) { + panic(fmt.Errorf("no positions found for error at record %d", recNum)) + } + errPos, ok := errPositions[recNum] + if !ok { + panic(fmt.Errorf("no error position found for error at record %d", recNum)) + } + parseErr1 := *parseErr + parseErr1.StartLine = positions[recNum][0][0] + parseErr1.Line = errPos[0] + parseErr1.Column = errPos[1] + return &parseErr1 +} + +// makePositions returns the expected field positions of all +// the fields in text, the positions of any errors, and the text with the position markers +// removed. +// +// The start of each field is marked with a § symbol; +// CSV lines are separated by ¶ symbols; +// Error positions are marked with ∑ symbols. +func makePositions(text string) ([][][2]int, map[int][2]int, string) { + buf := make([]byte, 0, len(text)) + var positions [][][2]int + errPositions := make(map[int][2]int) + line, col := 1, 1 + recNum := 0 + + for len(text) > 0 { + r, size := utf8.DecodeRuneInString(text) + switch r { + case '\n': + line++ + col = 1 + buf = append(buf, '\n') + case '§': + if len(positions) == 0 { + positions = append(positions, [][2]int{}) + } + positions[len(positions)-1] = append(positions[len(positions)-1], [2]int{line, col}) + case '¶': + positions = append(positions, [][2]int{}) + recNum++ + case '∑': + errPositions[recNum] = [2]int{line, col} + default: + buf = append(buf, text[:size]...) + col += size + } + text = text[size:] + } + return positions, errPositions, string(buf) +} + +// nTimes is an io.Reader which yields the string s n times. +type nTimes struct { + s string + n int + off int +} + +func (r *nTimes) Read(p []byte) (n int, err error) { + for { + if r.n <= 0 || r.s == "" { + return n, io.EOF + } + n0 := copy(p, r.s[r.off:]) + p = p[n0:] + n += n0 + r.off += n0 + if r.off == len(r.s) { + r.off = 0 + r.n-- + } + if len(p) == 0 { + return + } + } +} + +// benchmarkRead measures reading the provided CSV rows data. +// initReader, if non-nil, modifies the Reader before it's used. +func benchmarkRead(b *testing.B, initReader func(*Reader), rows string) { + b.ReportAllocs() + r := NewReader(&nTimes{s: rows, n: b.N}) + if initReader != nil { + initReader(r) + } + for { + _, err := r.Read() + if err == io.EOF { + break + } + if err != nil { + b.Fatal(err) + } + } +} + +const benchmarkCSVData = `x,y,z,w +x,y,z, +x,y,, +x,,, +,,, +"x","y","z","w" +"x","y","z","" +"x","y","","" +"x","","","" +"","","","" +` + +func BenchmarkRead(b *testing.B) { + benchmarkRead(b, nil, benchmarkCSVData) +} + +func BenchmarkReadWithFieldsPerRecord(b *testing.B) { + benchmarkRead(b, func(r *Reader) { r.FieldsPerRecord = 4 }, benchmarkCSVData) +} + +func BenchmarkReadWithoutFieldsPerRecord(b *testing.B) { + benchmarkRead(b, func(r *Reader) { r.FieldsPerRecord = -1 }, benchmarkCSVData) +} + +func BenchmarkReadLargeFields(b *testing.B) { + benchmarkRead(b, nil, strings.Repeat(`xxxxxxxxxxxxxxxx,yyyyyyyyyyyyyyyy,zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz,wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww,vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv +xxxxxxxxxxxxxxxxxxxxxxxx,yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy,zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz,wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww,vvvv +,,zzzz,wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww,vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv +xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx,yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy,zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz,wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww,vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv +`, 3)) +} + +func BenchmarkReadReuseRecord(b *testing.B) { + benchmarkRead(b, func(r *Reader) { r.ReuseRecord = true }, benchmarkCSVData) +} + +func BenchmarkReadReuseRecordWithFieldsPerRecord(b *testing.B) { + benchmarkRead(b, func(r *Reader) { r.ReuseRecord = true; r.FieldsPerRecord = 4 }, benchmarkCSVData) +} + +func BenchmarkReadReuseRecordWithoutFieldsPerRecord(b *testing.B) { + benchmarkRead(b, func(r *Reader) { r.ReuseRecord = true; r.FieldsPerRecord = -1 }, benchmarkCSVData) +} + +func BenchmarkReadReuseRecordLargeFields(b *testing.B) { + benchmarkRead(b, func(r *Reader) { r.ReuseRecord = true }, strings.Repeat(`xxxxxxxxxxxxxxxx,yyyyyyyyyyyyyyyy,zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz,wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww,vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv +xxxxxxxxxxxxxxxxxxxxxxxx,yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy,zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz,wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww,vvvv +,,zzzz,wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww,vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv +xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx,yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy,zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz,wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww,vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv +`, 3)) +} diff --git a/gnovm/stdlibs/encoding/csv/writer.gno b/gnovm/stdlibs/encoding/csv/writer.gno new file mode 100644 index 00000000000..f874765bd4e --- /dev/null +++ b/gnovm/stdlibs/encoding/csv/writer.gno @@ -0,0 +1,184 @@ +package csv + +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +import ( + "bufio" + "io" + "strings" + "unicode" + "unicode/utf8" +) + +// A Writer writes records using CSV encoding. +// +// As returned by [NewWriter], a Writer writes records terminated by a +// newline and uses ',' as the field delimiter. The exported fields can be +// changed to customize the details before +// the first call to [Writer.Write] or [Writer.WriteAll]. +// +// [Writer.Comma] is the field delimiter. +// +// If [Writer.UseCRLF] is true, +// the Writer ends each output line with \r\n instead of \n. +// +// The writes of individual records are buffered. +// After all data has been written, the client should call the +// [Writer.Flush] method to guarantee all data has been forwarded to +// the underlying [io.Writer]. Any errors that occurred should +// be checked by calling the [Writer.Error] method. +type Writer struct { + Comma rune // Field delimiter (set to ',' by NewWriter) + UseCRLF bool // True to use \r\n as the line terminator + w *bufio.Writer +} + +// NewWriter returns a new Writer that writes to w. +func NewWriter(w io.Writer) *Writer { + return &Writer{ + Comma: ',', + w: bufio.NewWriter(w), + } +} + +// Write writes a single CSV record to w along with any necessary quoting. +// A record is a slice of strings with each string being one field. +// Writes are buffered, so [Writer.Flush] must eventually be called to ensure +// that the record is written to the underlying [io.Writer]. +func (w *Writer) Write(record []string) error { + if !validDelim(w.Comma) { + return ErrInvalidDelim + } + + for n, field := range record { + if n > 0 { + if _, err := w.w.WriteRune(w.Comma); err != nil { + return err + } + } + + // If we don't have to have a quoted field then just + // write out the field and continue to the next field. + if !w.fieldNeedsQuotes(field) { + if _, err := w.w.WriteString(field); err != nil { + return err + } + continue + } + + if err := w.w.WriteByte('"'); err != nil { + return err + } + for len(field) > 0 { + // Search for special characters. + i := strings.IndexAny(field, "\"\r\n") + if i < 0 { + i = len(field) + } + + // Copy verbatim everything before the special character. + if _, err := w.w.WriteString(field[:i]); err != nil { + return err + } + field = field[i:] + + // Encode the special character. + if len(field) > 0 { + var err error + switch field[0] { + case '"': + _, err = w.w.WriteString(`""`) + case '\r': + if !w.UseCRLF { + err = w.w.WriteByte('\r') + } + case '\n': + if w.UseCRLF { + _, err = w.w.WriteString("\r\n") + } else { + err = w.w.WriteByte('\n') + } + } + field = field[1:] + if err != nil { + return err + } + } + } + if err := w.w.WriteByte('"'); err != nil { + return err + } + } + var err error + if w.UseCRLF { + _, err = w.w.WriteString("\r\n") + } else { + err = w.w.WriteByte('\n') + } + return err +} + +// Flush writes any buffered data to the underlying [io.Writer]. +// To check if an error occurred during Flush, call [Writer.Error]. +func (w *Writer) Flush() { + w.w.Flush() +} + +// Error reports any error that has occurred during +// a previous [Writer.Write] or [Writer.Flush]. +func (w *Writer) Error() error { + _, err := w.w.Write(nil) + return err +} + +// WriteAll writes multiple CSV records to w using [Writer.Write] and +// then calls [Writer.Flush], returning any error from the Flush. +func (w *Writer) WriteAll(records [][]string) error { + for _, record := range records { + err := w.Write(record) + if err != nil { + return err + } + } + return w.w.Flush() +} + +// fieldNeedsQuotes reports whether our field must be enclosed in quotes. +// Fields with a Comma, fields with a quote or newline, and +// fields which start with a space must be enclosed in quotes. +// We used to quote empty strings, but we do not anymore (as of Go 1.4). +// The two representations should be equivalent, but Postgres distinguishes +// quoted vs non-quoted empty string during database imports, and it has +// an option to force the quoted behavior for non-quoted CSV but it has +// no option to force the non-quoted behavior for quoted CSV, making +// CSV with quoted empty strings strictly less useful. +// Not quoting the empty string also makes this package match the behavior +// of Microsoft Excel and Google Drive. +// For Postgres, quote the data terminating string `\.`. +func (w *Writer) fieldNeedsQuotes(field string) bool { + if field == "" { + return false + } + + if field == `\.` { + return true + } + + if w.Comma < utf8.RuneSelf { + for i := 0; i < len(field); i++ { + c := field[i] + if c == '\n' || c == '\r' || c == '"' || c == byte(w.Comma) { + return true + } + } + } else { + if strings.ContainsRune(field, w.Comma) || strings.ContainsAny(field, "\"\r\n") { + return true + } + } + + r1, _ := utf8.DecodeRuneInString(field) + return unicode.IsSpace(r1) +} diff --git a/gnovm/stdlibs/encoding/csv/writer_test.gno b/gnovm/stdlibs/encoding/csv/writer_test.gno new file mode 100644 index 00000000000..1407f3c670a --- /dev/null +++ b/gnovm/stdlibs/encoding/csv/writer_test.gno @@ -0,0 +1,134 @@ +package csv + +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +import ( + "bytes" + "encoding/csv" + "errors" + "strings" + "testing" +) + +func TestWriteCSV(t *testing.T) { + records := [][]string{ + {"first_name", "last_name", "username"}, + {"Rob", "Pike", "rob"}, + {"Ken", "Thompson", "ken"}, + {"Robert", "Griesemer", "gri"}, + } + + var buf bytes.Buffer + w := csv.NewWriter(&buf) + err := w.WriteAll(records) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expected := "first_name,last_name,username\nRob,Pike,rob\nKen,Thompson,ken\nRobert,Griesemer,gri\n" + if buf.String() != expected { + t.Errorf("Unexpected output:\ngot %q\nwant %q", buf.String(), expected) + } +} + +var writeTests = []struct { + Input [][]string + Output string + Error error + UseCRLF bool + Comma rune +}{ + {Input: [][]string{{"abc"}}, Output: "abc\n"}, + {Input: [][]string{{"abc"}}, Output: "abc\r\n", UseCRLF: true}, + {Input: [][]string{{`"abc"`}}, Output: `"""abc"""` + "\n"}, + {Input: [][]string{{`a"b`}}, Output: `"a""b"` + "\n"}, + {Input: [][]string{{`"a"b"`}}, Output: `"""a""b"""` + "\n"}, + {Input: [][]string{{" abc"}}, Output: `" abc"` + "\n"}, + {Input: [][]string{{"abc,def"}}, Output: `"abc,def"` + "\n"}, + {Input: [][]string{{"abc", "def"}}, Output: "abc,def\n"}, + {Input: [][]string{{"abc"}, {"def"}}, Output: "abc\ndef\n"}, + {Input: [][]string{{"abc\ndef"}}, Output: "\"abc\ndef\"\n"}, + {Input: [][]string{{"abc\ndef"}}, Output: "\"abc\r\ndef\"\r\n", UseCRLF: true}, + {Input: [][]string{{"abc\rdef"}}, Output: "\"abcdef\"\r\n", UseCRLF: true}, + {Input: [][]string{{"abc\rdef"}}, Output: "\"abc\rdef\"\n", UseCRLF: false}, + {Input: [][]string{{""}}, Output: "\n"}, + {Input: [][]string{{"", ""}}, Output: ",\n"}, + {Input: [][]string{{"", "", ""}}, Output: ",,\n"}, + {Input: [][]string{{"", "", "a"}}, Output: ",,a\n"}, + {Input: [][]string{{"", "a", ""}}, Output: ",a,\n"}, + {Input: [][]string{{"", "a", "a"}}, Output: ",a,a\n"}, + {Input: [][]string{{"a", "", ""}}, Output: "a,,\n"}, + {Input: [][]string{{"a", "", "a"}}, Output: "a,,a\n"}, + {Input: [][]string{{"a", "a", ""}}, Output: "a,a,\n"}, + {Input: [][]string{{"a", "a", "a"}}, Output: "a,a,a\n"}, + {Input: [][]string{{`\.`}}, Output: "\"\\.\"\n"}, + {Input: [][]string{{"x09\x41\xb4\x1c", "aktau"}}, Output: "x09\x41\xb4\x1c,aktau\n"}, + {Input: [][]string{{",x09\x41\xb4\x1c", "aktau"}}, Output: "\",x09\x41\xb4\x1c\",aktau\n"}, + {Input: [][]string{{"a", "a", ""}}, Output: "a|a|\n", Comma: '|'}, + {Input: [][]string{{",", ",", ""}}, Output: ",|,|\n", Comma: '|'}, + {Input: [][]string{{"foo"}}, Comma: '"', Error: ErrInvalidDelim}, +} + +func TestWrite(t *testing.T) { + for n, tt := range writeTests { + b := &strings.Builder{} + f := csv.NewWriter(b) + f.UseCRLF = tt.UseCRLF + if tt.Comma != 0 { + f.Comma = tt.Comma + } + err := f.WriteAll(tt.Input) + if err != tt.Error { + t.Errorf("Unexpected error:\ngot %v\nwant %v", err, tt.Error) + } + out := b.String() + if out != tt.Output { + t.Errorf("#%d: out=%q want %q", n, out, tt.Output) + } + } +} + +type errorWriter struct{} + +func (e errorWriter) Write(b []byte) (int, error) { + return 0, errors.New("Test") +} + +func TestError(t *testing.T) { + b := &bytes.Buffer{} + f := csv.NewWriter(b) + f.Write([]string{"abc"}) + f.Flush() + err := f.Error() + if err != nil { + t.Errorf("Unexpected error: %s\n", err) + } + + f = csv.NewWriter(errorWriter{}) + f.Write([]string{"abc"}) + f.Flush() + err = f.Error() + + if err == nil { + t.Error("Error should not be nil") + } +} + +var benchmarkWriteData = [][]string{ + {"abc", "def", "12356", "1234567890987654311234432141542132"}, + {"abc", "def", "12356", "1234567890987654311234432141542132"}, + {"abc", "def", "12356", "1234567890987654311234432141542132"}, +} + +func BenchmarkWrite(b *testing.B) { + for i := 0; i < b.N; i++ { + w := csv.NewWriter(&bytes.Buffer{}) + err := w.WriteAll(benchmarkWriteData) + if err != nil { + b.Fatal(err) + } + w.Flush() + } +} diff --git a/gnovm/stdlibs/generated.go b/gnovm/stdlibs/generated.go index 67b492a34b2..d5ab052028f 100644 --- a/gnovm/stdlibs/generated.go +++ b/gnovm/stdlibs/generated.go @@ -988,12 +988,13 @@ var initOrder = [...]string{ "crypto/ed25519", "crypto/sha256", "encoding", + "encoding/base32", "encoding/base64", + "encoding/csv", "encoding/hex", "hash", "hash/adler32", "html", - "math/overflow", "math/rand", "path", "sort", diff --git a/gnovm/stdlibs/math/const_test.gno b/gnovm/stdlibs/math/const_test.gno index b892a12898b..fbe59d61878 100644 --- a/gnovm/stdlibs/math/const_test.gno +++ b/gnovm/stdlibs/math/const_test.gno @@ -31,19 +31,76 @@ func TestMaxUint(t *testing.T) { } func TestMaxInt(t *testing.T) { - if v := int(math.MaxInt); v+1 != math.MinInt { - t.Errorf("MaxInt should wrap around to MinInt: %d", v+1) + defer func() { + if r := recover(); r != nil { + if r != "addition overflow" { + panic(r) + } + } + }() + v := int(math.MaxInt) + if v+1 == math.MinInt { + t.Errorf("int should overflow") } - if v := int8(math.MaxInt8); v+1 != math.MinInt8 { - t.Errorf("MaxInt8 should wrap around to MinInt8: %d", v+1) + t.Errorf("expected panic did not occur") +} + +func TestMaxInt8(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if r != "addition overflow" { + panic(r) + } + } + }() + v := int8(math.MaxInt8) + if v+1 == math.MinInt8 { + t.Errorf("int8 should overflow") } - if v := int16(math.MaxInt16); v+1 != math.MinInt16 { - t.Errorf("MaxInt16 should wrap around to MinInt16: %d", v+1) + t.Errorf("expected panic did not occur") +} + +func TestMaxInt16(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if r != "addition overflow" { + panic(r) + } + } + }() + v := int16(math.MaxInt16) + if v+1 == math.MinInt16 { + t.Errorf("int16 should overflow") } - if v := int32(math.MaxInt32); v+1 != math.MinInt32 { - t.Errorf("MaxInt32 should wrap around to MinInt32: %d", v+1) + t.Errorf("expected panic did not occur") +} + +func TestMaxInt32(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if r != "addition overflow" { + panic(r) + } + } + }() + v := int32(math.MaxInt32) + if v+1 == math.MinInt32 { + t.Errorf("int32 should overflow") } - if v := int64(math.MaxInt64); v+1 != math.MinInt64 { - t.Errorf("MaxInt64 should wrap around to MinInt64: %d", v+1) + t.Errorf("expected panic did not occur") +} + +func TestMaxInt64(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if r != "addition overflow" { + panic(r) + } + } + }() + v := int64(math.MaxInt64) + if v+1 == math.MinInt64 { + t.Errorf("int64 should overflow") } + t.Errorf("expected panic did not occur") } diff --git a/gnovm/stdlibs/math/overflow/overflow.gno b/gnovm/stdlibs/math/overflow/overflow.gno deleted file mode 100644 index 0bc2e03a522..00000000000 --- a/gnovm/stdlibs/math/overflow/overflow.gno +++ /dev/null @@ -1,501 +0,0 @@ -// This is modified from https://github.com/JohnCGriffin/overflow (MIT). -// NOTE: there was a bug with the original Quotient* functions, and -// testing method. These have been fixed here, and tests ported to -// tests/files/maths_int*.go respectively. -// Note: moved over from p/demo/maths. - -/* -Package overflow offers overflow-checked integer arithmetic operations -for int, int32, and int64. Each of the operations returns a -result,bool combination. This was prompted by the need to know when -to flow into higher precision types from the math.big library. - -For instance, assuing a 64 bit machine: - -10 + 20 -> 30 -int(math.MaxInt64) + 1 -> -9223372036854775808 - -whereas - -overflow.Add(10,20) -> (30, true) -overflow.Add(math.MaxInt64,1) -> (0, false) - -Add, Sub, Mul, Div are for int. Add64, Add32, etc. are specifically sized. - -If anybody wishes an unsigned version, submit a pull request for code -and new tests. -*/ -package overflow - -import "math" - -//go:generate ./overflow_template.sh - -func _is64Bit() bool { - maxU32 := uint(math.MaxUint32) - return ((maxU32 << 1) >> 1) == maxU32 -} - -/********** PARTIAL TEST COVERAGE FROM HERE DOWN ************* - -The only way that I could see to do this is a combination of -my normal 64 bit system and a GopherJS running on Node. My -understanding is that its ints are 32 bit. - -So, FEEL FREE to carefully review the code visually. - -*************************************************************/ - -// Unspecified size, i.e. normal signed int - -// Add sums two ints, returning the result and a boolean status. -func Add(a, b int) (int, bool) { - if _is64Bit() { - r64, ok := Add64(int64(a), int64(b)) - return int(r64), ok - } - r32, ok := Add32(int32(a), int32(b)) - return int(r32), ok -} - -// Sub returns the difference of two ints and a boolean status. -func Sub(a, b int) (int, bool) { - if _is64Bit() { - r64, ok := Sub64(int64(a), int64(b)) - return int(r64), ok - } - r32, ok := Sub32(int32(a), int32(b)) - return int(r32), ok -} - -// Mul returns the product of two ints and a boolean status. -func Mul(a, b int) (int, bool) { - if _is64Bit() { - r64, ok := Mul64(int64(a), int64(b)) - return int(r64), ok - } - r32, ok := Mul32(int32(a), int32(b)) - return int(r32), ok -} - -// Div returns the quotient of two ints and a boolean status -func Div(a, b int) (int, bool) { - if _is64Bit() { - r64, ok := Div64(int64(a), int64(b)) - return int(r64), ok - } - r32, ok := Div32(int32(a), int32(b)) - return int(r32), ok -} - -// Quo returns the quotient, remainder and status of two ints -func Quo(a, b int) (int, int, bool) { - if _is64Bit() { - q64, r64, ok := Quo64(int64(a), int64(b)) - return int(q64), int(r64), ok - } - q32, r32, ok := Quo32(int32(a), int32(b)) - return int(q32), int(r32), ok -} - -/************* Panic versions for int ****************/ - -// Addp returns the sum of two ints, panicking on overflow -func Addp(a, b int) int { - r, ok := Add(a, b) - if !ok { - panic("addition overflow") - } - return r -} - -// Subp returns the difference of two ints, panicking on overflow. -func Subp(a, b int) int { - r, ok := Sub(a, b) - if !ok { - panic("subtraction overflow") - } - return r -} - -// Mulp returns the product of two ints, panicking on overflow. -func Mulp(a, b int) int { - r, ok := Mul(a, b) - if !ok { - panic("multiplication overflow") - } - return r -} - -// Divp returns the quotient of two ints, panicking on overflow. -func Divp(a, b int) int { - r, ok := Div(a, b) - if !ok { - panic("division failure") - } - return r -} - -//---------------------------------------- -// This is generated code, created by overflow_template.sh executed -// by "go generate" - -// Add8 performs + operation on two int8 operands -// returning a result and status -func Add8(a, b int8) (int8, bool) { - c := a + b - if (c > a) == (b > 0) { - return c, true - } - return c, false -} - -// Add8p is the unchecked panicking version of Add8 -func Add8p(a, b int8) int8 { - r, ok := Add8(a, b) - if !ok { - panic("addition overflow") - } - return r -} - -// Sub8 performs - operation on two int8 operands -// returning a result and status -func Sub8(a, b int8) (int8, bool) { - c := a - b - if (c < a) == (b > 0) { - return c, true - } - return c, false -} - -// Sub8p is the unchecked panicking version of Sub8 -func Sub8p(a, b int8) int8 { - r, ok := Sub8(a, b) - if !ok { - panic("subtraction overflow") - } - return r -} - -// Mul8 performs * operation on two int8 operands -// returning a result and status -func Mul8(a, b int8) (int8, bool) { - if a == 0 || b == 0 { - return 0, true - } - c := a * b - if (c < 0) == ((a < 0) != (b < 0)) { - if c/b == a { - return c, true - } - } - return c, false -} - -// Mul8p is the unchecked panicking version of Mul8 -func Mul8p(a, b int8) int8 { - r, ok := Mul8(a, b) - if !ok { - panic("multiplication overflow") - } - return r -} - -// Div8 performs / operation on two int8 operands -// returning a result and status -func Div8(a, b int8) (int8, bool) { - q, _, ok := Quo8(a, b) - return q, ok -} - -// Div8p is the unchecked panicking version of Div8 -func Div8p(a, b int8) int8 { - r, ok := Div8(a, b) - if !ok { - panic("division failure") - } - return r -} - -// Quo8 performs + operation on two int8 operands -// returning a quotient, a remainder and status -func Quo8(a, b int8) (int8, int8, bool) { - if b == 0 { - return 0, 0, false - } else if b == -1 && a == int8(math.MinInt8) { - return 0, 0, false - } - c := a / b - return c, a % b, true -} - -// Add16 performs + operation on two int16 operands -// returning a result and status -func Add16(a, b int16) (int16, bool) { - c := a + b - if (c > a) == (b > 0) { - return c, true - } - return c, false -} - -// Add16p is the unchecked panicking version of Add16 -func Add16p(a, b int16) int16 { - r, ok := Add16(a, b) - if !ok { - panic("addition overflow") - } - return r -} - -// Sub16 performs - operation on two int16 operands -// returning a result and status -func Sub16(a, b int16) (int16, bool) { - c := a - b - if (c < a) == (b > 0) { - return c, true - } - return c, false -} - -// Sub16p is the unchecked panicking version of Sub16 -func Sub16p(a, b int16) int16 { - r, ok := Sub16(a, b) - if !ok { - panic("subtraction overflow") - } - return r -} - -// Mul16 performs * operation on two int16 operands -// returning a result and status -func Mul16(a, b int16) (int16, bool) { - if a == 0 || b == 0 { - return 0, true - } - c := a * b - if (c < 0) == ((a < 0) != (b < 0)) { - if c/b == a { - return c, true - } - } - return c, false -} - -// Mul16p is the unchecked panicking version of Mul16 -func Mul16p(a, b int16) int16 { - r, ok := Mul16(a, b) - if !ok { - panic("multiplication overflow") - } - return r -} - -// Div16 performs / operation on two int16 operands -// returning a result and status -func Div16(a, b int16) (int16, bool) { - q, _, ok := Quo16(a, b) - return q, ok -} - -// Div16p is the unchecked panicking version of Div16 -func Div16p(a, b int16) int16 { - r, ok := Div16(a, b) - if !ok { - panic("division failure") - } - return r -} - -// Quo16 performs + operation on two int16 operands -// returning a quotient, a remainder and status -func Quo16(a, b int16) (int16, int16, bool) { - if b == 0 { - return 0, 0, false - } else if b == -1 && a == int16(math.MinInt16) { - return 0, 0, false - } - c := a / b - return c, a % b, true -} - -// Add32 performs + operation on two int32 operands -// returning a result and status -func Add32(a, b int32) (int32, bool) { - c := a + b - if (c > a) == (b > 0) { - return c, true - } - return c, false -} - -// Add32p is the unchecked panicking version of Add32 -func Add32p(a, b int32) int32 { - r, ok := Add32(a, b) - if !ok { - panic("addition overflow") - } - return r -} - -// Sub32 performs - operation on two int32 operands -// returning a result and status -func Sub32(a, b int32) (int32, bool) { - c := a - b - if (c < a) == (b > 0) { - return c, true - } - return c, false -} - -// Sub32p is the unchecked panicking version of Sub32 -func Sub32p(a, b int32) int32 { - r, ok := Sub32(a, b) - if !ok { - panic("subtraction overflow") - } - return r -} - -// Mul32 performs * operation on two int32 operands -// returning a result and status -func Mul32(a, b int32) (int32, bool) { - if a == 0 || b == 0 { - return 0, true - } - c := a * b - if (c < 0) == ((a < 0) != (b < 0)) { - if c/b == a { - return c, true - } - } - return c, false -} - -// Mul32p is the unchecked panicking version of Mul32 -func Mul32p(a, b int32) int32 { - r, ok := Mul32(a, b) - if !ok { - panic("multiplication overflow") - } - return r -} - -// Div32 performs / operation on two int32 operands -// returning a result and status -func Div32(a, b int32) (int32, bool) { - q, _, ok := Quo32(a, b) - return q, ok -} - -// Div32p is the unchecked panicking version of Div32 -func Div32p(a, b int32) int32 { - r, ok := Div32(a, b) - if !ok { - panic("division failure") - } - return r -} - -// Quo32 performs + operation on two int32 operands -// returning a quotient, a remainder and status -func Quo32(a, b int32) (int32, int32, bool) { - if b == 0 { - return 0, 0, false - } else if b == -1 && a == int32(math.MinInt32) { - return 0, 0, false - } - c := a / b - return c, a % b, true -} - -// Add64 performs + operation on two int64 operands -// returning a result and status -func Add64(a, b int64) (int64, bool) { - c := a + b - if (c > a) == (b > 0) { - return c, true - } - return c, false -} - -// Add64p is the unchecked panicking version of Add64 -func Add64p(a, b int64) int64 { - r, ok := Add64(a, b) - if !ok { - panic("addition overflow") - } - return r -} - -// Sub64 performs - operation on two int64 operands -// returning a result and status -func Sub64(a, b int64) (int64, bool) { - c := a - b - if (c < a) == (b > 0) { - return c, true - } - return c, false -} - -// Sub64p is the unchecked panicking version of Sub64 -func Sub64p(a, b int64) int64 { - r, ok := Sub64(a, b) - if !ok { - panic("subtraction overflow") - } - return r -} - -// Mul64 performs * operation on two int64 operands -// returning a result and status -func Mul64(a, b int64) (int64, bool) { - if a == 0 || b == 0 { - return 0, true - } - c := a * b - if (c < 0) == ((a < 0) != (b < 0)) { - if c/b == a { - return c, true - } - } - return c, false -} - -// Mul64p is the unchecked panicking version of Mul64 -func Mul64p(a, b int64) int64 { - r, ok := Mul64(a, b) - if !ok { - panic("multiplication overflow") - } - return r -} - -// Div64 performs / operation on two int64 operands -// returning a result and status -func Div64(a, b int64) (int64, bool) { - q, _, ok := Quo64(a, b) - return q, ok -} - -// Div64p is the unchecked panicking version of Div64 -func Div64p(a, b int64) int64 { - r, ok := Div64(a, b) - if !ok { - panic("division failure") - } - return r -} - -// Quo64 performs + operation on two int64 operands -// returning a quotient, a remainder and status -func Quo64(a, b int64) (int64, int64, bool) { - if b == 0 { - return 0, 0, false - } else if b == -1 && a == math.MinInt64 { - return 0, 0, false - } - c := a / b - return c, a % b, true -} diff --git a/gnovm/stdlibs/math/overflow/overflow_test.gno b/gnovm/stdlibs/math/overflow/overflow_test.gno deleted file mode 100644 index b7881aec480..00000000000 --- a/gnovm/stdlibs/math/overflow/overflow_test.gno +++ /dev/null @@ -1,200 +0,0 @@ -package overflow - -import ( - "math" - "testing" -) - -// sample all possibilities of 8 bit numbers -// by checking against 64 bit numbers - -func TestAlgorithms(t *testing.T) { - errors := 0 - - for a64 := int64(math.MinInt8); a64 <= int64(math.MaxInt8); a64++ { - for b64 := int64(math.MinInt8); b64 <= int64(math.MaxInt8) && errors < 10; b64++ { - - a8 := int8(a64) - b8 := int8(b64) - - if int64(a8) != a64 || int64(b8) != b64 { - t.Fatal("LOGIC FAILURE IN TEST") - } - - // ADDITION - { - r64 := a64 + b64 - - // now the verification - result, ok := Add8(a8, b8) - if ok && int64(result) != r64 { - t.Errorf("failed to fail on %v + %v = %v instead of %v\n", - a8, b8, result, r64) - errors++ - } - if !ok && int64(result) == r64 { - t.Fail() - errors++ - } - } - - // SUBTRACTION - { - r64 := a64 - b64 - - // now the verification - result, ok := Sub8(a8, b8) - if ok && int64(result) != r64 { - t.Errorf("failed to fail on %v - %v = %v instead of %v\n", - a8, b8, result, r64) - } - if !ok && int64(result) == r64 { - t.Fail() - errors++ - } - } - - // MULTIPLICATION - { - r64 := a64 * b64 - - // now the verification - result, ok := Mul8(a8, b8) - if ok && int64(result) != r64 { - t.Errorf("failed to fail on %v * %v = %v instead of %v\n", - a8, b8, result, r64) - errors++ - } - if !ok && int64(result) == r64 { - t.Fail() - errors++ - } - } - - // DIVISION - if b8 != 0 { - r64 := a64 / b64 - rem64 := a64 % b64 - - // now the verification - result, rem, ok := Quo8(a8, b8) - if ok && int64(result) != r64 { - t.Errorf("failed to fail on %v / %v = %v instead of %v\n", - a8, b8, result, r64) - errors++ - } - if ok && int64(rem) != rem64 { - t.Errorf("failed to fail on %v %% %v = %v instead of %v\n", - a8, b8, rem, rem64) - errors++ - } - } - } - } -} - -func TestQuotient(t *testing.T) { - q, r, ok := Quo(100, 3) - if r != 1 || q != 33 || !ok { - t.Errorf("expected 100/3 => 33, r=1") - } - if _, _, ok = Quo(1, 0); ok { - t.Error("unexpected lack of failure") - } -} - -func TestLong(t *testing.T) { - if testing.Short() { - t.Skip() - } - - ctr := int64(0) - - for a64 := int64(math.MinInt16); a64 <= int64(math.MaxInt16); a64++ { - for b64 := int64(math.MinInt16); b64 <= int64(math.MaxInt16); b64++ { - a16 := int16(a64) - b16 := int16(b64) - if int64(a16) != a64 || int64(b16) != b64 { - panic("LOGIC FAILURE IN TEST") - } - ctr++ - - // ADDITION - { - r64 := a64 + b64 - - // now the verification - result, ok := Add16(a16, b16) - if int64(math.MinInt16) <= r64 && r64 <= int64(math.MaxInt16) { - if !ok || int64(result) != r64 { - println("add", a16, b16, result, r64) - panic("incorrect result for non-overflow") - } - } else { - if ok { - println("add", a16, b16, result, r64) - panic("incorrect ok result") - } - } - } - - // SUBTRACTION - { - r64 := a64 - b64 - - // now the verification - result, ok := Sub16(a16, b16) - if int64(math.MinInt16) <= r64 && r64 <= int64(math.MaxInt16) { - if !ok || int64(result) != r64 { - println("sub", a16, b16, result, r64) - panic("incorrect result for non-overflow") - } - } else { - if ok { - println("sub", a16, b16, result, r64) - panic("incorrect ok result") - } - } - } - - // MULTIPLICATION - { - r64 := a64 * b64 - - // now the verification - result, ok := Mul16(a16, b16) - if int64(math.MinInt16) <= r64 && r64 <= int64(math.MaxInt16) { - if !ok || int64(result) != r64 { - println("mul", a16, b16, result, r64) - panic("incorrect result for non-overflow") - } - } else { - if ok { - println("mul", a16, b16, result, r64) - panic("incorrect ok result") - } - } - } - - // DIVISION - if b16 != 0 { - r64 := a64 / b64 - - // now the verification - result, _, ok := Quo16(a16, b16) - if int64(math.MinInt16) <= r64 && r64 <= int64(math.MaxInt16) { - if !ok || int64(result) != r64 { - println("quo", a16, b16, result, r64) - panic("incorrect result for non-overflow") - } - } else { - if ok { - println("quo", a16, b16, result, r64) - panic("incorrect ok result") - } - } - } - } - } - println("done", ctr) -} diff --git a/gnovm/stdlibs/std/coins.gno b/gnovm/stdlibs/std/coins.gno index 47e88e238d2..679674e443e 100644 --- a/gnovm/stdlibs/std/coins.gno +++ b/gnovm/stdlibs/std/coins.gno @@ -1,9 +1,6 @@ package std -import ( - "math/overflow" - "strconv" -) +import "strconv" // NOTE: this is selectively copied over from tm2/pkgs/std/coin.go @@ -56,13 +53,7 @@ func (c Coin) IsEqual(other Coin) bool { // An invalid result panics. func (c Coin) Add(other Coin) Coin { mustMatchDenominations(c.Denom, other.Denom) - - sum, ok := overflow.Add64(c.Amount, other.Amount) - if !ok { - panic("coin add overflow/underflow: " + strconv.Itoa(int(c.Amount)) + " +/- " + strconv.Itoa(int(other.Amount))) - } - - c.Amount = sum + c.Amount += other.Amount return c } @@ -72,13 +63,7 @@ func (c Coin) Add(other Coin) Coin { // An invalid result panics. func (c Coin) Sub(other Coin) Coin { mustMatchDenominations(c.Denom, other.Denom) - - dff, ok := overflow.Sub64(c.Amount, other.Amount) - if !ok { - panic("coin sub overflow/underflow: " + strconv.Itoa(int(c.Amount)) + " +/- " + strconv.Itoa(int(other.Amount))) - } - c.Amount = dff - + c.Amount -= other.Amount return c } @@ -113,10 +98,7 @@ func NewCoins(coins ...Coin) Coins { for _, coin := range coins { if currentAmount, exists := coinMap[coin.Denom]; exists { - var ok bool - if coinMap[coin.Denom], ok = overflow.Add64(currentAmount, coin.Amount); !ok { - panic("coin sub overflow/underflow: " + strconv.Itoa(int(currentAmount)) + " +/- " + strconv.Itoa(int(coin.Amount))) - } + coinMap[coin.Denom] = currentAmount + coin.Amount } else { coinMap[coin.Denom] = coin.Amount } diff --git a/gnovm/tests/files/overflow0.gno b/gnovm/tests/files/overflow0.gno new file mode 100644 index 00000000000..1313f064322 --- /dev/null +++ b/gnovm/tests/files/overflow0.gno @@ -0,0 +1,10 @@ +package main + +func main() { + var a, b, c int8 = -1<<7, -1, 0 + c = a / b // overflow: -128 instead of 128 + println(c) +} + +// Error: +// division by zero or overflow diff --git a/gnovm/tests/files/overflow1.gno b/gnovm/tests/files/overflow1.gno new file mode 100644 index 00000000000..a416e9a3498 --- /dev/null +++ b/gnovm/tests/files/overflow1.gno @@ -0,0 +1,10 @@ +package main + +func main() { + var a, b, c int16 = -1<<15, -1, 0 + c = a / b // overflow: -32768 instead of 32768 + println(c) +} + +// Error: +// division by zero or overflow diff --git a/gnovm/tests/files/overflow2.gno b/gnovm/tests/files/overflow2.gno new file mode 100644 index 00000000000..353729bcdf2 --- /dev/null +++ b/gnovm/tests/files/overflow2.gno @@ -0,0 +1,10 @@ +package main + +func main() { + var a, b, c int32 = -1<<31, -1, 0 + c = a / b // overflow: -2147483648 instead of 2147483648 + println(c) +} + +// Error: +// division by zero or overflow diff --git a/gnovm/tests/files/overflow3.gno b/gnovm/tests/files/overflow3.gno new file mode 100644 index 00000000000..a09c59dfb03 --- /dev/null +++ b/gnovm/tests/files/overflow3.gno @@ -0,0 +1,10 @@ +package main + +func main() { + var a, b, c int64 = -1<<63, -1, 0 + c = a / b // overflow: -9223372036854775808 instead of 9223372036854775808 + println(c) +} + +// Error: +// division by zero or overflow diff --git a/gnovm/tests/files/overflow4.gno b/gnovm/tests/files/overflow4.gno new file mode 100644 index 00000000000..26b05567b07 --- /dev/null +++ b/gnovm/tests/files/overflow4.gno @@ -0,0 +1,10 @@ +package main + +func main() { + var a, b, c int = -1<<63, -1, 0 + c = a / b // overflow: -9223372036854775808 instead of 9223372036854775808 + println(c) +} + +// Error: +// division by zero or overflow diff --git a/gnovm/tests/files/overflow5.gno b/gnovm/tests/files/overflow5.gno new file mode 100644 index 00000000000..ef7f976eb24 --- /dev/null +++ b/gnovm/tests/files/overflow5.gno @@ -0,0 +1,10 @@ +package main + +func main() { + var a, b, c int = -5, 7, 0 + c = a % b // 0 quotient triggers a false negative in gnolang/overflow + println(c) +} + +// Output: +// -5 diff --git a/gnovm/tests/files/recover14.gno b/gnovm/tests/files/recover14.gno index 30a34ab291a..3c96404fcbe 100644 --- a/gnovm/tests/files/recover14.gno +++ b/gnovm/tests/files/recover14.gno @@ -12,4 +12,4 @@ func main() { } // Output: -// recover: division by zero +// recover: division by zero or overflow diff --git a/go.mod b/go.mod index bf7c2df94b6..cd038e2ae65 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/davecgh/go-spew v1.1.1 github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 github.com/fortytw2/leaktest v1.3.0 + github.com/google/go-cmp v0.6.0 github.com/google/gofuzz v1.2.0 github.com/gorilla/websocket v1.5.3 github.com/libp2p/go-buffer-pool v0.1.0 @@ -23,9 +24,11 @@ require ( github.com/rogpeppe/go-internal v1.12.0 github.com/rs/cors v1.11.1 github.com/rs/xid v1.6.0 + github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b github.com/stretchr/testify v1.9.0 github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 github.com/yuin/goldmark v1.7.2 + github.com/yuin/goldmark-highlighting/v2 v2.0.0-20230729083705-37449abec8cc go.etcd.io/bbolt v1.3.11 go.opentelemetry.io/otel v1.29.0 go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.29.0 diff --git a/go.sum b/go.sum index 917270fd5a6..9c4d20dbad6 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,10 @@ dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/aead/siphash v1.0.1/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII= github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE= github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/chroma/v2 v2.2.0/go.mod h1:vf4zrexSH54oEjJ7EdB65tGNHmH3pGZmVkgTP5RHvAs= github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E= github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I= +github.com/alecthomas/repr v0.0.0-20220113201626-b1b626ac65ae/go.mod h1:2kn6fqh/zIyPLmm3ugklbEi5hg5wS435eygvNfaDQL8= github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ= @@ -51,6 +53,8 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeC github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= github.com/decred/dcrd/lru v1.0.0/go.mod h1:mxKOwFd7lFjN2GZYsiz/ecgqR6kkYAl+0pz0tEMk218= +github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= +github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= @@ -131,6 +135,8 @@ github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b h1:oV47z+jotrLVvhiLRNzACVe7/qZ8DcRlMlDucR/FARo= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b/go.mod h1:JprPCeMgYyLKJoAy9nxpVScm7NwFSwpibdrUKm4kcw0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -142,8 +148,11 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= +github.com/yuin/goldmark v1.4.15/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.7.2 h1:NjGd7lO7zrUn/A7eKwn5PEOt4ONYGqpxSEeZuduvgxc= github.com/yuin/goldmark v1.7.2/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +github.com/yuin/goldmark-highlighting/v2 v2.0.0-20230729083705-37449abec8cc h1:+IAOyRda+RLrxa1WC7umKOZRsGq4QrFFMYApOeHzQwQ= +github.com/yuin/goldmark-highlighting/v2 v2.0.0-20230729083705-37449abec8cc/go.mod h1:ovIvrum6DQJA4QsJSovrkC4saKHQVs7TvcaeO8AIl5I= github.com/zondax/hid v0.9.2 h1:WCJFnEDMiqGF64nlZz28E9qLVZ0KSJ7xpc5DLEyma2U= github.com/zondax/hid v0.9.2/go.mod h1:l5wttcP0jwtdLjqjMMWFVEE7d1zO0jvSPA9OPZxWpEM= github.com/zondax/ledger-go v0.14.3 h1:wEpJt2CEcBJ428md/5MgSLsXLBos98sBOyxNmCjfUCw= diff --git a/misc/autocounterd/go.mod b/misc/autocounterd/go.mod index 30a6f23b458..82e8b0081ce 100644 --- a/misc/autocounterd/go.mod +++ b/misc/autocounterd/go.mod @@ -26,6 +26,7 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rs/xid v1.6.0 // indirect + github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b // indirect github.com/stretchr/testify v1.9.0 // indirect github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 // indirect github.com/zondax/hid v0.9.2 // indirect @@ -43,6 +44,7 @@ require ( golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/net v0.28.0 // indirect + golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.24.0 // indirect golang.org/x/term v0.23.0 // indirect golang.org/x/text v0.17.0 // indirect diff --git a/misc/autocounterd/go.sum b/misc/autocounterd/go.sum index 28959bf214e..bd88dd5d08c 100644 --- a/misc/autocounterd/go.sum +++ b/misc/autocounterd/go.sum @@ -120,6 +120,8 @@ github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b h1:oV47z+jotrLVvhiLRNzACVe7/qZ8DcRlMlDucR/FARo= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b/go.mod h1:JprPCeMgYyLKJoAy9nxpVScm7NwFSwpibdrUKm4kcw0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -155,10 +157,6 @@ go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeX go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= -go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -go.uber.org/zap/exp v0.3.0 h1:6JYzdifzYkGmTdRR59oYH+Ng7k49H9qVpWwNSsGJj3U= -go.uber.org/zap/exp v0.3.0/go.mod h1:5I384qq7XGxYyByIhHm6jg5CHkGY0nsTfbDLgDDlgJQ= golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -176,6 +174,8 @@ golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/misc/docs-linter/jsx.go b/misc/docs-linter/jsx.go index d0307680a0c..eb041a78386 100644 --- a/misc/docs-linter/jsx.go +++ b/misc/docs-linter/jsx.go @@ -50,7 +50,7 @@ func lintJSX(filepathToJSX map[string][]string) (string, error) { found = true } - output.WriteString(fmt.Sprintf(">>> %s (found in file: %s)\n", tag, filePath)) + fmt.Fprintf(&output, ">>> %s (found in file: %s)\n", tag, filePath) } } diff --git a/misc/docs-linter/links.go b/misc/docs-linter/links.go index 744917d8dfb..e34d35d9f58 100644 --- a/misc/docs-linter/links.go +++ b/misc/docs-linter/links.go @@ -80,7 +80,7 @@ func lintLocalLinks(filepathToLinks map[string][]string, docsPath string) (strin found = true } - output.WriteString(fmt.Sprintf(">>> %s (found in file: %s)\n", link, filePath)) + fmt.Fprintf(&output, ">>> %s (found in file: %s)\n", link, filePath) } } } diff --git a/misc/docs-linter/main.go b/misc/docs-linter/main.go index 97d80316108..5d7cdf37982 100644 --- a/misc/docs-linter/main.go +++ b/misc/docs-linter/main.go @@ -61,8 +61,8 @@ func execLint(cfg *cfg, ctx context.Context) (string, error) { } // Main buffer to write to the end user after linting - var output bytes.Buffer - output.WriteString(fmt.Sprintf("Linting %s...\n", absPath)) + var output bytes.Buffer + fmt.Fprintf(&output, "Linting %s...\n", absPath) // Find docs files to lint mdFiles, err := findFilePaths(cfg.docsPath) diff --git a/misc/docs-linter/urls.go b/misc/docs-linter/urls.go index 093e624d81e..098d0a05524 100644 --- a/misc/docs-linter/urls.go +++ b/misc/docs-linter/urls.go @@ -66,7 +66,7 @@ func lintURLs(filepathToURLs map[string][]string, ctx context.Context) (string, found = true } - output.WriteString(fmt.Sprintf(">>> %s (found in file: %s)\n", url, filePath)) + fmt.Fprintf(&output, ">>> %s (found in file: %s)\n", url, filePath) lock.Unlock() } diff --git a/misc/genstd/util.go b/misc/genstd/util.go index 025fe4b673e..13e90836f36 100644 --- a/misc/genstd/util.go +++ b/misc/genstd/util.go @@ -70,7 +70,8 @@ func findDirs() (gitRoot string, relPath string, err error) { } p := wd for { - if s, e := os.Stat(filepath.Join(p, ".git")); e == nil && s.IsDir() { + // .git is normally a directory, or a file in case of a git worktree. + if _, e := os.Stat(filepath.Join(p, ".git")); e == nil { // make relPath relative to the git root rp := strings.TrimPrefix(wd, p+string(filepath.Separator)) // normalize separator to / diff --git a/misc/loop/go.mod b/misc/loop/go.mod index af7783e57bb..fc2c5daac59 100644 --- a/misc/loop/go.mod +++ b/misc/loop/go.mod @@ -52,6 +52,7 @@ require ( github.com/prometheus/procfs v0.11.1 // indirect github.com/rs/cors v1.11.1 // indirect github.com/rs/xid v1.6.0 // indirect + github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b // indirect github.com/stretchr/testify v1.9.0 // indirect github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 // indirect go.etcd.io/bbolt v1.3.11 // indirect @@ -68,6 +69,7 @@ require ( golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/net v0.28.0 // indirect + golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.24.0 // indirect golang.org/x/term v0.23.0 // indirect golang.org/x/text v0.17.0 // indirect diff --git a/misc/loop/go.sum b/misc/loop/go.sum index 0d235f2cfb1..1ed786fb82d 100644 --- a/misc/loop/go.sum +++ b/misc/loop/go.sum @@ -162,6 +162,8 @@ github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b h1:oV47z+jotrLVvhiLRNzACVe7/qZ8DcRlMlDucR/FARo= +github.com/sig-0/insertion-queue v0.0.0-20241004125609-6b3ca841346b/go.mod h1:JprPCeMgYyLKJoAy9nxpVScm7NwFSwpibdrUKm4kcw0= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/tm2/pkg/bft/blockchain/pool.go b/tm2/pkg/bft/blockchain/pool.go index b610a0c0e7a..5fdd23af910 100644 --- a/tm2/pkg/bft/blockchain/pool.go +++ b/tm2/pkg/bft/blockchain/pool.go @@ -12,7 +12,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/bft/types" "github.com/gnolang/gno/tm2/pkg/flow" "github.com/gnolang/gno/tm2/pkg/log" - "github.com/gnolang/gno/tm2/pkg/p2p" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/gnolang/gno/tm2/pkg/service" ) @@ -69,7 +69,7 @@ type BlockPool struct { requesters map[int64]*bpRequester height int64 // the lowest key in requesters. // peers - peers map[p2p.ID]*bpPeer + peers map[p2pTypes.ID]*bpPeer maxPeerHeight int64 // the biggest reported height // atomic @@ -83,7 +83,7 @@ type BlockPool struct { // requests and errors will be sent to requestsCh and errorsCh accordingly. func NewBlockPool(start int64, requestsCh chan<- BlockRequest, errorsCh chan<- peerError) *BlockPool { bp := &BlockPool{ - peers: make(map[p2p.ID]*bpPeer), + peers: make(map[p2pTypes.ID]*bpPeer), requesters: make(map[int64]*bpRequester), height: start, @@ -226,13 +226,13 @@ func (pool *BlockPool) PopRequest() { // RedoRequest invalidates the block at pool.height, // Remove the peer and redo request from others. // Returns the ID of the removed peer. -func (pool *BlockPool) RedoRequest(height int64) p2p.ID { +func (pool *BlockPool) RedoRequest(height int64) p2pTypes.ID { pool.mtx.Lock() defer pool.mtx.Unlock() request := pool.requesters[height] peerID := request.getPeerID() - if peerID != p2p.ID("") { + if peerID != p2pTypes.ID("") { // RemovePeer will redo all requesters associated with this peer. pool.removePeer(peerID) } @@ -241,7 +241,7 @@ func (pool *BlockPool) RedoRequest(height int64) p2p.ID { // AddBlock validates that the block comes from the peer it was expected from and calls the requester to store it. // TODO: ensure that blocks come in order for each peer. -func (pool *BlockPool) AddBlock(peerID p2p.ID, block *types.Block, blockSize int) { +func (pool *BlockPool) AddBlock(peerID p2pTypes.ID, block *types.Block, blockSize int) { pool.mtx.Lock() defer pool.mtx.Unlock() @@ -278,7 +278,7 @@ func (pool *BlockPool) MaxPeerHeight() int64 { } // SetPeerHeight sets the peer's alleged blockchain height. -func (pool *BlockPool) SetPeerHeight(peerID p2p.ID, height int64) { +func (pool *BlockPool) SetPeerHeight(peerID p2pTypes.ID, height int64) { pool.mtx.Lock() defer pool.mtx.Unlock() @@ -298,14 +298,14 @@ func (pool *BlockPool) SetPeerHeight(peerID p2p.ID, height int64) { // RemovePeer removes the peer with peerID from the pool. If there's no peer // with peerID, function is a no-op. -func (pool *BlockPool) RemovePeer(peerID p2p.ID) { +func (pool *BlockPool) RemovePeer(peerID p2pTypes.ID) { pool.mtx.Lock() defer pool.mtx.Unlock() pool.removePeer(peerID) } -func (pool *BlockPool) removePeer(peerID p2p.ID) { +func (pool *BlockPool) removePeer(peerID p2pTypes.ID) { for _, requester := range pool.requesters { if requester.getPeerID() == peerID { requester.redo(peerID) @@ -386,14 +386,14 @@ func (pool *BlockPool) requestersLen() int64 { return int64(len(pool.requesters)) } -func (pool *BlockPool) sendRequest(height int64, peerID p2p.ID) { +func (pool *BlockPool) sendRequest(height int64, peerID p2pTypes.ID) { if !pool.IsRunning() { return } pool.requestsCh <- BlockRequest{height, peerID} } -func (pool *BlockPool) sendError(err error, peerID p2p.ID) { +func (pool *BlockPool) sendError(err error, peerID p2pTypes.ID) { if !pool.IsRunning() { return } @@ -424,7 +424,7 @@ func (pool *BlockPool) debug() string { type bpPeer struct { pool *BlockPool - id p2p.ID + id p2pTypes.ID recvMonitor *flow.Monitor height int64 @@ -435,7 +435,7 @@ type bpPeer struct { logger *slog.Logger } -func newBPPeer(pool *BlockPool, peerID p2p.ID, height int64) *bpPeer { +func newBPPeer(pool *BlockPool, peerID p2pTypes.ID, height int64) *bpPeer { peer := &bpPeer{ pool: pool, id: peerID, @@ -499,10 +499,10 @@ type bpRequester struct { pool *BlockPool height int64 gotBlockCh chan struct{} - redoCh chan p2p.ID // redo may send multitime, add peerId to identify repeat + redoCh chan p2pTypes.ID // redo may send multitime, add peerId to identify repeat mtx sync.Mutex - peerID p2p.ID + peerID p2pTypes.ID block *types.Block } @@ -511,7 +511,7 @@ func newBPRequester(pool *BlockPool, height int64) *bpRequester { pool: pool, height: height, gotBlockCh: make(chan struct{}, 1), - redoCh: make(chan p2p.ID, 1), + redoCh: make(chan p2pTypes.ID, 1), peerID: "", block: nil, @@ -526,7 +526,7 @@ func (bpr *bpRequester) OnStart() error { } // Returns true if the peer matches and block doesn't already exist. -func (bpr *bpRequester) setBlock(block *types.Block, peerID p2p.ID) bool { +func (bpr *bpRequester) setBlock(block *types.Block, peerID p2pTypes.ID) bool { bpr.mtx.Lock() if bpr.block != nil || bpr.peerID != peerID { bpr.mtx.Unlock() @@ -548,7 +548,7 @@ func (bpr *bpRequester) getBlock() *types.Block { return bpr.block } -func (bpr *bpRequester) getPeerID() p2p.ID { +func (bpr *bpRequester) getPeerID() p2pTypes.ID { bpr.mtx.Lock() defer bpr.mtx.Unlock() return bpr.peerID @@ -570,7 +570,7 @@ func (bpr *bpRequester) reset() { // Tells bpRequester to pick another peer and try again. // NOTE: Nonblocking, and does nothing if another redo // was already requested. -func (bpr *bpRequester) redo(peerID p2p.ID) { +func (bpr *bpRequester) redo(peerID p2pTypes.ID) { select { case bpr.redoCh <- peerID: default: @@ -631,5 +631,5 @@ OUTER_LOOP: // delivering the block type BlockRequest struct { Height int64 - PeerID p2p.ID + PeerID p2pTypes.ID } diff --git a/tm2/pkg/bft/blockchain/pool_test.go b/tm2/pkg/bft/blockchain/pool_test.go index a4d5636d5e3..ee58d672e75 100644 --- a/tm2/pkg/bft/blockchain/pool_test.go +++ b/tm2/pkg/bft/blockchain/pool_test.go @@ -5,12 +5,12 @@ import ( "testing" "time" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gnolang/gno/tm2/pkg/bft/types" "github.com/gnolang/gno/tm2/pkg/log" - "github.com/gnolang/gno/tm2/pkg/p2p" "github.com/gnolang/gno/tm2/pkg/random" ) @@ -19,7 +19,7 @@ func init() { } type testPeer struct { - id p2p.ID + id p2pTypes.ID height int64 inputChan chan inputData // make sure each peer's data is sequential } @@ -47,7 +47,7 @@ func (p testPeer) simulateInput(input inputData) { // input.t.Logf("Added block from peer %v (height: %v)", input.request.PeerID, input.request.Height) } -type testPeers map[p2p.ID]testPeer +type testPeers map[p2pTypes.ID]testPeer func (ps testPeers) start() { for _, v := range ps { @@ -64,7 +64,7 @@ func (ps testPeers) stop() { func makePeers(numPeers int, minHeight, maxHeight int64) testPeers { peers := make(testPeers, numPeers) for i := 0; i < numPeers; i++ { - peerID := p2p.ID(random.RandStr(12)) + peerID := p2pTypes.ID(random.RandStr(12)) height := minHeight + random.RandInt63n(maxHeight-minHeight) peers[peerID] = testPeer{peerID, height, make(chan inputData, 10)} } @@ -172,7 +172,7 @@ func TestBlockPoolTimeout(t *testing.T) { // Pull from channels counter := 0 - timedOut := map[p2p.ID]struct{}{} + timedOut := map[p2pTypes.ID]struct{}{} for { select { case err := <-errorsCh: @@ -195,7 +195,7 @@ func TestBlockPoolRemovePeer(t *testing.T) { peers := make(testPeers, 10) for i := 0; i < 10; i++ { - peerID := p2p.ID(fmt.Sprintf("%d", i+1)) + peerID := p2pTypes.ID(fmt.Sprintf("%d", i+1)) height := int64(i + 1) peers[peerID] = testPeer{peerID, height, make(chan inputData)} } @@ -215,10 +215,10 @@ func TestBlockPoolRemovePeer(t *testing.T) { assert.EqualValues(t, 10, pool.MaxPeerHeight()) // remove not-existing peer - assert.NotPanics(t, func() { pool.RemovePeer(p2p.ID("Superman")) }) + assert.NotPanics(t, func() { pool.RemovePeer(p2pTypes.ID("Superman")) }) // remove peer with biggest height - pool.RemovePeer(p2p.ID("10")) + pool.RemovePeer(p2pTypes.ID("10")) assert.EqualValues(t, 9, pool.MaxPeerHeight()) // remove all peers diff --git a/tm2/pkg/bft/blockchain/reactor.go b/tm2/pkg/bft/blockchain/reactor.go index 09e1225b717..417e96ad383 100644 --- a/tm2/pkg/bft/blockchain/reactor.go +++ b/tm2/pkg/bft/blockchain/reactor.go @@ -12,6 +12,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/bft/store" "github.com/gnolang/gno/tm2/pkg/bft/types" "github.com/gnolang/gno/tm2/pkg/p2p" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" ) const ( @@ -37,15 +38,14 @@ const ( bcBlockResponseMessageFieldKeySize ) -type consensusReactor interface { - // for when we switch from blockchain reactor and fast sync to - // the consensus machine - SwitchToConsensus(sm.State, int) -} +// SwitchToConsensusFn is a callback method that is meant to +// stop the syncing process as soon as the latest known height is reached, +// and start the consensus process for the validator node +type SwitchToConsensusFn func(sm.State, int) type peerError struct { err error - peerID p2p.ID + peerID p2pTypes.ID } func (e peerError) Error() string { @@ -66,11 +66,17 @@ type BlockchainReactor struct { requestsCh <-chan BlockRequest errorsCh <-chan peerError + + switchToConsensusFn SwitchToConsensusFn } // NewBlockchainReactor returns new reactor instance. -func NewBlockchainReactor(state sm.State, blockExec *sm.BlockExecutor, store *store.BlockStore, +func NewBlockchainReactor( + state sm.State, + blockExec *sm.BlockExecutor, + store *store.BlockStore, fastSync bool, + switchToConsensusFn SwitchToConsensusFn, ) *BlockchainReactor { if state.LastBlockHeight != store.Height() { panic(fmt.Sprintf("state (%v) and store (%v) height mismatch", state.LastBlockHeight, @@ -89,13 +95,14 @@ func NewBlockchainReactor(state sm.State, blockExec *sm.BlockExecutor, store *st ) bcR := &BlockchainReactor{ - initialState: state, - blockExec: blockExec, - store: store, - pool: pool, - fastSync: fastSync, - requestsCh: requestsCh, - errorsCh: errorsCh, + initialState: state, + blockExec: blockExec, + store: store, + pool: pool, + fastSync: fastSync, + requestsCh: requestsCh, + errorsCh: errorsCh, + switchToConsensusFn: switchToConsensusFn, } bcR.BaseReactor = *p2p.NewBaseReactor("BlockchainReactor", bcR) return bcR @@ -138,7 +145,7 @@ func (bcR *BlockchainReactor) GetChannels() []*p2p.ChannelDescriptor { } // AddPeer implements Reactor by sending our state to peer. -func (bcR *BlockchainReactor) AddPeer(peer p2p.Peer) { +func (bcR *BlockchainReactor) AddPeer(peer p2p.PeerConn) { msgBytes := amino.MustMarshalAny(&bcStatusResponseMessage{bcR.store.Height()}) peer.Send(BlockchainChannel, msgBytes) // it's OK if send fails. will try later in poolRoutine @@ -148,7 +155,7 @@ func (bcR *BlockchainReactor) AddPeer(peer p2p.Peer) { } // RemovePeer implements Reactor by removing peer from the pool. -func (bcR *BlockchainReactor) RemovePeer(peer p2p.Peer, reason interface{}) { +func (bcR *BlockchainReactor) RemovePeer(peer p2p.PeerConn, reason interface{}) { bcR.pool.RemovePeer(peer.ID()) } @@ -157,7 +164,7 @@ func (bcR *BlockchainReactor) RemovePeer(peer p2p.Peer, reason interface{}) { // According to the Tendermint spec, if all nodes are honest, // no node should be requesting for a block that's non-existent. func (bcR *BlockchainReactor) respondToPeer(msg *bcBlockRequestMessage, - src p2p.Peer, + src p2p.PeerConn, ) (queued bool) { block := bcR.store.LoadBlock(msg.Height) if block != nil { @@ -172,7 +179,7 @@ func (bcR *BlockchainReactor) respondToPeer(msg *bcBlockRequestMessage, } // Receive implements Reactor by handling 4 types of messages (look below). -func (bcR *BlockchainReactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { +func (bcR *BlockchainReactor) Receive(chID byte, src p2p.PeerConn, msgBytes []byte) { msg, err := decodeMsg(msgBytes) if err != nil { bcR.Logger.Error("Error decoding message", "src", src, "chId", chID, "msg", msg, "err", err, "bytes", msgBytes) @@ -257,16 +264,13 @@ FOR_LOOP: select { case <-switchToConsensusTicker.C: height, numPending, lenRequesters := bcR.pool.GetStatus() - outbound, inbound, _ := bcR.Switch.NumPeers() - bcR.Logger.Debug("Consensus ticker", "numPending", numPending, "total", lenRequesters, - "outbound", outbound, "inbound", inbound) + + bcR.Logger.Debug("Consensus ticker", "numPending", numPending, "total", lenRequesters) if bcR.pool.IsCaughtUp() { bcR.Logger.Info("Time to switch to consensus reactor!", "height", height) bcR.pool.Stop() - conR, ok := bcR.Switch.Reactor("CONSENSUS").(consensusReactor) - if ok { - conR.SwitchToConsensus(state, blocksSynced) - } + + bcR.switchToConsensusFn(state, blocksSynced) // else { // should only happen during testing // } diff --git a/tm2/pkg/bft/blockchain/reactor_test.go b/tm2/pkg/bft/blockchain/reactor_test.go index a40dbc6376b..1bc2df59055 100644 --- a/tm2/pkg/bft/blockchain/reactor_test.go +++ b/tm2/pkg/bft/blockchain/reactor_test.go @@ -1,14 +1,13 @@ package blockchain import ( + "context" "log/slog" "os" "sort" "testing" "time" - "github.com/stretchr/testify/assert" - abci "github.com/gnolang/gno/tm2/pkg/bft/abci/types" "github.com/gnolang/gno/tm2/pkg/bft/appconn" cfg "github.com/gnolang/gno/tm2/pkg/bft/config" @@ -20,9 +19,12 @@ import ( tmtime "github.com/gnolang/gno/tm2/pkg/bft/types/time" "github.com/gnolang/gno/tm2/pkg/db/memdb" "github.com/gnolang/gno/tm2/pkg/errors" + p2pTesting "github.com/gnolang/gno/tm2/pkg/internal/p2p" "github.com/gnolang/gno/tm2/pkg/log" "github.com/gnolang/gno/tm2/pkg/p2p" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/gnolang/gno/tm2/pkg/testutils" + "github.com/stretchr/testify/assert" ) var config *cfg.Config @@ -110,7 +112,7 @@ func newBlockchainReactor(logger *slog.Logger, genDoc *types.GenesisDoc, privVal blockStore.SaveBlock(thisBlock, thisParts, lastCommit) } - bcReactor := NewBlockchainReactor(state.Copy(), blockExec, blockStore, fastSync) + bcReactor := NewBlockchainReactor(state.Copy(), blockExec, blockStore, fastSync, nil) bcReactor.SetLogger(logger.With("module", "blockchain")) return BlockchainReactorPair{bcReactor, proxyApp} @@ -125,15 +127,35 @@ func TestNoBlockResponse(t *testing.T) { maxBlockHeight := int64(65) - reactorPairs := make([]BlockchainReactorPair, 2) + var ( + reactorPairs = make([]BlockchainReactorPair, 2) + options = make(map[int][]p2p.SwitchOption) + ) - reactorPairs[0] = newBlockchainReactor(log.NewTestingLogger(t), genDoc, privVals, maxBlockHeight) - reactorPairs[1] = newBlockchainReactor(log.NewTestingLogger(t), genDoc, privVals, 0) + for i := range reactorPairs { + height := int64(0) + if i == 0 { + height = maxBlockHeight + } - p2p.MakeConnectedSwitches(config.P2P, 2, func(i int, s *p2p.Switch) *p2p.Switch { - s.AddReactor("BLOCKCHAIN", reactorPairs[i].reactor) - return s - }, p2p.Connect2Switches) + reactorPairs[i] = newBlockchainReactor(log.NewTestingLogger(t), genDoc, privVals, height) + + options[i] = []p2p.SwitchOption{ + p2p.WithReactor("BLOCKCHAIN", reactorPairs[i].reactor), + } + } + + ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second) + defer cancelFn() + + testingCfg := p2pTesting.TestingConfig{ + Count: 2, + P2PCfg: config.P2P, + SwitchOptions: options, + Channels: []byte{BlockchainChannel}, + } + + p2pTesting.MakeConnectedPeers(t, ctx, testingCfg) defer func() { for _, r := range reactorPairs { @@ -194,17 +216,35 @@ func TestFlappyBadBlockStopsPeer(t *testing.T) { otherChain.app.Stop() }() - reactorPairs := make([]BlockchainReactorPair, 4) + var ( + reactorPairs = make([]BlockchainReactorPair, 4) + options = make(map[int][]p2p.SwitchOption) + ) + + for i := range reactorPairs { + height := int64(0) + if i == 0 { + height = maxBlockHeight + } + + reactorPairs[i] = newBlockchainReactor(log.NewNoopLogger(), genDoc, privVals, height) - reactorPairs[0] = newBlockchainReactor(log.NewNoopLogger(), genDoc, privVals, maxBlockHeight) - reactorPairs[1] = newBlockchainReactor(log.NewNoopLogger(), genDoc, privVals, 0) - reactorPairs[2] = newBlockchainReactor(log.NewNoopLogger(), genDoc, privVals, 0) - reactorPairs[3] = newBlockchainReactor(log.NewNoopLogger(), genDoc, privVals, 0) + options[i] = []p2p.SwitchOption{ + p2p.WithReactor("BLOCKCHAIN", reactorPairs[i].reactor), + } + } - switches := p2p.MakeConnectedSwitches(config.P2P, 4, func(i int, s *p2p.Switch) *p2p.Switch { - s.AddReactor("BLOCKCHAIN", reactorPairs[i].reactor) - return s - }, p2p.Connect2Switches) + ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second) + defer cancelFn() + + testingCfg := p2pTesting.TestingConfig{ + Count: 4, + P2PCfg: config.P2P, + SwitchOptions: options, + Channels: []byte{BlockchainChannel}, + } + + switches, transports := p2pTesting.MakeConnectedPeers(t, ctx, testingCfg) defer func() { for _, r := range reactorPairs { @@ -222,7 +262,7 @@ func TestFlappyBadBlockStopsPeer(t *testing.T) { } // at this time, reactors[0-3] is the newest - assert.Equal(t, 3, reactorPairs[1].reactor.Switch.Peers().Size()) + assert.Equal(t, 3, len(reactorPairs[1].reactor.Switch.Peers().List())) // mark reactorPairs[3] is an invalid peer reactorPairs[3].reactor.store = otherChain.reactor.store @@ -230,24 +270,41 @@ func TestFlappyBadBlockStopsPeer(t *testing.T) { lastReactorPair := newBlockchainReactor(log.NewNoopLogger(), genDoc, privVals, 0) reactorPairs = append(reactorPairs, lastReactorPair) - switches = append(switches, p2p.MakeConnectedSwitches(config.P2P, 1, func(i int, s *p2p.Switch) *p2p.Switch { - s.AddReactor("BLOCKCHAIN", reactorPairs[len(reactorPairs)-1].reactor) - return s - }, p2p.Connect2Switches)...) + persistentPeers := make([]*p2pTypes.NetAddress, 0, len(transports)) - for i := 0; i < len(reactorPairs)-1; i++ { - p2p.Connect2Switches(switches, i, len(reactorPairs)-1) + for _, tr := range transports { + addr := tr.NetAddress() + persistentPeers = append(persistentPeers, &addr) } + for i, opt := range options { + opt = append(opt, p2p.WithPersistentPeers(persistentPeers)) + + options[i] = opt + } + + ctx, cancelFn = context.WithTimeout(context.Background(), 10*time.Second) + defer cancelFn() + + testingCfg = p2pTesting.TestingConfig{ + Count: 1, + P2PCfg: config.P2P, + SwitchOptions: options, + Channels: []byte{BlockchainChannel}, + } + + sw, _ := p2pTesting.MakeConnectedPeers(t, ctx, testingCfg) + switches = append(switches, sw...) + for { - if lastReactorPair.reactor.pool.IsCaughtUp() || lastReactorPair.reactor.Switch.Peers().Size() == 0 { + if lastReactorPair.reactor.pool.IsCaughtUp() || len(lastReactorPair.reactor.Switch.Peers().List()) == 0 { break } time.Sleep(1 * time.Second) } - assert.True(t, lastReactorPair.reactor.Switch.Peers().Size() < len(reactorPairs)-1) + assert.True(t, len(lastReactorPair.reactor.Switch.Peers().List()) < len(reactorPairs)-1) } func TestBcBlockRequestMessageValidateBasic(t *testing.T) { diff --git a/tm2/pkg/bft/config/config.go b/tm2/pkg/bft/config/config.go index f9e9a0cd899..1a01686f4bd 100644 --- a/tm2/pkg/bft/config/config.go +++ b/tm2/pkg/bft/config/config.go @@ -6,6 +6,7 @@ import ( "path/filepath" "regexp" "slices" + "time" "dario.cat/mergo" abci "github.com/gnolang/gno/tm2/pkg/bft/abci/types" @@ -17,6 +18,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/errors" osm "github.com/gnolang/gno/tm2/pkg/os" p2p "github.com/gnolang/gno/tm2/pkg/p2p/config" + sdk "github.com/gnolang/gno/tm2/pkg/sdk/config" telemetry "github.com/gnolang/gno/tm2/pkg/telemetry/config" ) @@ -54,6 +56,7 @@ type Config struct { Consensus *cns.ConsensusConfig `json:"consensus" toml:"consensus" comment:"##### consensus configuration options #####"` TxEventStore *eventstore.Config `json:"tx_event_store" toml:"tx_event_store" comment:"##### event store #####"` Telemetry *telemetry.Config `json:"telemetry" toml:"telemetry" comment:"##### node telemetry #####"` + Application *sdk.AppConfig `json:"application" toml:"application" comment:"##### app settings #####"` } // DefaultConfig returns a default configuration for a Tendermint node @@ -66,6 +69,7 @@ func DefaultConfig() *Config { Consensus: cns.DefaultConsensusConfig(), TxEventStore: eventstore.DefaultEventStoreConfig(), Telemetry: telemetry.DefaultTelemetryConfig(), + Application: sdk.DefaultAppConfig(), } } @@ -163,16 +167,26 @@ func LoadOrMakeConfigWithOptions(root string, opts ...Option) (*Config, error) { return cfg, nil } +// testP2PConfig returns a configuration for testing the peer-to-peer layer +func testP2PConfig() *p2p.P2PConfig { + cfg := p2p.DefaultP2PConfig() + cfg.ListenAddress = "tcp://0.0.0.0:26656" + cfg.FlushThrottleTimeout = 10 * time.Millisecond + + return cfg +} + // TestConfig returns a configuration that can be used for testing func TestConfig() *Config { return &Config{ BaseConfig: testBaseConfig(), RPC: rpc.TestRPCConfig(), - P2P: p2p.TestP2PConfig(), + P2P: testP2PConfig(), Mempool: mem.TestMempoolConfig(), Consensus: cns.TestConsensusConfig(), TxEventStore: eventstore.DefaultEventStoreConfig(), Telemetry: telemetry.DefaultTelemetryConfig(), + Application: sdk.DefaultAppConfig(), } } @@ -228,6 +242,9 @@ func (cfg *Config) ValidateBasic() error { if err := cfg.Consensus.ValidateBasic(); err != nil { return errors.Wrap(err, "Error in [consensus] section") } + if err := cfg.Application.ValidateBasic(); err != nil { + return errors.Wrap(err, "Error in [application] section") + } return nil } @@ -318,10 +335,6 @@ type BaseConfig struct { // TCP or UNIX socket address for the profiling server to listen on ProfListenAddress string `toml:"prof_laddr" comment:"TCP or UNIX socket address for the profiling server to listen on"` - - // If true, query the ABCI app on connecting to a new peer - // so the app can decide if we should keep the connection or not - FilterPeers bool `toml:"filter_peers" comment:"If true, query the ABCI app on connecting to a new peer\n so the app can decide if we should keep the connection or not"` // false } // DefaultBaseConfig returns a default base configuration for a Tendermint node @@ -335,7 +348,6 @@ func DefaultBaseConfig() BaseConfig { ABCI: SocketABCI, ProfListenAddress: "", FastSyncMode: true, - FilterPeers: false, DBBackend: db.GoLevelDBBackend.String(), DBPath: DefaultDBDir, } @@ -372,6 +384,10 @@ func (cfg BaseConfig) NodeKeyFile() string { // DBDir returns the full path to the database directory func (cfg BaseConfig) DBDir() string { + if filepath.IsAbs(cfg.DBPath) { + return cfg.DBPath + } + return filepath.Join(cfg.RootDir, cfg.DBPath) } diff --git a/tm2/pkg/bft/config/config_test.go b/tm2/pkg/bft/config/config_test.go index 77f7c0d5e16..ea37e6e1763 100644 --- a/tm2/pkg/bft/config/config_test.go +++ b/tm2/pkg/bft/config/config_test.go @@ -185,3 +185,28 @@ func TestConfig_ValidateBaseConfig(t *testing.T) { assert.ErrorIs(t, c.BaseConfig.ValidateBasic(), errInvalidProfListenAddress) }) } + +func TestConfig_DBDir(t *testing.T) { + t.Parallel() + + t.Run("DB path is absolute", func(t *testing.T) { + t.Parallel() + + c := DefaultConfig() + c.RootDir = "/root" + c.DBPath = "/abs/path" + + assert.Equal(t, c.DBPath, c.DBDir()) + assert.NotEqual(t, filepath.Join(c.RootDir, c.DBPath), c.DBDir()) + }) + + t.Run("DB path is relative", func(t *testing.T) { + t.Parallel() + + c := DefaultConfig() + c.RootDir = "/root" + c.DBPath = "relative/path" + + assert.Equal(t, filepath.Join(c.RootDir, c.DBPath), c.DBDir()) + }) +} diff --git a/tm2/pkg/bft/consensus/reactor.go b/tm2/pkg/bft/consensus/reactor.go index aee695114f8..f39f012b289 100644 --- a/tm2/pkg/bft/consensus/reactor.go +++ b/tm2/pkg/bft/consensus/reactor.go @@ -35,7 +35,7 @@ const ( // ConsensusReactor defines a reactor for the consensus service. type ConsensusReactor struct { - p2p.BaseReactor // BaseService + p2p.Switch + p2p.BaseReactor // BaseService + p2p.MultiplexSwitch conS *ConsensusState @@ -157,7 +157,7 @@ func (conR *ConsensusReactor) GetChannels() []*p2p.ChannelDescriptor { } // InitPeer implements Reactor by creating a state for the peer. -func (conR *ConsensusReactor) InitPeer(peer p2p.Peer) p2p.Peer { +func (conR *ConsensusReactor) InitPeer(peer p2p.PeerConn) p2p.PeerConn { peerState := NewPeerState(peer).SetLogger(conR.Logger) peer.Set(types.PeerStateKey, peerState) return peer @@ -165,7 +165,7 @@ func (conR *ConsensusReactor) InitPeer(peer p2p.Peer) p2p.Peer { // AddPeer implements Reactor by spawning multiple gossiping goroutines for the // peer. -func (conR *ConsensusReactor) AddPeer(peer p2p.Peer) { +func (conR *ConsensusReactor) AddPeer(peer p2p.PeerConn) { if !conR.IsRunning() { return } @@ -187,7 +187,7 @@ func (conR *ConsensusReactor) AddPeer(peer p2p.Peer) { } // RemovePeer is a noop. -func (conR *ConsensusReactor) RemovePeer(peer p2p.Peer, reason interface{}) { +func (conR *ConsensusReactor) RemovePeer(peer p2p.PeerConn, reason interface{}) { if !conR.IsRunning() { return } @@ -205,7 +205,7 @@ func (conR *ConsensusReactor) RemovePeer(peer p2p.Peer, reason interface{}) { // Peer state updates can happen in parallel, but processing of // proposals, block parts, and votes are ordered by the receiveRoutine // NOTE: blocks on consensus state for proposals, block parts, and votes -func (conR *ConsensusReactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { +func (conR *ConsensusReactor) Receive(chID byte, src p2p.PeerConn, msgBytes []byte) { if !conR.IsRunning() { conR.Logger.Debug("Receive", "src", src, "chId", chID, "bytes", msgBytes) return @@ -417,7 +417,7 @@ func (conR *ConsensusReactor) broadcastHasVoteMessage(vote *types.Vote) { conR.Switch.Broadcast(StateChannel, amino.MustMarshalAny(msg)) /* // TODO: Make this broadcast more selective. - for _, peer := range conR.Switch.Peers().List() { + for _, peer := range conR.MultiplexSwitch.Peers().List() { ps, ok := peer.Get(PeerStateKey).(*PeerState) if !ok { panic(fmt.Sprintf("Peer %v has no state", peer)) @@ -446,13 +446,13 @@ func makeRoundStepMessage(event cstypes.EventNewRoundStep) (nrsMsg *NewRoundStep return } -func (conR *ConsensusReactor) sendNewRoundStepMessage(peer p2p.Peer) { +func (conR *ConsensusReactor) sendNewRoundStepMessage(peer p2p.PeerConn) { rs := conR.conS.GetRoundState() nrsMsg := makeRoundStepMessage(rs.EventNewRoundStep()) peer.Send(StateChannel, amino.MustMarshalAny(nrsMsg)) } -func (conR *ConsensusReactor) gossipDataRoutine(peer p2p.Peer, ps *PeerState) { +func (conR *ConsensusReactor) gossipDataRoutine(peer p2p.PeerConn, ps *PeerState) { logger := conR.Logger.With("peer", peer) OUTER_LOOP: @@ -547,7 +547,7 @@ OUTER_LOOP: } func (conR *ConsensusReactor) gossipDataForCatchup(logger *slog.Logger, rs *cstypes.RoundState, - prs *cstypes.PeerRoundState, ps *PeerState, peer p2p.Peer, + prs *cstypes.PeerRoundState, ps *PeerState, peer p2p.PeerConn, ) { if index, ok := prs.ProposalBlockParts.Not().PickRandom(); ok { // Ensure that the peer's PartSetHeader is correct @@ -589,7 +589,7 @@ func (conR *ConsensusReactor) gossipDataForCatchup(logger *slog.Logger, rs *csty time.Sleep(conR.conS.config.PeerGossipSleepDuration) } -func (conR *ConsensusReactor) gossipVotesRoutine(peer p2p.Peer, ps *PeerState) { +func (conR *ConsensusReactor) gossipVotesRoutine(peer p2p.PeerConn, ps *PeerState) { logger := conR.Logger.With("peer", peer) // Simple hack to throttle logs upon sleep. @@ -715,7 +715,7 @@ func (conR *ConsensusReactor) gossipVotesForHeight(logger *slog.Logger, rs *csty // NOTE: `queryMaj23Routine` has a simple crude design since it only comes // into play for liveness when there's a signature DDoS attack happening. -func (conR *ConsensusReactor) queryMaj23Routine(peer p2p.Peer, ps *PeerState) { +func (conR *ConsensusReactor) queryMaj23Routine(peer p2p.PeerConn, ps *PeerState) { logger := conR.Logger.With("peer", peer) OUTER_LOOP: @@ -826,12 +826,12 @@ func (conR *ConsensusReactor) peerStatsRoutine() { case *VoteMessage: if numVotes := ps.RecordVote(); numVotes%votesToContributeToBecomeGoodPeer == 0 { // TODO: peer metrics. - // conR.Switch.MarkPeerAsGood(peer) + // conR.MultiplexSwitch.MarkPeerAsGood(peer) } case *BlockPartMessage: if numParts := ps.RecordBlockPart(); numParts%blocksToContributeToBecomeGoodPeer == 0 { // TODO: peer metrics. - // conR.Switch.MarkPeerAsGood(peer) + // conR.MultiplexSwitch.MarkPeerAsGood(peer) } } case <-conR.conS.Quit(): @@ -878,7 +878,7 @@ var ( // NOTE: PeerStateExposed gets dumped with rpc/core/consensus.go. // Be mindful of what you Expose. type PeerState struct { - peer p2p.Peer + peer p2p.PeerConn logger *slog.Logger mtx sync.Mutex // NOTE: Modify below using setters, never directly. @@ -886,7 +886,7 @@ type PeerState struct { } // NewPeerState returns a new PeerState for the given Peer -func NewPeerState(peer p2p.Peer) *PeerState { +func NewPeerState(peer p2p.PeerConn) *PeerState { return &PeerState{ peer: peer, logger: log.NewNoopLogger(), diff --git a/tm2/pkg/bft/consensus/reactor_test.go b/tm2/pkg/bft/consensus/reactor_test.go index 42f944b7481..0e1d6249783 100644 --- a/tm2/pkg/bft/consensus/reactor_test.go +++ b/tm2/pkg/bft/consensus/reactor_test.go @@ -1,14 +1,13 @@ package consensus import ( + "context" "fmt" "log/slog" "sync" "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/gnolang/gno/tm2/pkg/amino" "github.com/gnolang/gno/tm2/pkg/bft/abci/example/kvstore" cfg "github.com/gnolang/gno/tm2/pkg/bft/config" @@ -18,27 +17,39 @@ import ( "github.com/gnolang/gno/tm2/pkg/bitarray" "github.com/gnolang/gno/tm2/pkg/crypto/tmhash" "github.com/gnolang/gno/tm2/pkg/events" + p2pTesting "github.com/gnolang/gno/tm2/pkg/internal/p2p" "github.com/gnolang/gno/tm2/pkg/log" osm "github.com/gnolang/gno/tm2/pkg/os" "github.com/gnolang/gno/tm2/pkg/p2p" - "github.com/gnolang/gno/tm2/pkg/p2p/mock" "github.com/gnolang/gno/tm2/pkg/testutils" + "github.com/stretchr/testify/assert" ) // ---------------------------------------------- // in-process testnets -func startConsensusNet(css []*ConsensusState, n int) ([]*ConsensusReactor, []<-chan events.Event, []events.EventSwitch, []*p2p.Switch) { +func startConsensusNet( + t *testing.T, + css []*ConsensusState, + n int, +) ([]*ConsensusReactor, []<-chan events.Event, []events.EventSwitch, []*p2p.MultiplexSwitch) { + t.Helper() + reactors := make([]*ConsensusReactor, n) blocksSubs := make([]<-chan events.Event, 0) eventSwitches := make([]events.EventSwitch, n) - p2pSwitches := ([]*p2p.Switch)(nil) + p2pSwitches := ([]*p2p.MultiplexSwitch)(nil) + options := make(map[int][]p2p.SwitchOption) for i := 0; i < n; i++ { /*logger, err := tmflags.ParseLogLevel("consensus:info,*:error", logger, "info") if err != nil { t.Fatal(err)}*/ reactors[i] = NewConsensusReactor(css[i], true) // so we dont start the consensus states reactors[i].SetLogger(css[i].Logger) + options[i] = []p2p.SwitchOption{ + p2p.WithReactor("CONSENSUS", reactors[i]), + } + // evsw is already started with the cs eventSwitches[i] = css[i].evsw reactors[i].SetEventSwitch(eventSwitches[i]) @@ -51,11 +62,22 @@ func startConsensusNet(css []*ConsensusState, n int) ([]*ConsensusReactor, []<-c } } // make connected switches and start all reactors - p2pSwitches = p2p.MakeConnectedSwitches(config.P2P, n, func(i int, s *p2p.Switch) *p2p.Switch { - s.AddReactor("CONSENSUS", reactors[i]) - s.SetLogger(reactors[i].conS.Logger.With("module", "p2p")) - return s - }, p2p.Connect2Switches) + ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second) + defer cancelFn() + + testingCfg := p2pTesting.TestingConfig{ + P2PCfg: config.P2P, + Count: n, + SwitchOptions: options, + Channels: []byte{ + StateChannel, + DataChannel, + VoteChannel, + VoteSetBitsChannel, + }, + } + + p2pSwitches, _ = p2pTesting.MakeConnectedPeers(t, ctx, testingCfg) // now that everyone is connected, start the state machines // If we started the state machines before everyone was connected, @@ -68,11 +90,15 @@ func startConsensusNet(css []*ConsensusState, n int) ([]*ConsensusReactor, []<-c return reactors, blocksSubs, eventSwitches, p2pSwitches } -func stopConsensusNet(logger *slog.Logger, reactors []*ConsensusReactor, eventSwitches []events.EventSwitch, p2pSwitches []*p2p.Switch) { +func stopConsensusNet( + logger *slog.Logger, + reactors []*ConsensusReactor, + eventSwitches []events.EventSwitch, + p2pSwitches []*p2p.MultiplexSwitch, +) { logger.Info("stopConsensusNet", "n", len(reactors)) - for i, r := range reactors { + for i := range reactors { logger.Info("stopConsensusNet: Stopping ConsensusReactor", "i", i) - r.Switch.Stop() } for i, b := range eventSwitches { logger.Info("stopConsensusNet: Stopping evsw", "i", i) @@ -92,7 +118,7 @@ func TestReactorBasic(t *testing.T) { N := 4 css, cleanup := randConsensusNet(N, "consensus_reactor_test", newMockTickerFunc(true), newCounter) defer cleanup() - reactors, blocksSubs, eventSwitches, p2pSwitches := startConsensusNet(css, N) + reactors, blocksSubs, eventSwitches, p2pSwitches := startConsensusNet(t, css, N) defer stopConsensusNet(log.NewTestingLogger(t), reactors, eventSwitches, p2pSwitches) // wait till everyone makes the first new block timeoutWaitGroup(t, N, func(j int) { @@ -112,7 +138,7 @@ func TestReactorCreatesBlockWhenEmptyBlocksFalse(t *testing.T) { c.Consensus.CreateEmptyBlocks = false }) defer cleanup() - reactors, blocksSubs, eventSwitches, p2pSwitches := startConsensusNet(css, N) + reactors, blocksSubs, eventSwitches, p2pSwitches := startConsensusNet(t, css, N) defer stopConsensusNet(log.NewTestingLogger(t), reactors, eventSwitches, p2pSwitches) // send a tx @@ -132,12 +158,12 @@ func TestReactorReceiveDoesNotPanicIfAddPeerHasntBeenCalledYet(t *testing.T) { N := 1 css, cleanup := randConsensusNet(N, "consensus_reactor_test", newMockTickerFunc(true), newCounter) defer cleanup() - reactors, _, eventSwitches, p2pSwitches := startConsensusNet(css, N) + reactors, _, eventSwitches, p2pSwitches := startConsensusNet(t, css, N) defer stopConsensusNet(log.NewTestingLogger(t), reactors, eventSwitches, p2pSwitches) var ( reactor = reactors[0] - peer = mock.NewPeer(nil) + peer = p2pTesting.NewPeer(t) msg = amino.MustMarshalAny(&HasVoteMessage{Height: 1, Round: 1, Index: 1, Type: types.PrevoteType}) ) @@ -156,12 +182,12 @@ func TestReactorReceivePanicsIfInitPeerHasntBeenCalledYet(t *testing.T) { N := 1 css, cleanup := randConsensusNet(N, "consensus_reactor_test", newMockTickerFunc(true), newCounter) defer cleanup() - reactors, _, eventSwitches, p2pSwitches := startConsensusNet(css, N) + reactors, _, eventSwitches, p2pSwitches := startConsensusNet(t, css, N) defer stopConsensusNet(log.NewTestingLogger(t), reactors, eventSwitches, p2pSwitches) var ( reactor = reactors[0] - peer = mock.NewPeer(nil) + peer = p2pTesting.NewPeer(t) msg = amino.MustMarshalAny(&HasVoteMessage{Height: 1, Round: 1, Index: 1, Type: types.PrevoteType}) ) @@ -182,7 +208,7 @@ func TestFlappyReactorRecordsVotesAndBlockParts(t *testing.T) { N := 4 css, cleanup := randConsensusNet(N, "consensus_reactor_test", newMockTickerFunc(true), newCounter) defer cleanup() - reactors, blocksSubs, eventSwitches, p2pSwitches := startConsensusNet(css, N) + reactors, blocksSubs, eventSwitches, p2pSwitches := startConsensusNet(t, css, N) defer stopConsensusNet(log.NewTestingLogger(t), reactors, eventSwitches, p2pSwitches) // wait till everyone makes the first new block @@ -210,7 +236,7 @@ func TestReactorVotingPowerChange(t *testing.T) { css, cleanup := randConsensusNet(nVals, "consensus_voting_power_changes_test", newMockTickerFunc(true), newPersistentKVStore) defer cleanup() - reactors, blocksSubs, eventSwitches, p2pSwitches := startConsensusNet(css, nVals) + reactors, blocksSubs, eventSwitches, p2pSwitches := startConsensusNet(t, css, nVals) defer stopConsensusNet(logger, reactors, eventSwitches, p2pSwitches) // map of active validators @@ -276,7 +302,7 @@ func TestReactorValidatorSetChanges(t *testing.T) { logger := log.NewTestingLogger(t) - reactors, blocksSubs, eventSwitches, p2pSwitches := startConsensusNet(css, nPeers) + reactors, blocksSubs, eventSwitches, p2pSwitches := startConsensusNet(t, css, nPeers) defer stopConsensusNet(logger, reactors, eventSwitches, p2pSwitches) // map of active validators @@ -375,7 +401,7 @@ func TestReactorWithTimeoutCommit(t *testing.T) { css[i].config.SkipTimeoutCommit = false } - reactors, blocksSubs, eventSwitches, p2pSwitches := startConsensusNet(css, N-1) + reactors, blocksSubs, eventSwitches, p2pSwitches := startConsensusNet(t, css, N-1) defer stopConsensusNet(log.NewTestingLogger(t), reactors, eventSwitches, p2pSwitches) // wait till everyone makes the first new block diff --git a/tm2/pkg/bft/consensus/state.go b/tm2/pkg/bft/consensus/state.go index 8b2653813e3..d9c78ec1bdf 100644 --- a/tm2/pkg/bft/consensus/state.go +++ b/tm2/pkg/bft/consensus/state.go @@ -23,7 +23,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/errors" "github.com/gnolang/gno/tm2/pkg/events" osm "github.com/gnolang/gno/tm2/pkg/os" - "github.com/gnolang/gno/tm2/pkg/p2p" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/gnolang/gno/tm2/pkg/service" "github.com/gnolang/gno/tm2/pkg/telemetry" "github.com/gnolang/gno/tm2/pkg/telemetry/metrics" @@ -53,7 +53,7 @@ type newRoundStepInfo struct { // msgs from the reactor which may update the state type msgInfo struct { Msg ConsensusMessage `json:"msg"` - PeerID p2p.ID `json:"peer_key"` + PeerID p2pTypes.ID `json:"peer_key"` } // WAL message. @@ -399,7 +399,7 @@ func (cs *ConsensusState) OpenWAL(walFile string) (walm.WAL, error) { // TODO: should these return anything or let callers just use events? // AddVote inputs a vote. -func (cs *ConsensusState) AddVote(vote *types.Vote, peerID p2p.ID) (added bool, err error) { +func (cs *ConsensusState) AddVote(vote *types.Vote, peerID p2pTypes.ID) (added bool, err error) { if peerID == "" { cs.internalMsgQueue <- msgInfo{&VoteMessage{vote}, ""} } else { @@ -411,7 +411,7 @@ func (cs *ConsensusState) AddVote(vote *types.Vote, peerID p2p.ID) (added bool, } // SetProposal inputs a proposal. -func (cs *ConsensusState) SetProposal(proposal *types.Proposal, peerID p2p.ID) error { +func (cs *ConsensusState) SetProposal(proposal *types.Proposal, peerID p2pTypes.ID) error { if peerID == "" { cs.internalMsgQueue <- msgInfo{&ProposalMessage{proposal}, ""} } else { @@ -423,7 +423,7 @@ func (cs *ConsensusState) SetProposal(proposal *types.Proposal, peerID p2p.ID) e } // AddProposalBlockPart inputs a part of the proposal block. -func (cs *ConsensusState) AddProposalBlockPart(height int64, round int, part *types.Part, peerID p2p.ID) error { +func (cs *ConsensusState) AddProposalBlockPart(height int64, round int, part *types.Part, peerID p2pTypes.ID) error { if peerID == "" { cs.internalMsgQueue <- msgInfo{&BlockPartMessage{height, round, part}, ""} } else { @@ -435,7 +435,7 @@ func (cs *ConsensusState) AddProposalBlockPart(height int64, round int, part *ty } // SetProposalAndBlock inputs the proposal and all block parts. -func (cs *ConsensusState) SetProposalAndBlock(proposal *types.Proposal, block *types.Block, parts *types.PartSet, peerID p2p.ID) error { +func (cs *ConsensusState) SetProposalAndBlock(proposal *types.Proposal, block *types.Block, parts *types.PartSet, peerID p2pTypes.ID) error { if err := cs.SetProposal(proposal, peerID); err != nil { return err } @@ -1444,7 +1444,7 @@ func (cs *ConsensusState) defaultSetProposal(proposal *types.Proposal) error { // NOTE: block is not necessarily valid. // Asynchronously triggers either enterPrevote (before we timeout of propose) or tryFinalizeCommit, once we have the full block. -func (cs *ConsensusState) addProposalBlockPart(msg *BlockPartMessage, peerID p2p.ID) (added bool, err error) { +func (cs *ConsensusState) addProposalBlockPart(msg *BlockPartMessage, peerID p2pTypes.ID) (added bool, err error) { height, round, part := msg.Height, msg.Round, msg.Part // Blocks might be reused, so round mismatch is OK @@ -1514,7 +1514,7 @@ func (cs *ConsensusState) addProposalBlockPart(msg *BlockPartMessage, peerID p2p } // Attempt to add the vote. if its a duplicate signature, dupeout the validator -func (cs *ConsensusState) tryAddVote(vote *types.Vote, peerID p2p.ID) (bool, error) { +func (cs *ConsensusState) tryAddVote(vote *types.Vote, peerID p2pTypes.ID) (bool, error) { added, err := cs.addVote(vote, peerID) if err != nil { // If the vote height is off, we'll just ignore it, @@ -1547,7 +1547,7 @@ func (cs *ConsensusState) tryAddVote(vote *types.Vote, peerID p2p.ID) (bool, err // ----------------------------------------------------------------------------- -func (cs *ConsensusState) addVote(vote *types.Vote, peerID p2p.ID) (added bool, err error) { +func (cs *ConsensusState) addVote(vote *types.Vote, peerID p2pTypes.ID) (added bool, err error) { cs.Logger.Debug("addVote", "voteHeight", vote.Height, "voteType", vote.Type, "valIndex", vote.ValidatorIndex, "csHeight", cs.Height) // A precommit for the previous height? diff --git a/tm2/pkg/bft/consensus/state_test.go b/tm2/pkg/bft/consensus/state_test.go index 201cf8906b3..4f4b3e3eb05 100644 --- a/tm2/pkg/bft/consensus/state_test.go +++ b/tm2/pkg/bft/consensus/state_test.go @@ -1733,7 +1733,7 @@ func TestStateOutputsBlockPartsStats(t *testing.T) { // create dummy peer cs, _ := randConsensusState(1) - peer := p2pmock.NewPeer(nil) + peer := p2pmock.Peer{} // 1) new block part parts := types.NewPartSetFromData(random.RandBytes(100), 10) @@ -1777,7 +1777,7 @@ func TestStateOutputVoteStats(t *testing.T) { cs, vss := randConsensusState(2) // create dummy peer - peer := p2pmock.NewPeer(nil) + peer := p2pmock.Peer{} vote := signVote(vss[1], types.PrecommitType, []byte("test"), types.PartSetHeader{}) diff --git a/tm2/pkg/bft/consensus/types/height_vote_set.go b/tm2/pkg/bft/consensus/types/height_vote_set.go index b81937ebd1e..7f3d52022ad 100644 --- a/tm2/pkg/bft/consensus/types/height_vote_set.go +++ b/tm2/pkg/bft/consensus/types/height_vote_set.go @@ -8,7 +8,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/amino" "github.com/gnolang/gno/tm2/pkg/bft/types" - "github.com/gnolang/gno/tm2/pkg/p2p" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" ) type RoundVoteSet struct { @@ -39,9 +39,9 @@ type HeightVoteSet struct { valSet *types.ValidatorSet mtx sync.Mutex - round int // max tracked round - roundVoteSets map[int]RoundVoteSet // keys: [0...round] - peerCatchupRounds map[p2p.ID][]int // keys: peer.ID; values: at most 2 rounds + round int // max tracked round + roundVoteSets map[int]RoundVoteSet // keys: [0...round] + peerCatchupRounds map[p2pTypes.ID][]int // keys: peer.ID; values: at most 2 rounds } func NewHeightVoteSet(chainID string, height int64, valSet *types.ValidatorSet) *HeightVoteSet { @@ -59,7 +59,7 @@ func (hvs *HeightVoteSet) Reset(height int64, valSet *types.ValidatorSet) { hvs.height = height hvs.valSet = valSet hvs.roundVoteSets = make(map[int]RoundVoteSet) - hvs.peerCatchupRounds = make(map[p2p.ID][]int) + hvs.peerCatchupRounds = make(map[p2pTypes.ID][]int) hvs.addRound(0) hvs.round = 0 @@ -108,7 +108,7 @@ func (hvs *HeightVoteSet) addRound(round int) { // Duplicate votes return added=false, err=nil. // By convention, peerID is "" if origin is self. -func (hvs *HeightVoteSet) AddVote(vote *types.Vote, peerID p2p.ID) (added bool, err error) { +func (hvs *HeightVoteSet) AddVote(vote *types.Vote, peerID p2pTypes.ID) (added bool, err error) { hvs.mtx.Lock() defer hvs.mtx.Unlock() if !types.IsVoteTypeValid(vote.Type) { @@ -176,7 +176,7 @@ func (hvs *HeightVoteSet) getVoteSet(round int, type_ types.SignedMsgType) *type // NOTE: if there are too many peers, or too much peer churn, // this can cause memory issues. // TODO: implement ability to remove peers too -func (hvs *HeightVoteSet) SetPeerMaj23(round int, type_ types.SignedMsgType, peerID p2p.ID, blockID types.BlockID) error { +func (hvs *HeightVoteSet) SetPeerMaj23(round int, type_ types.SignedMsgType, peerID p2pTypes.ID, blockID types.BlockID) error { hvs.mtx.Lock() defer hvs.mtx.Unlock() if !types.IsVoteTypeValid(type_) { diff --git a/tm2/pkg/bft/mempool/reactor.go b/tm2/pkg/bft/mempool/reactor.go index 3ef85b80a21..cf253999fb3 100644 --- a/tm2/pkg/bft/mempool/reactor.go +++ b/tm2/pkg/bft/mempool/reactor.go @@ -13,6 +13,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/bft/types" "github.com/gnolang/gno/tm2/pkg/clist" "github.com/gnolang/gno/tm2/pkg/p2p" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" ) const ( @@ -39,25 +40,25 @@ type Reactor struct { type mempoolIDs struct { mtx sync.RWMutex - peerMap map[p2p.ID]uint16 + peerMap map[p2pTypes.ID]uint16 nextID uint16 // assumes that a node will never have over 65536 active peers - activeIDs map[uint16]struct{} // used to check if a given peerID key is used, the value doesn't matter + activeIDs map[uint16]struct{} // used to check if a given mempoolID key is used, the value doesn't matter } // Reserve searches for the next unused ID and assigns it to the // peer. -func (ids *mempoolIDs) ReserveForPeer(peer p2p.Peer) { +func (ids *mempoolIDs) ReserveForPeer(id p2pTypes.ID) { ids.mtx.Lock() defer ids.mtx.Unlock() - curID := ids.nextPeerID() - ids.peerMap[peer.ID()] = curID + curID := ids.nextMempoolPeerID() + ids.peerMap[id] = curID ids.activeIDs[curID] = struct{}{} } -// nextPeerID returns the next unused peer ID to use. +// nextMempoolPeerID returns the next unused peer ID to use. // This assumes that ids's mutex is already locked. -func (ids *mempoolIDs) nextPeerID() uint16 { +func (ids *mempoolIDs) nextMempoolPeerID() uint16 { if len(ids.activeIDs) == maxActiveIDs { panic(fmt.Sprintf("node has maximum %d active IDs and wanted to get one more", maxActiveIDs)) } @@ -73,28 +74,28 @@ func (ids *mempoolIDs) nextPeerID() uint16 { } // Reclaim returns the ID reserved for the peer back to unused pool. -func (ids *mempoolIDs) Reclaim(peer p2p.Peer) { +func (ids *mempoolIDs) Reclaim(id p2pTypes.ID) { ids.mtx.Lock() defer ids.mtx.Unlock() - removedID, ok := ids.peerMap[peer.ID()] + removedID, ok := ids.peerMap[id] if ok { delete(ids.activeIDs, removedID) - delete(ids.peerMap, peer.ID()) + delete(ids.peerMap, id) } } // GetForPeer returns an ID reserved for the peer. -func (ids *mempoolIDs) GetForPeer(peer p2p.Peer) uint16 { +func (ids *mempoolIDs) GetForPeer(id p2pTypes.ID) uint16 { ids.mtx.RLock() defer ids.mtx.RUnlock() - return ids.peerMap[peer.ID()] + return ids.peerMap[id] } func newMempoolIDs() *mempoolIDs { return &mempoolIDs{ - peerMap: make(map[p2p.ID]uint16), + peerMap: make(map[p2pTypes.ID]uint16), activeIDs: map[uint16]struct{}{0: {}}, nextID: 1, // reserve unknownPeerID(0) for mempoolReactor.BroadcastTx } @@ -138,20 +139,20 @@ func (memR *Reactor) GetChannels() []*p2p.ChannelDescriptor { // AddPeer implements Reactor. // It starts a broadcast routine ensuring all txs are forwarded to the given peer. -func (memR *Reactor) AddPeer(peer p2p.Peer) { - memR.ids.ReserveForPeer(peer) +func (memR *Reactor) AddPeer(peer p2p.PeerConn) { + memR.ids.ReserveForPeer(peer.ID()) go memR.broadcastTxRoutine(peer) } // RemovePeer implements Reactor. -func (memR *Reactor) RemovePeer(peer p2p.Peer, reason interface{}) { - memR.ids.Reclaim(peer) +func (memR *Reactor) RemovePeer(peer p2p.PeerConn, reason interface{}) { + memR.ids.Reclaim(peer.ID()) // broadcast routine checks if peer is gone and returns } // Receive implements Reactor. // It adds any received transactions to the mempool. -func (memR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { +func (memR *Reactor) Receive(chID byte, src p2p.PeerConn, msgBytes []byte) { msg, err := memR.decodeMsg(msgBytes) if err != nil { memR.Logger.Error("Error decoding mempool message", "src", src, "chId", chID, "msg", msg, "err", err, "bytes", msgBytes) @@ -162,8 +163,8 @@ func (memR *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { switch msg := msg.(type) { case *TxMessage: - peerID := memR.ids.GetForPeer(src) - err := memR.mempool.CheckTxWithInfo(msg.Tx, nil, TxInfo{SenderID: peerID}) + mempoolID := memR.ids.GetForPeer(src.ID()) + err := memR.mempool.CheckTxWithInfo(msg.Tx, nil, TxInfo{SenderID: mempoolID}) if err != nil { memR.Logger.Info("Could not check tx", "tx", txID(msg.Tx), "err", err) } @@ -179,12 +180,12 @@ type PeerState interface { } // Send new mempool txs to peer. -func (memR *Reactor) broadcastTxRoutine(peer p2p.Peer) { +func (memR *Reactor) broadcastTxRoutine(peer p2p.PeerConn) { if !memR.config.Broadcast { return } - peerID := memR.ids.GetForPeer(peer) + mempoolID := memR.ids.GetForPeer(peer.ID()) var next *clist.CElement for { // In case of both next.NextWaitChan() and peer.Quit() are variable at the same time @@ -213,7 +214,7 @@ func (memR *Reactor) broadcastTxRoutine(peer p2p.Peer) { peerState, ok := peer.Get(types.PeerStateKey).(PeerState) if !ok { // Peer does not have a state yet. We set it in the consensus reactor, but - // when we add peer in Switch, the order we call reactors#AddPeer is + // when we add peer in MultiplexSwitch, the order we call reactors#AddPeer is // different every time due to us using a map. Sometimes other reactors // will be initialized before the consensus reactor. We should wait a few // milliseconds and retry. @@ -226,7 +227,7 @@ func (memR *Reactor) broadcastTxRoutine(peer p2p.Peer) { } // ensure peer hasn't already sent us this tx - if _, ok := memTx.senders.Load(peerID); !ok { + if _, ok := memTx.senders.Load(mempoolID); !ok { // send memTx msg := &TxMessage{Tx: memTx.tx} success := peer.Send(MempoolChannel, amino.MustMarshalAny(msg)) diff --git a/tm2/pkg/bft/mempool/reactor_test.go b/tm2/pkg/bft/mempool/reactor_test.go index e7a3c43a6b9..2d20fb252e2 100644 --- a/tm2/pkg/bft/mempool/reactor_test.go +++ b/tm2/pkg/bft/mempool/reactor_test.go @@ -1,26 +1,36 @@ package mempool import ( - "net" + "context" + "fmt" "sync" "testing" "time" "github.com/fortytw2/leaktest" - "github.com/stretchr/testify/assert" - "github.com/gnolang/gno/tm2/pkg/bft/abci/example/kvstore" memcfg "github.com/gnolang/gno/tm2/pkg/bft/mempool/config" "github.com/gnolang/gno/tm2/pkg/bft/proxy" "github.com/gnolang/gno/tm2/pkg/bft/types" "github.com/gnolang/gno/tm2/pkg/errors" + p2pTesting "github.com/gnolang/gno/tm2/pkg/internal/p2p" "github.com/gnolang/gno/tm2/pkg/log" "github.com/gnolang/gno/tm2/pkg/p2p" p2pcfg "github.com/gnolang/gno/tm2/pkg/p2p/config" - "github.com/gnolang/gno/tm2/pkg/p2p/mock" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/gnolang/gno/tm2/pkg/testutils" + "github.com/stretchr/testify/assert" ) +// testP2PConfig returns a configuration for testing the peer-to-peer layer +func testP2PConfig() *p2pcfg.P2PConfig { + cfg := p2pcfg.DefaultP2PConfig() + cfg.ListenAddress = "tcp://0.0.0.0:26656" + cfg.FlushThrottleTimeout = 10 * time.Millisecond + + return cfg +} + type peerState struct { height int64 } @@ -30,65 +40,108 @@ func (ps peerState) GetHeight() int64 { } // connect N mempool reactors through N switches -func makeAndConnectReactors(mconfig *memcfg.MempoolConfig, pconfig *p2pcfg.P2PConfig, n int) []*Reactor { - reactors := make([]*Reactor, n) - logger := log.NewNoopLogger() +func makeAndConnectReactors(t *testing.T, mconfig *memcfg.MempoolConfig, pconfig *p2pcfg.P2PConfig, n int) []*Reactor { + t.Helper() + + var ( + reactors = make([]*Reactor, n) + logger = log.NewNoopLogger() + options = make(map[int][]p2p.SwitchOption) + ) + for i := 0; i < n; i++ { app := kvstore.NewKVStoreApplication() cc := proxy.NewLocalClientCreator(app) mempool, cleanup := newMempoolWithApp(cc) defer cleanup() - reactors[i] = NewReactor(mconfig, mempool) // so we dont start the consensus states - reactors[i].SetLogger(logger.With("validator", i)) + reactor := NewReactor(mconfig, mempool) // so we dont start the consensus states + reactor.SetLogger(logger.With("validator", i)) + + options[i] = []p2p.SwitchOption{ + p2p.WithReactor("MEMPOOL", reactor), + } + + reactors[i] = reactor + } + + // "Simulate" the networking layer + ctx, cancelFn := context.WithTimeout(context.Background(), 10*time.Second) + defer cancelFn() + + cfg := p2pTesting.TestingConfig{ + Count: n, + P2PCfg: pconfig, + SwitchOptions: options, + Channels: []byte{MempoolChannel}, } - p2p.MakeConnectedSwitches(pconfig, n, func(i int, s *p2p.Switch) *p2p.Switch { - s.AddReactor("MEMPOOL", reactors[i]) - return s - }, p2p.Connect2Switches) + p2pTesting.MakeConnectedPeers(t, ctx, cfg) + return reactors } -func waitForTxsOnReactors(t *testing.T, txs types.Txs, reactors []*Reactor) { +func waitForTxsOnReactors( + t *testing.T, + txs types.Txs, + reactors []*Reactor, +) { t.Helper() - // wait for the txs in all mempools - wg := new(sync.WaitGroup) + ctx, cancelFn := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelFn() + + // Wait for the txs to propagate in all mempools + var wg sync.WaitGroup + for i, reactor := range reactors { wg.Add(1) + go func(r *Reactor, reactorIndex int) { defer wg.Done() - waitForTxsOnReactor(t, txs, r, reactorIndex) + + reapedTxs := waitForTxsOnReactor(t, ctx, len(txs), r) + + for i, tx := range txs { + assert.Equalf(t, tx, reapedTxs[i], + fmt.Sprintf( + "txs at index %d on reactor %d don't match: %v vs %v", + i, reactorIndex, + tx, + reapedTxs[i], + ), + ) + } }(reactor, i) } - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - - timer := time.After(timeout) - select { - case <-timer: - t.Fatal("Timed out waiting for txs") - case <-done: - } + wg.Wait() } -func waitForTxsOnReactor(t *testing.T, txs types.Txs, reactor *Reactor, reactorIndex int) { +func waitForTxsOnReactor( + t *testing.T, + ctx context.Context, + expectedLength int, + reactor *Reactor, +) types.Txs { t.Helper() - mempool := reactor.mempool - for mempool.Size() < len(txs) { - time.Sleep(time.Millisecond * 100) - } - - reapedTxs := mempool.ReapMaxTxs(len(txs)) - for i, tx := range txs { - assert.Equalf(t, tx, reapedTxs[i], - "txs at index %d on reactor %d don't match: %v vs %v", i, reactorIndex, tx, reapedTxs[i]) + var ( + mempool = reactor.mempool + ticker = time.NewTicker(100 * time.Millisecond) + ) + + for { + select { + case <-ctx.Done(): + t.Fatal("timed out waiting for txs") + case <-ticker.C: + if mempool.Size() < expectedLength { + continue + } + + return mempool.ReapMaxTxs(expectedLength) + } } } @@ -100,32 +153,29 @@ func ensureNoTxs(t *testing.T, reactor *Reactor, timeout time.Duration) { assert.Zero(t, reactor.mempool.Size()) } -const ( - numTxs = 1000 - timeout = 120 * time.Second // ridiculously high because CircleCI is slow -) - func TestReactorBroadcastTxMessage(t *testing.T) { t.Parallel() mconfig := memcfg.TestMempoolConfig() - pconfig := p2pcfg.TestP2PConfig() + pconfig := testP2PConfig() const N = 4 - reactors := makeAndConnectReactors(mconfig, pconfig, N) - defer func() { + reactors := makeAndConnectReactors(t, mconfig, pconfig, N) + t.Cleanup(func() { for _, r := range reactors { - r.Stop() + assert.NoError(t, r.Stop()) } - }() + }) + for _, r := range reactors { for _, peer := range r.Switch.Peers().List() { + fmt.Printf("Setting peer %s\n", peer.ID()) peer.Set(types.PeerStateKey, peerState{1}) } } // send a bunch of txs to the first reactor's mempool // and wait for them all to be received in the others - txs := checkTxs(t, reactors[0].mempool, numTxs, UnknownPeerID, true) + txs := checkTxs(t, reactors[0].mempool, 1000, UnknownPeerID, true) waitForTxsOnReactors(t, txs, reactors) } @@ -133,9 +183,9 @@ func TestReactorNoBroadcastToSender(t *testing.T) { t.Parallel() mconfig := memcfg.TestMempoolConfig() - pconfig := p2pcfg.TestP2PConfig() + pconfig := testP2PConfig() const N = 2 - reactors := makeAndConnectReactors(mconfig, pconfig, N) + reactors := makeAndConnectReactors(t, mconfig, pconfig, N) defer func() { for _, r := range reactors { r.Stop() @@ -144,7 +194,7 @@ func TestReactorNoBroadcastToSender(t *testing.T) { // send a bunch of txs to the first reactor's mempool, claiming it came from peer // ensure peer gets no txs - checkTxs(t, reactors[0].mempool, numTxs, 1, true) + checkTxs(t, reactors[0].mempool, 1000, 1, true) ensureNoTxs(t, reactors[1], 100*time.Millisecond) } @@ -158,9 +208,9 @@ func TestFlappyBroadcastTxForPeerStopsWhenPeerStops(t *testing.T) { } mconfig := memcfg.TestMempoolConfig() - pconfig := p2pcfg.TestP2PConfig() + pconfig := testP2PConfig() const N = 2 - reactors := makeAndConnectReactors(mconfig, pconfig, N) + reactors := makeAndConnectReactors(t, mconfig, pconfig, N) defer func() { for _, r := range reactors { r.Stop() @@ -186,9 +236,9 @@ func TestFlappyBroadcastTxForPeerStopsWhenReactorStops(t *testing.T) { } mconfig := memcfg.TestMempoolConfig() - pconfig := p2pcfg.TestP2PConfig() + pconfig := testP2PConfig() const N = 2 - reactors := makeAndConnectReactors(mconfig, pconfig, N) + reactors := makeAndConnectReactors(t, mconfig, pconfig, N) // stop reactors for _, r := range reactors { @@ -205,15 +255,15 @@ func TestMempoolIDsBasic(t *testing.T) { ids := newMempoolIDs() - peer := mock.NewPeer(net.IP{127, 0, 0, 1}) + id := p2pTypes.GenerateNodeKey().ID() - ids.ReserveForPeer(peer) - assert.EqualValues(t, 1, ids.GetForPeer(peer)) - ids.Reclaim(peer) + ids.ReserveForPeer(id) + assert.EqualValues(t, 1, ids.GetForPeer(id)) + ids.Reclaim(id) - ids.ReserveForPeer(peer) - assert.EqualValues(t, 2, ids.GetForPeer(peer)) - ids.Reclaim(peer) + ids.ReserveForPeer(id) + assert.EqualValues(t, 2, ids.GetForPeer(id)) + ids.Reclaim(id) } func TestMempoolIDsPanicsIfNodeRequestsOvermaxActiveIDs(t *testing.T) { @@ -227,12 +277,13 @@ func TestMempoolIDsPanicsIfNodeRequestsOvermaxActiveIDs(t *testing.T) { ids := newMempoolIDs() for i := 0; i < maxActiveIDs-1; i++ { - peer := mock.NewPeer(net.IP{127, 0, 0, 1}) - ids.ReserveForPeer(peer) + id := p2pTypes.GenerateNodeKey().ID() + ids.ReserveForPeer(id) } assert.Panics(t, func() { - peer := mock.NewPeer(net.IP{127, 0, 0, 1}) - ids.ReserveForPeer(peer) + id := p2pTypes.GenerateNodeKey().ID() + + ids.ReserveForPeer(id) }) } diff --git a/tm2/pkg/bft/node/node.go b/tm2/pkg/bft/node/node.go index e29de3dd1ae..c1afb2996fa 100644 --- a/tm2/pkg/bft/node/node.go +++ b/tm2/pkg/bft/node/node.go @@ -12,12 +12,16 @@ import ( "sync" "time" + goErrors "errors" + "github.com/gnolang/gno/tm2/pkg/bft/appconn" "github.com/gnolang/gno/tm2/pkg/bft/state/eventstore/file" + "github.com/gnolang/gno/tm2/pkg/p2p/conn" + "github.com/gnolang/gno/tm2/pkg/p2p/discovery" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/rs/cors" "github.com/gnolang/gno/tm2/pkg/amino" - abci "github.com/gnolang/gno/tm2/pkg/bft/abci/types" bc "github.com/gnolang/gno/tm2/pkg/bft/blockchain" cfg "github.com/gnolang/gno/tm2/pkg/bft/config" cs "github.com/gnolang/gno/tm2/pkg/bft/consensus" @@ -43,6 +47,23 @@ import ( verset "github.com/gnolang/gno/tm2/pkg/versionset" ) +// Reactors are hooks for the p2p module, +// to alert of connecting / disconnecting peers +const ( + mempoolReactorName = "MEMPOOL" + blockchainReactorName = "BLOCKCHAIN" + consensusReactorName = "CONSENSUS" + discoveryReactorName = "DISCOVERY" +) + +const ( + mempoolModuleName = "mempool" + blockchainModuleName = "blockchain" + consensusModuleName = "consensus" + p2pModuleName = "p2p" + discoveryModuleName = "discovery" +) + // ------------------------------------------------------------------------------ // DBContext specifies config information for loading a new DB. @@ -87,7 +108,7 @@ func DefaultNewNode( logger *slog.Logger, ) (*Node, error) { // Generate node PrivKey - nodeKey, err := p2p.LoadOrGenNodeKey(config.NodeKeyFile()) + nodeKey, err := p2pTypes.LoadOrGenNodeKey(config.NodeKeyFile()) if err != nil { return nil, err } @@ -118,30 +139,6 @@ func DefaultNewNode( // Option sets a parameter for the node. type Option func(*Node) -// CustomReactors allows you to add custom reactors (name -> p2p.Reactor) to -// the node's Switch. -// -// WARNING: using any name from the below list of the existing reactors will -// result in replacing it with the custom one. -// -// - MEMPOOL -// - BLOCKCHAIN -// - CONSENSUS -// - EVIDENCE -// - PEX -func CustomReactors(reactors map[string]p2p.Reactor) Option { - return func(n *Node) { - for name, reactor := range reactors { - if existingReactor := n.sw.Reactor(name); existingReactor != nil { - n.sw.Logger.Info("Replacing existing reactor with a custom one", - "name", name, "existing", existingReactor, "custom", reactor) - n.sw.RemoveReactor(name, existingReactor) - } - n.sw.AddReactor(name, reactor) - } - } -} - // ------------------------------------------------------------------------------ // Node is the highest level interface to a full Tendermint node. @@ -155,11 +152,12 @@ type Node struct { privValidator types.PrivValidator // local node's validator key // network - transport *p2p.MultiplexTransport - sw *p2p.Switch // p2p connections - nodeInfo p2p.NodeInfo - nodeKey *p2p.NodeKey // our node privkey - isListening bool + transport *p2p.MultiplexTransport + sw *p2p.MultiplexSwitch // p2p connections + discoveryReactor *discovery.Reactor // discovery reactor + nodeInfo p2pTypes.NodeInfo + nodeKey *p2pTypes.NodeKey // our node privkey + isListening bool // services evsw events.EventSwitch @@ -279,7 +277,7 @@ func createMempoolAndMempoolReactor(config *cfg.Config, proxyApp appconn.AppConn state.ConsensusParams.Block.MaxTxBytes, mempl.WithPreCheck(sm.TxPreCheck(state)), ) - mempoolLogger := logger.With("module", "mempool") + mempoolLogger := logger.With("module", mempoolModuleName) mempoolReactor := mempl.NewReactor(config.Mempool, mempool) mempoolReactor.SetLogger(mempoolLogger) @@ -289,16 +287,23 @@ func createMempoolAndMempoolReactor(config *cfg.Config, proxyApp appconn.AppConn return mempoolReactor, mempool } -func createBlockchainReactor(config *cfg.Config, +func createBlockchainReactor( state sm.State, blockExec *sm.BlockExecutor, blockStore *store.BlockStore, fastSync bool, + switchToConsensusFn bc.SwitchToConsensusFn, logger *slog.Logger, ) (bcReactor p2p.Reactor, err error) { - bcReactor = bc.NewBlockchainReactor(state.Copy(), blockExec, blockStore, fastSync) + bcReactor = bc.NewBlockchainReactor( + state.Copy(), + blockExec, + blockStore, + fastSync, + switchToConsensusFn, + ) - bcReactor.SetLogger(logger.With("module", "blockchain")) + bcReactor.SetLogger(logger.With("module", blockchainModuleName)) return bcReactor, nil } @@ -331,93 +336,15 @@ func createConsensusReactor(config *cfg.Config, return consensusReactor, consensusState } -func createTransport(config *cfg.Config, nodeInfo p2p.NodeInfo, nodeKey *p2p.NodeKey, proxyApp appconn.AppConns) (*p2p.MultiplexTransport, []p2p.PeerFilterFunc) { - var ( - mConnConfig = p2p.MConnConfig(config.P2P) - transport = p2p.NewMultiplexTransport(nodeInfo, *nodeKey, mConnConfig) - connFilters = []p2p.ConnFilterFunc{} - peerFilters = []p2p.PeerFilterFunc{} - ) - - if !config.P2P.AllowDuplicateIP { - connFilters = append(connFilters, p2p.ConnDuplicateIPFilter()) - } - - // Filter peers by addr or pubkey with an ABCI query. - // If the query return code is OK, add peer. - if config.FilterPeers { - connFilters = append( - connFilters, - // ABCI query for address filtering. - func(_ p2p.ConnSet, c net.Conn, _ []net.IP) error { - res, err := proxyApp.Query().QuerySync(abci.RequestQuery{ - Path: fmt.Sprintf("/p2p/filter/addr/%s", c.RemoteAddr().String()), - }) - if err != nil { - return err - } - if res.IsErr() { - return fmt.Errorf("error querying abci app: %v", res) - } - - return nil - }, - ) - - peerFilters = append( - peerFilters, - // ABCI query for ID filtering. - func(_ p2p.IPeerSet, p p2p.Peer) error { - res, err := proxyApp.Query().QuerySync(abci.RequestQuery{ - Path: fmt.Sprintf("/p2p/filter/id/%s", p.ID()), - }) - if err != nil { - return err - } - if res.IsErr() { - return fmt.Errorf("error querying abci app: %v", res) - } - - return nil - }, - ) - } - - p2p.MultiplexTransportConnFilters(connFilters...)(transport) - return transport, peerFilters -} - -func createSwitch(config *cfg.Config, - transport *p2p.MultiplexTransport, - peerFilters []p2p.PeerFilterFunc, - mempoolReactor *mempl.Reactor, - bcReactor p2p.Reactor, - consensusReactor *cs.ConsensusReactor, - nodeInfo p2p.NodeInfo, - nodeKey *p2p.NodeKey, - p2pLogger *slog.Logger, -) *p2p.Switch { - sw := p2p.NewSwitch( - config.P2P, - transport, - p2p.SwitchPeerFilters(peerFilters...), - ) - sw.SetLogger(p2pLogger) - sw.AddReactor("MEMPOOL", mempoolReactor) - sw.AddReactor("BLOCKCHAIN", bcReactor) - sw.AddReactor("CONSENSUS", consensusReactor) - - sw.SetNodeInfo(nodeInfo) - sw.SetNodeKey(nodeKey) - - p2pLogger.Info("P2P Node ID", "ID", nodeKey.ID(), "file", config.NodeKeyFile()) - return sw +type nodeReactor struct { + name string + reactor p2p.Reactor } // NewNode returns a new, ready to go, Tendermint Node. func NewNode(config *cfg.Config, privValidator types.PrivValidator, - nodeKey *p2p.NodeKey, + nodeKey *p2pTypes.NodeKey, clientCreator appconn.ClientCreator, genesisDocProvider GenesisDocProvider, dbProvider DBProvider, @@ -463,7 +390,7 @@ func NewNode(config *cfg.Config, // Create the handshaker, which calls RequestInfo, sets the AppVersion on the state, // and replays any blocks as necessary to sync tendermint with the app. - consensusLogger := logger.With("module", "consensus") + consensusLogger := logger.With("module", consensusModuleName) if err := doHandshake(stateDB, state, blockStore, genDoc, evsw, proxyApp, consensusLogger); err != nil { return nil, err } @@ -506,38 +433,103 @@ func NewNode(config *cfg.Config, mempool, ) - // Make BlockchainReactor - bcReactor, err := createBlockchainReactor(config, state, blockExec, blockStore, fastSync, logger) - if err != nil { - return nil, errors.Wrap(err, "could not create blockchain reactor") - } - // Make ConsensusReactor consensusReactor, consensusState := createConsensusReactor( config, state, blockExec, blockStore, mempool, privValidator, fastSync, evsw, consensusLogger, ) + // Make BlockchainReactor + bcReactor, err := createBlockchainReactor( + state, + blockExec, + blockStore, + fastSync, + consensusReactor.SwitchToConsensus, + logger, + ) + if err != nil { + return nil, errors.Wrap(err, "could not create blockchain reactor") + } + + reactors := []nodeReactor{ + { + mempoolReactorName, mempoolReactor, + }, + { + blockchainReactorName, bcReactor, + }, + { + consensusReactorName, consensusReactor, + }, + } + nodeInfo, err := makeNodeInfo(config, nodeKey, txEventStore, genDoc, state) if err != nil { return nil, errors.Wrap(err, "error making NodeInfo") } - // Setup Transport. - transport, peerFilters := createTransport(config, nodeInfo, nodeKey, proxyApp) + p2pLogger := logger.With("module", p2pModuleName) - // Setup Switch. - p2pLogger := logger.With("module", "p2p") - sw := createSwitch( - config, transport, peerFilters, mempoolReactor, bcReactor, - consensusReactor, nodeInfo, nodeKey, p2pLogger, + // Setup the multiplex transport, used by the P2P switch + transport := p2p.NewMultiplexTransport( + nodeInfo, + *nodeKey, + conn.MConfigFromP2P(config.P2P), + p2pLogger.With("transport", "multiplex"), ) - err = sw.AddPersistentPeers(splitAndTrimEmpty(config.P2P.PersistentPeers, ",", " ")) - if err != nil { - return nil, errors.Wrap(err, "could not add peers from persistent_peers field") + var discoveryReactor *discovery.Reactor + + if config.P2P.PeerExchange { + discoveryReactor = discovery.NewReactor() + + discoveryReactor.SetLogger(logger.With("module", discoveryModuleName)) + + reactors = append(reactors, nodeReactor{ + name: discoveryReactorName, + reactor: discoveryReactor, + }) + } + + // Setup MultiplexSwitch. + peerAddrs, errs := p2pTypes.NewNetAddressFromStrings( + splitAndTrimEmpty(config.P2P.PersistentPeers, ",", " "), + ) + for _, err = range errs { + p2pLogger.Error("invalid persistent peer address", "err", err) + } + + // Parse the private peer IDs + privatePeerIDs, errs := p2pTypes.NewIDFromStrings( + splitAndTrimEmpty(config.P2P.PrivatePeerIDs, ",", " "), + ) + for _, err = range errs { + p2pLogger.Error("invalid private peer ID", "err", err) + } + + // Prepare the misc switch options + opts := []p2p.SwitchOption{ + p2p.WithPersistentPeers(peerAddrs), + p2p.WithPrivatePeers(privatePeerIDs), + p2p.WithMaxInboundPeers(config.P2P.MaxNumInboundPeers), + p2p.WithMaxOutboundPeers(config.P2P.MaxNumOutboundPeers), + } + + // Prepare the reactor switch options + for _, r := range reactors { + opts = append(opts, p2p.WithReactor(r.name, r.reactor)) } + sw := p2p.NewMultiplexSwitch( + transport, + opts..., + ) + + sw.SetLogger(p2pLogger) + + p2pLogger.Info("P2P Node ID", "ID", nodeKey.ID(), "file", config.NodeKeyFile()) + if config.ProfListenAddress != "" { server := &http.Server{ Addr: config.ProfListenAddress, @@ -554,10 +546,11 @@ func NewNode(config *cfg.Config, genesisDoc: genDoc, privValidator: privValidator, - transport: transport, - sw: sw, - nodeInfo: nodeInfo, - nodeKey: nodeKey, + transport: transport, + sw: sw, + discoveryReactor: discoveryReactor, + nodeInfo: nodeInfo, + nodeKey: nodeKey, evsw: evsw, stateDB: stateDB, @@ -611,10 +604,16 @@ func (n *Node) OnStart() error { } // Start the transport. - addr, err := p2p.NewNetAddressFromString(p2p.NetAddressString(n.nodeKey.ID(), n.config.P2P.ListenAddress)) + lAddr := n.config.P2P.ExternalAddress + if lAddr == "" { + lAddr = n.config.P2P.ListenAddress + } + + addr, err := p2pTypes.NewNetAddressFromString(p2pTypes.NetAddressString(n.nodeKey.ID(), lAddr)) if err != nil { - return err + return fmt.Errorf("unable to parse network address, %w", err) } + if err := n.transport.Listen(*addr); err != nil { return err } @@ -639,11 +638,14 @@ func (n *Node) OnStart() error { } // Always connect to persistent peers - err = n.sw.DialPeersAsync(splitAndTrimEmpty(n.config.P2P.PersistentPeers, ",", " ")) - if err != nil { - return errors.Wrap(err, "could not dial peers from persistent_peers field") + peerAddrs, errs := p2pTypes.NewNetAddressFromStrings(splitAndTrimEmpty(n.config.P2P.PersistentPeers, ",", " ")) + for _, err := range errs { + n.Logger.Error("invalid persistent peer address", "err", err) } + // Dial the persistent peers + n.sw.DialPeers(peerAddrs...) + return nil } @@ -657,8 +659,15 @@ func (n *Node) OnStop() { n.evsw.Stop() n.eventStoreService.Stop() + // Stop the node p2p transport + if err := n.transport.Close(); err != nil { + n.Logger.Error("unable to gracefully close transport", "err", err) + } + // now stop the reactors - n.sw.Stop() + if err := n.sw.Stop(); err != nil { + n.Logger.Error("unable to gracefully close switch", "err", err) + } // stop mempool WAL if n.config.Mempool.WalEnabled() { @@ -791,7 +800,7 @@ func joinListenerAddresses(ll []net.Listener) string { } // Switch returns the Node's Switch. -func (n *Node) Switch() *p2p.Switch { +func (n *Node) Switch() *p2p.MultiplexSwitch { return n.sw } @@ -859,17 +868,17 @@ func (n *Node) IsListening() bool { } // NodeInfo returns the Node's Info from the Switch. -func (n *Node) NodeInfo() p2p.NodeInfo { +func (n *Node) NodeInfo() p2pTypes.NodeInfo { return n.nodeInfo } func makeNodeInfo( config *cfg.Config, - nodeKey *p2p.NodeKey, + nodeKey *p2pTypes.NodeKey, txEventStore eventstore.TxEventStore, genDoc *types.GenesisDoc, state sm.State, -) (p2p.NodeInfo, error) { +) (p2pTypes.NodeInfo, error) { txIndexerStatus := eventstore.StatusOff if txEventStore.GetType() != null.EventStoreType { txIndexerStatus = eventstore.StatusOn @@ -882,8 +891,9 @@ func makeNodeInfo( Version: state.AppVersion, }) - nodeInfo := p2p.NodeInfo{ + nodeInfo := p2pTypes.NodeInfo{ VersionSet: vset, + PeerID: nodeKey.ID(), Network: genDoc.ChainID, Version: version.Version, Channels: []byte{ @@ -892,24 +902,23 @@ func makeNodeInfo( mempl.MempoolChannel, }, Moniker: config.Moniker, - Other: p2p.NodeInfoOther{ + Other: p2pTypes.NodeInfoOther{ TxIndex: txIndexerStatus, RPCAddress: config.RPC.ListenAddress, }, } - lAddr := config.P2P.ExternalAddress - if lAddr == "" { - lAddr = config.P2P.ListenAddress + if config.P2P.PeerExchange { + nodeInfo.Channels = append(nodeInfo.Channels, discovery.Channel) } - addr, err := p2p.NewNetAddressFromString(p2p.NetAddressString(nodeKey.ID(), lAddr)) - if err != nil { - return nodeInfo, errors.Wrap(err, "invalid (local) node net address") + + // Validate the node info + err := nodeInfo.Validate() + if err != nil && !goErrors.Is(err, p2pTypes.ErrUnspecifiedIP) { + return p2pTypes.NodeInfo{}, fmt.Errorf("unable to validate node info, %w", err) } - nodeInfo.NetAddress = addr - err = nodeInfo.Validate() - return nodeInfo, err + return nodeInfo, nil } // ------------------------------------------------------------------------------ diff --git a/tm2/pkg/bft/node/node_test.go b/tm2/pkg/bft/node/node_test.go index 6e86a0bcc6f..1ea789d31c2 100644 --- a/tm2/pkg/bft/node/node_test.go +++ b/tm2/pkg/bft/node/node_test.go @@ -25,8 +25,6 @@ import ( "github.com/gnolang/gno/tm2/pkg/db/memdb" "github.com/gnolang/gno/tm2/pkg/events" "github.com/gnolang/gno/tm2/pkg/log" - "github.com/gnolang/gno/tm2/pkg/p2p" - p2pmock "github.com/gnolang/gno/tm2/pkg/p2p/mock" "github.com/gnolang/gno/tm2/pkg/random" ) @@ -40,8 +38,6 @@ func TestNodeStartStop(t *testing.T) { err = n.Start() require.NoError(t, err) - t.Logf("Started node %v", n.sw.NodeInfo()) - // wait for the node to produce a block blocksSub := events.SubscribeToEvent(n.EventSwitch(), "node_test", types.EventNewBlock{}) require.NoError(t, err) @@ -308,39 +304,6 @@ func TestCreateProposalBlock(t *testing.T) { assert.NoError(t, err) } -func TestNodeNewNodeCustomReactors(t *testing.T) { - config, genesisFile := cfg.ResetTestRoot("node_new_node_custom_reactors_test") - defer os.RemoveAll(config.RootDir) - - cr := p2pmock.NewReactor() - customBlockchainReactor := p2pmock.NewReactor() - - nodeKey, err := p2p.LoadOrGenNodeKey(config.NodeKeyFile()) - require.NoError(t, err) - - n, err := NewNode(config, - privval.LoadOrGenFilePV(config.PrivValidatorKeyFile(), config.PrivValidatorStateFile()), - nodeKey, - proxy.DefaultClientCreator(nil, config.ProxyApp, config.ABCI, config.DBDir()), - DefaultGenesisDocProviderFunc(genesisFile), - DefaultDBProvider, - events.NewEventSwitch(), - log.NewTestingLogger(t), - CustomReactors(map[string]p2p.Reactor{"FOO": cr, "BLOCKCHAIN": customBlockchainReactor}), - ) - require.NoError(t, err) - - err = n.Start() - require.NoError(t, err) - defer n.Stop() - - assert.True(t, cr.IsRunning()) - assert.Equal(t, cr, n.Switch().Reactor("FOO")) - - assert.True(t, customBlockchainReactor.IsRunning()) - assert.Equal(t, customBlockchainReactor, n.Switch().Reactor("BLOCKCHAIN")) -} - func state(nVals int, height int64) (sm.State, dbm.DB) { vals := make([]types.GenesisValidator, nVals) for i := 0; i < nVals; i++ { diff --git a/tm2/pkg/bft/rpc/client/batch_test.go b/tm2/pkg/bft/rpc/client/batch_test.go index 52930e5c372..fcd0f3f834d 100644 --- a/tm2/pkg/bft/rpc/client/batch_test.go +++ b/tm2/pkg/bft/rpc/client/batch_test.go @@ -10,7 +10,7 @@ import ( ctypes "github.com/gnolang/gno/tm2/pkg/bft/rpc/core/types" types "github.com/gnolang/gno/tm2/pkg/bft/rpc/lib/types" bfttypes "github.com/gnolang/gno/tm2/pkg/bft/types" - "github.com/gnolang/gno/tm2/pkg/p2p" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -116,7 +116,7 @@ func TestRPCBatch_Send(t *testing.T) { var ( numRequests = 10 expectedStatus = &ctypes.ResultStatus{ - NodeInfo: p2p.NodeInfo{ + NodeInfo: p2pTypes.NodeInfo{ Moniker: "dummy", }, } @@ -160,7 +160,7 @@ func TestRPCBatch_Endpoints(t *testing.T) { { statusMethod, &ctypes.ResultStatus{ - NodeInfo: p2p.NodeInfo{ + NodeInfo: p2pTypes.NodeInfo{ Moniker: "dummy", }, }, diff --git a/tm2/pkg/bft/rpc/client/client_test.go b/tm2/pkg/bft/rpc/client/client_test.go index cb88c91fc5f..31889f59883 100644 --- a/tm2/pkg/bft/rpc/client/client_test.go +++ b/tm2/pkg/bft/rpc/client/client_test.go @@ -14,7 +14,7 @@ import ( ctypes "github.com/gnolang/gno/tm2/pkg/bft/rpc/core/types" types "github.com/gnolang/gno/tm2/pkg/bft/rpc/lib/types" bfttypes "github.com/gnolang/gno/tm2/pkg/bft/types" - "github.com/gnolang/gno/tm2/pkg/p2p" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -114,7 +114,7 @@ func TestRPCClient_Status(t *testing.T) { var ( expectedStatus = &ctypes.ResultStatus{ - NodeInfo: p2p.NodeInfo{ + NodeInfo: p2pTypes.NodeInfo{ Moniker: "dummy", }, } @@ -811,17 +811,17 @@ func TestRPCClient_Batch(t *testing.T) { var ( expectedStatuses = []*ctypes.ResultStatus{ { - NodeInfo: p2p.NodeInfo{ + NodeInfo: p2pTypes.NodeInfo{ Moniker: "dummy", }, }, { - NodeInfo: p2p.NodeInfo{ + NodeInfo: p2pTypes.NodeInfo{ Moniker: "dummy", }, }, { - NodeInfo: p2p.NodeInfo{ + NodeInfo: p2pTypes.NodeInfo{ Moniker: "dummy", }, }, diff --git a/tm2/pkg/bft/rpc/client/e2e_test.go b/tm2/pkg/bft/rpc/client/e2e_test.go index 08d4b9b735d..358c66b0b26 100644 --- a/tm2/pkg/bft/rpc/client/e2e_test.go +++ b/tm2/pkg/bft/rpc/client/e2e_test.go @@ -13,7 +13,7 @@ import ( ctypes "github.com/gnolang/gno/tm2/pkg/bft/rpc/core/types" types "github.com/gnolang/gno/tm2/pkg/bft/rpc/lib/types" bfttypes "github.com/gnolang/gno/tm2/pkg/bft/types" - "github.com/gnolang/gno/tm2/pkg/p2p" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -170,7 +170,7 @@ func TestRPCClient_E2E_Endpoints(t *testing.T) { { statusMethod, &ctypes.ResultStatus{ - NodeInfo: p2p.NodeInfo{ + NodeInfo: p2pTypes.NodeInfo{ Moniker: "dummy", }, }, diff --git a/tm2/pkg/bft/rpc/client/local.go b/tm2/pkg/bft/rpc/client/local.go index 59c4216a468..4bc724e7d70 100644 --- a/tm2/pkg/bft/rpc/client/local.go +++ b/tm2/pkg/bft/rpc/client/local.go @@ -106,14 +106,6 @@ func (c *Local) Health() (*ctypes.ResultHealth, error) { return core.Health(c.ctx) } -func (c *Local) DialSeeds(seeds []string) (*ctypes.ResultDialSeeds, error) { - return core.UnsafeDialSeeds(c.ctx, seeds) -} - -func (c *Local) DialPeers(peers []string, persistent bool) (*ctypes.ResultDialPeers, error) { - return core.UnsafeDialPeers(c.ctx, peers, persistent) -} - func (c *Local) BlockchainInfo(minHeight, maxHeight int64) (*ctypes.ResultBlockchainInfo, error) { return core.BlockchainInfo(c.ctx, minHeight, maxHeight) } diff --git a/tm2/pkg/bft/rpc/core/net.go b/tm2/pkg/bft/rpc/core/net.go index 975d5ed822f..f8839b7d91f 100644 --- a/tm2/pkg/bft/rpc/core/net.go +++ b/tm2/pkg/bft/rpc/core/net.go @@ -3,7 +3,6 @@ package core import ( ctypes "github.com/gnolang/gno/tm2/pkg/bft/rpc/core/types" rpctypes "github.com/gnolang/gno/tm2/pkg/bft/rpc/lib/types" - "github.com/gnolang/gno/tm2/pkg/errors" ) // Get network info. @@ -154,10 +153,14 @@ import ( // } // // ``` -func NetInfo(ctx *rpctypes.Context) (*ctypes.ResultNetInfo, error) { - out, in, _ := p2pPeers.NumPeers() +func NetInfo(_ *rpctypes.Context) (*ctypes.ResultNetInfo, error) { + var ( + set = p2pPeers.Peers() + out, in = set.NumOutbound(), set.NumInbound() + ) + peers := make([]ctypes.Peer, 0, out+in) - for _, peer := range p2pPeers.Peers().List() { + for _, peer := range set.List() { nodeInfo := peer.NodeInfo() peers = append(peers, ctypes.Peer{ NodeInfo: nodeInfo, @@ -166,9 +169,7 @@ func NetInfo(ctx *rpctypes.Context) (*ctypes.ResultNetInfo, error) { RemoteIP: peer.RemoteIP().String(), }) } - // TODO: Should we include PersistentPeers and Seeds in here? - // PRO: useful info - // CON: privacy + return &ctypes.ResultNetInfo{ Listening: p2pTransport.IsListening(), Listeners: p2pTransport.Listeners(), @@ -177,33 +178,6 @@ func NetInfo(ctx *rpctypes.Context) (*ctypes.ResultNetInfo, error) { }, nil } -func UnsafeDialSeeds(ctx *rpctypes.Context, seeds []string) (*ctypes.ResultDialSeeds, error) { - if len(seeds) == 0 { - return &ctypes.ResultDialSeeds{}, errors.New("No seeds provided") - } - logger.Info("DialSeeds", "seeds", seeds) - if err := p2pPeers.DialPeersAsync(seeds); err != nil { - return &ctypes.ResultDialSeeds{}, err - } - return &ctypes.ResultDialSeeds{Log: "Dialing seeds in progress. See /net_info for details"}, nil -} - -func UnsafeDialPeers(ctx *rpctypes.Context, peers []string, persistent bool) (*ctypes.ResultDialPeers, error) { - if len(peers) == 0 { - return &ctypes.ResultDialPeers{}, errors.New("No peers provided") - } - logger.Info("DialPeers", "peers", peers, "persistent", persistent) - if persistent { - if err := p2pPeers.AddPersistentPeers(peers); err != nil { - return &ctypes.ResultDialPeers{}, err - } - } - if err := p2pPeers.DialPeersAsync(peers); err != nil { - return &ctypes.ResultDialPeers{}, err - } - return &ctypes.ResultDialPeers{Log: "Dialing peers in progress. See /net_info for details"}, nil -} - // Get genesis file. // // ```shell diff --git a/tm2/pkg/bft/rpc/core/net_test.go b/tm2/pkg/bft/rpc/core/net_test.go deleted file mode 100644 index 3273837b6ce..00000000000 --- a/tm2/pkg/bft/rpc/core/net_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package core - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - rpctypes "github.com/gnolang/gno/tm2/pkg/bft/rpc/lib/types" - "github.com/gnolang/gno/tm2/pkg/log" - "github.com/gnolang/gno/tm2/pkg/p2p" - p2pcfg "github.com/gnolang/gno/tm2/pkg/p2p/config" -) - -func TestUnsafeDialSeeds(t *testing.T) { - t.Parallel() - - sw := p2p.MakeSwitch(p2pcfg.DefaultP2PConfig(), 1, "testing", "123.123.123", - func(n int, sw *p2p.Switch) *p2p.Switch { return sw }) - err := sw.Start() - require.NoError(t, err) - defer sw.Stop() - - logger = log.NewNoopLogger() - p2pPeers = sw - - testCases := []struct { - seeds []string - isErr bool - }{ - {[]string{}, true}, - {[]string{"g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:41198"}, false}, - {[]string{"127.0.0.1:41198"}, true}, - } - - for _, tc := range testCases { - res, err := UnsafeDialSeeds(&rpctypes.Context{}, tc.seeds) - if tc.isErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.NotNil(t, res) - } - } -} - -func TestUnsafeDialPeers(t *testing.T) { - t.Parallel() - - sw := p2p.MakeSwitch(p2pcfg.DefaultP2PConfig(), 1, "testing", "123.123.123", - func(n int, sw *p2p.Switch) *p2p.Switch { return sw }) - err := sw.Start() - require.NoError(t, err) - defer sw.Stop() - - logger = log.NewNoopLogger() - p2pPeers = sw - - testCases := []struct { - peers []string - isErr bool - }{ - {[]string{}, true}, - {[]string{"g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:41198"}, false}, - {[]string{"127.0.0.1:41198"}, true}, - } - - for _, tc := range testCases { - res, err := UnsafeDialPeers(&rpctypes.Context{}, tc.peers, false) - if tc.isErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.NotNil(t, res) - } - } -} diff --git a/tm2/pkg/bft/rpc/core/pipe.go b/tm2/pkg/bft/rpc/core/pipe.go index 9493e7c5873..085fc35da55 100644 --- a/tm2/pkg/bft/rpc/core/pipe.go +++ b/tm2/pkg/bft/rpc/core/pipe.go @@ -15,6 +15,7 @@ import ( dbm "github.com/gnolang/gno/tm2/pkg/db" "github.com/gnolang/gno/tm2/pkg/events" "github.com/gnolang/gno/tm2/pkg/p2p" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" ) const ( @@ -38,14 +39,11 @@ type Consensus interface { type transport interface { Listeners() []string IsListening() bool - NodeInfo() p2p.NodeInfo + NodeInfo() p2pTypes.NodeInfo } type peers interface { - AddPersistentPeers([]string) error - DialPeersAsync([]string) error - NumPeers() (outbound, inbound, dialig int) - Peers() p2p.IPeerSet + Peers() p2p.PeerSet } // ---------------------------------------------- diff --git a/tm2/pkg/bft/rpc/core/routes.go b/tm2/pkg/bft/rpc/core/routes.go index 8d210f67985..76217a7cbd9 100644 --- a/tm2/pkg/bft/rpc/core/routes.go +++ b/tm2/pkg/bft/rpc/core/routes.go @@ -36,8 +36,6 @@ var Routes = map[string]*rpc.RPCFunc{ func AddUnsafeRoutes() { // control API - Routes["dial_seeds"] = rpc.NewRPCFunc(UnsafeDialSeeds, "seeds") - Routes["dial_peers"] = rpc.NewRPCFunc(UnsafeDialPeers, "peers,persistent") Routes["unsafe_flush_mempool"] = rpc.NewRPCFunc(UnsafeFlushMempool, "") // profiler API diff --git a/tm2/pkg/bft/rpc/core/types/responses.go b/tm2/pkg/bft/rpc/core/types/responses.go index 2874517147d..76474867b27 100644 --- a/tm2/pkg/bft/rpc/core/types/responses.go +++ b/tm2/pkg/bft/rpc/core/types/responses.go @@ -11,6 +11,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/bft/types" "github.com/gnolang/gno/tm2/pkg/crypto" "github.com/gnolang/gno/tm2/pkg/p2p" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" ) // List of blocks @@ -74,9 +75,9 @@ type ValidatorInfo struct { // Node Status type ResultStatus struct { - NodeInfo p2p.NodeInfo `json:"node_info"` - SyncInfo SyncInfo `json:"sync_info"` - ValidatorInfo ValidatorInfo `json:"validator_info"` + NodeInfo p2pTypes.NodeInfo `json:"node_info"` + SyncInfo SyncInfo `json:"sync_info"` + ValidatorInfo ValidatorInfo `json:"validator_info"` } // Is TxIndexing enabled @@ -107,7 +108,7 @@ type ResultDialPeers struct { // A peer type Peer struct { - NodeInfo p2p.NodeInfo `json:"node_info"` + NodeInfo p2pTypes.NodeInfo `json:"node_info"` IsOutbound bool `json:"is_outbound"` ConnectionStatus p2p.ConnectionStatus `json:"connection_status"` RemoteIP string `json:"remote_ip"` diff --git a/tm2/pkg/bft/rpc/core/types/responses_test.go b/tm2/pkg/bft/rpc/core/types/responses_test.go index 268a8d25c34..7d03addc546 100644 --- a/tm2/pkg/bft/rpc/core/types/responses_test.go +++ b/tm2/pkg/bft/rpc/core/types/responses_test.go @@ -3,9 +3,8 @@ package core_types import ( "testing" + "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/stretchr/testify/assert" - - "github.com/gnolang/gno/tm2/pkg/p2p" ) func TestStatusIndexer(t *testing.T) { @@ -17,17 +16,17 @@ func TestStatusIndexer(t *testing.T) { status = &ResultStatus{} assert.False(t, status.TxIndexEnabled()) - status.NodeInfo = p2p.NodeInfo{} + status.NodeInfo = types.NodeInfo{} assert.False(t, status.TxIndexEnabled()) cases := []struct { expected bool - other p2p.NodeInfoOther + other types.NodeInfoOther }{ - {false, p2p.NodeInfoOther{}}, - {false, p2p.NodeInfoOther{TxIndex: "aa"}}, - {false, p2p.NodeInfoOther{TxIndex: "off"}}, - {true, p2p.NodeInfoOther{TxIndex: "on"}}, + {false, types.NodeInfoOther{}}, + {false, types.NodeInfoOther{TxIndex: "aa"}}, + {false, types.NodeInfoOther{TxIndex: "off"}}, + {true, types.NodeInfoOther{TxIndex: "on"}}, } for _, tc := range cases { diff --git a/tm2/pkg/bft/rpc/lib/client/http/client.go b/tm2/pkg/bft/rpc/lib/client/http/client.go index aa4fc5c5392..288eca57300 100644 --- a/tm2/pkg/bft/rpc/lib/client/http/client.go +++ b/tm2/pkg/bft/rpc/lib/client/http/client.go @@ -166,14 +166,14 @@ func defaultHTTPClient(remoteAddr string) *http.Client { Transport: &http.Transport{ // Set to true to prevent GZIP-bomb DoS attacks DisableCompression: true, - DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { - return makeHTTPDialer(remoteAddr)(network, addr) + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return makeHTTPDialer(remoteAddr) }, }, } } -func makeHTTPDialer(remoteAddr string) func(string, string) (net.Conn, error) { +func makeHTTPDialer(remoteAddr string) (net.Conn, error) { protocol, address := parseRemoteAddr(remoteAddr) // net.Dial doesn't understand http/https, so change it to TCP @@ -182,9 +182,7 @@ func makeHTTPDialer(remoteAddr string) func(string, string) (net.Conn, error) { protocol = protoTCP } - return func(proto, addr string) (net.Conn, error) { - return net.Dial(protocol, address) - } + return net.Dial(protocol, address) } // protocol - client's protocol (for example, "http", "https", "wss", "ws", "tcp") diff --git a/tm2/pkg/bft/rpc/lib/client/http/client_test.go b/tm2/pkg/bft/rpc/lib/client/http/client_test.go index 4ccbfdc2d1e..0d88ee32650 100644 --- a/tm2/pkg/bft/rpc/lib/client/http/client_test.go +++ b/tm2/pkg/bft/rpc/lib/client/http/client_test.go @@ -76,7 +76,7 @@ func TestClient_makeHTTPDialer(t *testing.T) { t.Run("http", func(t *testing.T) { t.Parallel() - _, err := makeHTTPDialer("https://.")("hello", "world") + _, err := makeHTTPDialer("https://.") require.Error(t, err) assert.Contains(t, err.Error(), "dial tcp:", "should convert https to tcp") @@ -85,7 +85,7 @@ func TestClient_makeHTTPDialer(t *testing.T) { t.Run("udp", func(t *testing.T) { t.Parallel() - _, err := makeHTTPDialer("udp://.")("hello", "world") + _, err := makeHTTPDialer("udp://.") require.Error(t, err) assert.Contains(t, err.Error(), "dial udp:", "udp protocol should remain the same") diff --git a/tm2/pkg/bft/rpc/lib/server/handlers.go b/tm2/pkg/bft/rpc/lib/server/handlers.go index 88ee26da4a9..9e10596a975 100644 --- a/tm2/pkg/bft/rpc/lib/server/handlers.go +++ b/tm2/pkg/bft/rpc/lib/server/handlers.go @@ -939,7 +939,7 @@ func writeListOfEndpoints(w http.ResponseWriter, r *http.Request, funcMap map[st for _, name := range noArgNames { link := fmt.Sprintf("//%s/%s", r.Host, name) - buf.WriteString(fmt.Sprintf("%s
", link, link)) + fmt.Fprintf(buf, "%s
", link, link) } buf.WriteString("
Endpoints that require arguments:
") @@ -952,7 +952,7 @@ func writeListOfEndpoints(w http.ResponseWriter, r *http.Request, funcMap map[st link += "&" } } - buf.WriteString(fmt.Sprintf("%s
", link, link)) + fmt.Fprintf(buf, "%s
", link, link) } buf.WriteString("") w.Header().Set("Content-Type", "text/html") diff --git a/tm2/pkg/bft/rpc/lib/server/write_endpoints_test.go b/tm2/pkg/bft/rpc/lib/server/write_endpoints_test.go new file mode 100644 index 00000000000..b01144f9273 --- /dev/null +++ b/tm2/pkg/bft/rpc/lib/server/write_endpoints_test.go @@ -0,0 +1,33 @@ +package rpcserver + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + + types "github.com/gnolang/gno/tm2/pkg/bft/rpc/lib/types" +) + +func TestWriteListOfEndpoints(t *testing.T) { + funcMap := map[string]*RPCFunc{ + "c": NewWSRPCFunc(func(ctx *types.Context, s string, i int) (string, error) { return "foo", nil }, "s,i"), + "d": {}, + } + + req, _ := http.NewRequest("GET", "http://localhost/", nil) + rec := httptest.NewRecorder() + writeListOfEndpoints(rec, req, funcMap) + res := rec.Result() + assert.Equal(t, res.StatusCode, 200, "Should always return 200") + blob, err := io.ReadAll(res.Body) + assert.NoError(t, err) + gotResp := string(blob) + wantResp := `
Available endpoints:
//localhost/d

Endpoints that require arguments:
//localhost/c?s=_&i=_
` + if diff := cmp.Diff(gotResp, wantResp); diff != "" { + t.Fatalf("Mismatch response: got - want +\n%s", diff) + } +} diff --git a/tm2/pkg/crypto/crypto.go b/tm2/pkg/crypto/crypto.go index 7757b75354e..7908a082d3b 100644 --- a/tm2/pkg/crypto/crypto.go +++ b/tm2/pkg/crypto/crypto.go @@ -3,6 +3,7 @@ package crypto import ( "bytes" "encoding/json" + "errors" "fmt" "github.com/gnolang/gno/tm2/pkg/bech32" @@ -128,6 +129,8 @@ func (addr *Address) DecodeString(str string) error { // ---------------------------------------- // ID +var ErrZeroID = errors.New("address ID is zero") + // The bech32 representation w/ bech32 prefix. type ID string @@ -141,16 +144,12 @@ func (id ID) String() string { func (id ID) Validate() error { if id.IsZero() { - return fmt.Errorf("zero ID is invalid") + return ErrZeroID } + var addr Address - err := addr.DecodeID(id) - return err -} -func AddressFromID(id ID) (addr Address, err error) { - err = addr.DecodeString(string(id)) - return + return addr.DecodeID(id) } func (addr Address) ID() ID { diff --git a/tm2/pkg/crypto/ed25519/ed25519.go b/tm2/pkg/crypto/ed25519/ed25519.go index 8976994986c..f8b9529b788 100644 --- a/tm2/pkg/crypto/ed25519/ed25519.go +++ b/tm2/pkg/crypto/ed25519/ed25519.go @@ -68,11 +68,9 @@ func (privKey PrivKeyEd25519) PubKey() crypto.PubKey { // Equals - you probably don't need to use this. // Runs in constant time based on length of the keys. func (privKey PrivKeyEd25519) Equals(other crypto.PrivKey) bool { - if otherEd, ok := other.(PrivKeyEd25519); ok { - return subtle.ConstantTimeCompare(privKey[:], otherEd[:]) == 1 - } else { - return false - } + otherEd, ok := other.(PrivKeyEd25519) + + return ok && subtle.ConstantTimeCompare(privKey[:], otherEd[:]) == 1 } // GenPrivKey generates a new ed25519 private key. diff --git a/tm2/pkg/internal/p2p/p2p.go b/tm2/pkg/internal/p2p/p2p.go new file mode 100644 index 00000000000..1e650e0cd25 --- /dev/null +++ b/tm2/pkg/internal/p2p/p2p.go @@ -0,0 +1,255 @@ +// Package p2p contains testing code that is moved over, and adapted from p2p/test_utils.go. +// This isn't a good way to simulate the networking layer in TM2 modules. +// It actually isn't a good way to simulate the networking layer, in anything. +// +// Code is carried over to keep the testing code of p2p-dependent modules happy +// and "working". We should delete this entire package the second TM2 module unit tests don't +// need to rely on a live p2p cluster to pass. +package p2p + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "net" + "testing" + "time" + + "github.com/gnolang/gno/tm2/pkg/log" + "github.com/gnolang/gno/tm2/pkg/p2p" + p2pcfg "github.com/gnolang/gno/tm2/pkg/p2p/config" + "github.com/gnolang/gno/tm2/pkg/p2p/conn" + "github.com/gnolang/gno/tm2/pkg/p2p/events" + p2pTypes "github.com/gnolang/gno/tm2/pkg/p2p/types" + "github.com/gnolang/gno/tm2/pkg/service" + "github.com/gnolang/gno/tm2/pkg/versionset" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +// TestingConfig is the P2P cluster testing config +type TestingConfig struct { + P2PCfg *p2pcfg.P2PConfig // the common p2p configuration + Count int // the size of the cluster + SwitchOptions map[int][]p2p.SwitchOption // multiplex switch options + Channels []byte // the common p2p peer multiplex channels +} + +// MakeConnectedPeers creates a cluster of peers, with the given options. +// Used to simulate the networking layer for a TM2 module +func MakeConnectedPeers( + t *testing.T, + ctx context.Context, + cfg TestingConfig, +) ([]*p2p.MultiplexSwitch, []*p2p.MultiplexTransport) { + t.Helper() + + // Initialize collections for switches, transports, and addresses. + var ( + sws = make([]*p2p.MultiplexSwitch, 0, cfg.Count) + ts = make([]*p2p.MultiplexTransport, 0, cfg.Count) + addrs = make([]*p2pTypes.NetAddress, 0, cfg.Count) + ) + + createTransport := func(index int) *p2p.MultiplexTransport { + // Generate a fresh key + key := p2pTypes.GenerateNodeKey() + + addr, err := p2pTypes.NewNetAddress( + key.ID(), + &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, // random free port + }, + ) + require.NoError(t, err) + + info := p2pTypes.NodeInfo{ + VersionSet: versionset.VersionSet{ + versionset.VersionInfo{Name: "p2p", Version: "v0.0.0"}, + }, + PeerID: key.ID(), + Network: "testing", + Software: "p2ptest", + Version: "v1.2.3-rc.0-deadbeef", + Channels: cfg.Channels, + Moniker: fmt.Sprintf("node-%d", index), + Other: p2pTypes.NodeInfoOther{ + TxIndex: "off", + RPCAddress: fmt.Sprintf("127.0.0.1:%d", 0), + }, + } + + transport := p2p.NewMultiplexTransport( + info, + *key, + conn.MConfigFromP2P(cfg.P2PCfg), + log.NewNoopLogger(), + ) + + require.NoError(t, transport.Listen(*addr)) + t.Cleanup(func() { assert.NoError(t, transport.Close()) }) + + return transport + } + + // Create transports and gather addresses + for i := 0; i < cfg.Count; i++ { + transport := createTransport(i) + addr := transport.NetAddress() + + addrs = append(addrs, &addr) + ts = append(ts, transport) + } + + // Connect switches and ensure all peers are connected + connectPeers := func(switchIndex int) error { + multiplexSwitch := p2p.NewMultiplexSwitch( + ts[switchIndex], + cfg.SwitchOptions[switchIndex]..., + ) + + ch, unsubFn := multiplexSwitch.Subscribe(func(event events.Event) bool { + return event.Type() == events.PeerConnected + }) + defer unsubFn() + + // Start the switch + require.NoError(t, multiplexSwitch.Start()) + + // Save it + sws = append(sws, multiplexSwitch) + + if cfg.Count == 1 { + // No peers to dial, switch is alone + return nil + } + + // Async dial the other peers + multiplexSwitch.DialPeers(addrs...) + + // Set up an exit timer + timer := time.NewTimer(1 * time.Minute) + defer timer.Stop() + + var ( + connectedPeers = make(map[p2pTypes.ID]struct{}) + targetPeers = cfg.Count - 1 + ) + + for { + select { + case evRaw := <-ch: + ev := evRaw.(events.PeerConnectedEvent) + + connectedPeers[ev.PeerID] = struct{}{} + + if len(connectedPeers) == targetPeers { + return nil + } + case <-timer.C: + return errors.New("timed out waiting for peer switches to connect") + } + } + } + + g, _ := errgroup.WithContext(ctx) + for i := 0; i < cfg.Count; i++ { + g.Go(func() error { return connectPeers(i) }) + } + + require.NoError(t, g.Wait()) + + return sws, ts +} + +// createRoutableAddr generates a valid, routable NetAddress for the given node ID using a secure random IP +func createRoutableAddr(t *testing.T, id p2pTypes.ID) *p2pTypes.NetAddress { + t.Helper() + + generateIP := func() string { + ip := make([]byte, 4) + + _, err := rand.Read(ip) + require.NoError(t, err) + + return fmt.Sprintf("%d.%d.%d.%d", ip[0], ip[1], ip[2], ip[3]) + } + + for { + addrStr := fmt.Sprintf("%s@%s:26656", id, generateIP()) + + netAddr, err := p2pTypes.NewNetAddressFromString(addrStr) + require.NoError(t, err) + + if netAddr.Routable() { + return netAddr + } + } +} + +// Peer is a live peer, utilized for testing purposes. +// This Peer implementation is NOT thread safe +type Peer struct { + *service.BaseService + ip net.IP + id p2pTypes.ID + addr *p2pTypes.NetAddress + kv map[string]any + + Outbound, Persistent, Private bool +} + +// NewPeer creates and starts a new mock peer. +// It generates a new routable address for the peer +func NewPeer(t *testing.T) *Peer { + t.Helper() + + var ( + nodeKey = p2pTypes.GenerateNodeKey() + netAddr = createRoutableAddr(t, nodeKey.ID()) + ) + + mp := &Peer{ + ip: netAddr.IP, + id: nodeKey.ID(), + addr: netAddr, + kv: make(map[string]any), + } + + mp.BaseService = service.NewBaseService(nil, "MockPeer", mp) + + require.NoError(t, mp.Start()) + + return mp +} + +func (mp *Peer) FlushStop() { mp.Stop() } +func (mp *Peer) TrySend(_ byte, _ []byte) bool { return true } +func (mp *Peer) Send(_ byte, _ []byte) bool { return true } +func (mp *Peer) NodeInfo() p2pTypes.NodeInfo { + return p2pTypes.NodeInfo{ + PeerID: mp.id, + } +} +func (mp *Peer) Status() conn.ConnectionStatus { return conn.ConnectionStatus{} } +func (mp *Peer) ID() p2pTypes.ID { return mp.id } +func (mp *Peer) IsOutbound() bool { return mp.Outbound } +func (mp *Peer) IsPersistent() bool { return mp.Persistent } +func (mp *Peer) IsPrivate() bool { return mp.Private } +func (mp *Peer) Get(key string) interface{} { + if value, ok := mp.kv[key]; ok { + return value + } + return nil +} + +func (mp *Peer) Set(key string, value interface{}) { + mp.kv[key] = value +} +func (mp *Peer) RemoteIP() net.IP { return mp.ip } +func (mp *Peer) SocketAddr() *p2pTypes.NetAddress { return mp.addr } +func (mp *Peer) RemoteAddr() net.Addr { return &net.TCPAddr{IP: mp.ip, Port: 8800} } +func (mp *Peer) CloseConn() error { return nil } diff --git a/tm2/pkg/overflow/README.md b/tm2/pkg/overflow/README.md index 55a9ba4c327..26ba7dc9985 100644 --- a/tm2/pkg/overflow/README.md +++ b/tm2/pkg/overflow/README.md @@ -2,26 +2,25 @@ Check for int/int8/int16/int64/int32 integer overflow in Golang arithmetic. -Forked from https://github.com/JohnCGriffin/overflow +Originally forked from https://github.com/JohnCGriffin/overflow. ### Install -``` -go get github.com/johncgriffin/overflow -``` -Note that because Go has no template types, the majority of repetitive code is -generated by overflow_template.sh. If you have to change an -algorithm, change it there and regenerate the Go code via: -``` + +The majority of repetitive code is generated by overflow_template.sh. If you +have to change an algorithm, change it there and regenerate the Go code via: + +```sh go generate ``` + ### Synopsis -``` +```go package main import "fmt" import "math" -import "github.com/JohnCGriffin/overflow" +import "github.com/gnolang/gno/tm2/pkg/overflow" func main() { @@ -29,38 +28,33 @@ func main() { for i := 0; i < 10; i++ { sum, ok := overflow.Add(addend, i) - fmt.Printf("%v+%v -> (%v,%v)\n", + fmt.Printf("%v+%v -> (%v, %v)\n", addend, i, sum, ok) } } ``` + yields the output -``` -9223372036854775802+0 -> (9223372036854775802,true) -9223372036854775802+1 -> (9223372036854775803,true) -9223372036854775802+2 -> (9223372036854775804,true) -9223372036854775802+3 -> (9223372036854775805,true) -9223372036854775802+4 -> (9223372036854775806,true) -9223372036854775802+5 -> (9223372036854775807,true) -9223372036854775802+6 -> (0,false) -9223372036854775802+7 -> (0,false) -9223372036854775802+8 -> (0,false) -9223372036854775802+9 -> (0,false) + +```console +9223372036854775802+0 -> (9223372036854775802, true) +9223372036854775802+1 -> (9223372036854775803, true) +9223372036854775802+2 -> (9223372036854775804, true) +9223372036854775802+3 -> (9223372036854775805, true) +9223372036854775802+4 -> (9223372036854775806, true) +9223372036854775802+5 -> (9223372036854775807, true) +9223372036854775802+6 -> (0, false) +9223372036854775802+7 -> (0, false) +9223372036854775802+8 -> (0, false) +9223372036854775802+9 -> (0, false) ``` For int, int64, and int32 types, provide Add, Add32, Add64, Sub, Sub32, Sub64, etc. -Unsigned types not covered at the moment, but such additions are welcome. ### Stay calm and panic -There's a good case to be made that a panic is an unidiomatic but proper response. Iff you -believe that there's no valid way to continue your program after math goes wayward, you can -use the easier Addp, Mulp, Subp, and Divp versions which return the normal result or panic. - - - - - - - +There's a good case to be made that a panic is an unidiomatic but proper +response. If you believe that there's no valid way to continue your program +after math goes wayward, you can use the easier Addp, Mulp, Subp, and Divp +versions which return the normal result or panic. diff --git a/tm2/pkg/overflow/overflow_impl.go b/tm2/pkg/overflow/overflow_impl.go index a9a90c43835..0f057f65387 100644 --- a/tm2/pkg/overflow/overflow_impl.go +++ b/tm2/pkg/overflow/overflow_impl.go @@ -1,10 +1,8 @@ package overflow -// This is generated code, created by overflow_template.sh executed -// by "go generate" +// Code generated by overflow_template.sh from 'go generate'. DO NOT EDIT. -// Add8 performs + operation on two int8 operands -// returning a result and status +// Add8 performs + operation on two int8 operands, returning a result and status. func Add8(a, b int8) (int8, bool) { c := a + b if (c > a) == (b > 0) { @@ -13,7 +11,7 @@ func Add8(a, b int8) (int8, bool) { return c, false } -// Add8p is the unchecked panicing version of Add8 +// Add8p is the unchecked panicing version of Add8. func Add8p(a, b int8) int8 { r, ok := Add8(a, b) if !ok { @@ -22,8 +20,7 @@ func Add8p(a, b int8) int8 { return r } -// Sub8 performs - operation on two int8 operands -// returning a result and status +// Sub8 performs - operation on two int8 operands, returning a result and status. func Sub8(a, b int8) (int8, bool) { c := a - b if (c < a) == (b > 0) { @@ -32,7 +29,7 @@ func Sub8(a, b int8) (int8, bool) { return c, false } -// Sub8p is the unchecked panicing version of Sub8 +// Sub8p is the unchecked panicing version of Sub8. func Sub8p(a, b int8) int8 { r, ok := Sub8(a, b) if !ok { @@ -41,8 +38,7 @@ func Sub8p(a, b int8) int8 { return r } -// Mul8 performs * operation on two int8 operands -// returning a result and status +// Mul8 performs * operation on two int8 operands returning a result and status. func Mul8(a, b int8) (int8, bool) { if a == 0 || b == 0 { return 0, true @@ -56,7 +52,7 @@ func Mul8(a, b int8) (int8, bool) { return c, false } -// Mul8p is the unchecked panicing version of Mul8 +// Mul8p is the unchecked panicing version of Mul8. func Mul8p(a, b int8) int8 { r, ok := Mul8(a, b) if !ok { @@ -65,14 +61,13 @@ func Mul8p(a, b int8) int8 { return r } -// Div8 performs / operation on two int8 operands -// returning a result and status +// Div8 performs / operation on two int8 operands, returning a result and status. func Div8(a, b int8) (int8, bool) { q, _, ok := Quotient8(a, b) return q, ok } -// Div8p is the unchecked panicing version of Div8 +// Div8p is the unchecked panicing version of Div8. func Div8p(a, b int8) int8 { r, ok := Div8(a, b) if !ok { @@ -81,19 +76,19 @@ func Div8p(a, b int8) int8 { return r } -// Quotient8 performs + operation on two int8 operands -// returning a quotient, a remainder and status +// Quotient8 performs / operation on two int8 operands, returning a quotient, +// a remainder and status. func Quotient8(a, b int8) (int8, int8, bool) { if b == 0 { return 0, 0, false } c := a / b - status := (c < 0) == ((a < 0) != (b < 0)) - return c, a % b, status + status := (c < 0) == ((a < 0) != (b < 0)) || (c == 0) // no sign check for 0 quotient + return c, a%b, status } -// Add16 performs + operation on two int16 operands -// returning a result and status + +// Add16 performs + operation on two int16 operands, returning a result and status. func Add16(a, b int16) (int16, bool) { c := a + b if (c > a) == (b > 0) { @@ -102,7 +97,7 @@ func Add16(a, b int16) (int16, bool) { return c, false } -// Add16p is the unchecked panicing version of Add16 +// Add16p is the unchecked panicing version of Add16. func Add16p(a, b int16) int16 { r, ok := Add16(a, b) if !ok { @@ -111,8 +106,7 @@ func Add16p(a, b int16) int16 { return r } -// Sub16 performs - operation on two int16 operands -// returning a result and status +// Sub16 performs - operation on two int16 operands, returning a result and status. func Sub16(a, b int16) (int16, bool) { c := a - b if (c < a) == (b > 0) { @@ -121,7 +115,7 @@ func Sub16(a, b int16) (int16, bool) { return c, false } -// Sub16p is the unchecked panicing version of Sub16 +// Sub16p is the unchecked panicing version of Sub16. func Sub16p(a, b int16) int16 { r, ok := Sub16(a, b) if !ok { @@ -130,8 +124,7 @@ func Sub16p(a, b int16) int16 { return r } -// Mul16 performs * operation on two int16 operands -// returning a result and status +// Mul16 performs * operation on two int16 operands returning a result and status. func Mul16(a, b int16) (int16, bool) { if a == 0 || b == 0 { return 0, true @@ -145,7 +138,7 @@ func Mul16(a, b int16) (int16, bool) { return c, false } -// Mul16p is the unchecked panicing version of Mul16 +// Mul16p is the unchecked panicing version of Mul16. func Mul16p(a, b int16) int16 { r, ok := Mul16(a, b) if !ok { @@ -154,14 +147,13 @@ func Mul16p(a, b int16) int16 { return r } -// Div16 performs / operation on two int16 operands -// returning a result and status +// Div16 performs / operation on two int16 operands, returning a result and status. func Div16(a, b int16) (int16, bool) { q, _, ok := Quotient16(a, b) return q, ok } -// Div16p is the unchecked panicing version of Div16 +// Div16p is the unchecked panicing version of Div16. func Div16p(a, b int16) int16 { r, ok := Div16(a, b) if !ok { @@ -170,19 +162,19 @@ func Div16p(a, b int16) int16 { return r } -// Quotient16 performs + operation on two int16 operands -// returning a quotient, a remainder and status +// Quotient16 performs / operation on two int16 operands, returning a quotient, +// a remainder and status. func Quotient16(a, b int16) (int16, int16, bool) { if b == 0 { return 0, 0, false } c := a / b - status := (c < 0) == ((a < 0) != (b < 0)) - return c, a % b, status + status := (c < 0) == ((a < 0) != (b < 0)) || (c == 0) // no sign check for 0 quotient + return c, a%b, status } -// Add32 performs + operation on two int32 operands -// returning a result and status + +// Add32 performs + operation on two int32 operands, returning a result and status. func Add32(a, b int32) (int32, bool) { c := a + b if (c > a) == (b > 0) { @@ -191,7 +183,7 @@ func Add32(a, b int32) (int32, bool) { return c, false } -// Add32p is the unchecked panicing version of Add32 +// Add32p is the unchecked panicing version of Add32. func Add32p(a, b int32) int32 { r, ok := Add32(a, b) if !ok { @@ -200,8 +192,7 @@ func Add32p(a, b int32) int32 { return r } -// Sub32 performs - operation on two int32 operands -// returning a result and status +// Sub32 performs - operation on two int32 operands, returning a result and status. func Sub32(a, b int32) (int32, bool) { c := a - b if (c < a) == (b > 0) { @@ -210,7 +201,7 @@ func Sub32(a, b int32) (int32, bool) { return c, false } -// Sub32p is the unchecked panicing version of Sub32 +// Sub32p is the unchecked panicing version of Sub32. func Sub32p(a, b int32) int32 { r, ok := Sub32(a, b) if !ok { @@ -219,8 +210,7 @@ func Sub32p(a, b int32) int32 { return r } -// Mul32 performs * operation on two int32 operands -// returning a result and status +// Mul32 performs * operation on two int32 operands returning a result and status. func Mul32(a, b int32) (int32, bool) { if a == 0 || b == 0 { return 0, true @@ -234,7 +224,7 @@ func Mul32(a, b int32) (int32, bool) { return c, false } -// Mul32p is the unchecked panicing version of Mul32 +// Mul32p is the unchecked panicing version of Mul32. func Mul32p(a, b int32) int32 { r, ok := Mul32(a, b) if !ok { @@ -243,14 +233,13 @@ func Mul32p(a, b int32) int32 { return r } -// Div32 performs / operation on two int32 operands -// returning a result and status +// Div32 performs / operation on two int32 operands, returning a result and status. func Div32(a, b int32) (int32, bool) { q, _, ok := Quotient32(a, b) return q, ok } -// Div32p is the unchecked panicing version of Div32 +// Div32p is the unchecked panicing version of Div32. func Div32p(a, b int32) int32 { r, ok := Div32(a, b) if !ok { @@ -259,19 +248,19 @@ func Div32p(a, b int32) int32 { return r } -// Quotient32 performs + operation on two int32 operands -// returning a quotient, a remainder and status +// Quotient32 performs / operation on two int32 operands, returning a quotient, +// a remainder and status. func Quotient32(a, b int32) (int32, int32, bool) { if b == 0 { return 0, 0, false } c := a / b - status := (c < 0) == ((a < 0) != (b < 0)) - return c, a % b, status + status := (c < 0) == ((a < 0) != (b < 0)) || (c == 0) // no sign check for 0 quotient + return c, a%b, status } -// Add64 performs + operation on two int64 operands -// returning a result and status + +// Add64 performs + operation on two int64 operands, returning a result and status. func Add64(a, b int64) (int64, bool) { c := a + b if (c > a) == (b > 0) { @@ -280,7 +269,7 @@ func Add64(a, b int64) (int64, bool) { return c, false } -// Add64p is the unchecked panicing version of Add64 +// Add64p is the unchecked panicing version of Add64. func Add64p(a, b int64) int64 { r, ok := Add64(a, b) if !ok { @@ -289,8 +278,7 @@ func Add64p(a, b int64) int64 { return r } -// Sub64 performs - operation on two int64 operands -// returning a result and status +// Sub64 performs - operation on two int64 operands, returning a result and status. func Sub64(a, b int64) (int64, bool) { c := a - b if (c < a) == (b > 0) { @@ -299,7 +287,7 @@ func Sub64(a, b int64) (int64, bool) { return c, false } -// Sub64p is the unchecked panicing version of Sub64 +// Sub64p is the unchecked panicing version of Sub64. func Sub64p(a, b int64) int64 { r, ok := Sub64(a, b) if !ok { @@ -308,8 +296,7 @@ func Sub64p(a, b int64) int64 { return r } -// Mul64 performs * operation on two int64 operands -// returning a result and status +// Mul64 performs * operation on two int64 operands returning a result and status. func Mul64(a, b int64) (int64, bool) { if a == 0 || b == 0 { return 0, true @@ -323,7 +310,7 @@ func Mul64(a, b int64) (int64, bool) { return c, false } -// Mul64p is the unchecked panicing version of Mul64 +// Mul64p is the unchecked panicing version of Mul64. func Mul64p(a, b int64) int64 { r, ok := Mul64(a, b) if !ok { @@ -332,14 +319,13 @@ func Mul64p(a, b int64) int64 { return r } -// Div64 performs / operation on two int64 operands -// returning a result and status +// Div64 performs / operation on two int64 operands, returning a result and status. func Div64(a, b int64) (int64, bool) { q, _, ok := Quotient64(a, b) return q, ok } -// Div64p is the unchecked panicing version of Div64 +// Div64p is the unchecked panicing version of Div64. func Div64p(a, b int64) int64 { r, ok := Div64(a, b) if !ok { @@ -348,13 +334,14 @@ func Div64p(a, b int64) int64 { return r } -// Quotient64 performs + operation on two int64 operands -// returning a quotient, a remainder and status +// Quotient64 performs / operation on two int64 operands, returning a quotient, +// a remainder and status. func Quotient64(a, b int64) (int64, int64, bool) { if b == 0 { return 0, 0, false } c := a / b - status := (c < 0) == ((a < 0) != (b < 0)) - return c, a % b, status + status := (c < 0) == ((a < 0) != (b < 0)) || (c == 0) // no sign check for 0 quotient + return c, a%b, status } + diff --git a/tm2/pkg/overflow/overflow_template.sh b/tm2/pkg/overflow/overflow_template.sh index a2a85f2c581..0cc3c9595bf 100755 --- a/tm2/pkg/overflow/overflow_template.sh +++ b/tm2/pkg/overflow/overflow_template.sh @@ -4,109 +4,94 @@ exec > overflow_impl.go echo "package overflow -// This is generated code, created by overflow_template.sh executed -// by \"go generate\" - -" - +// Code generated by overflow_template.sh from 'go generate'. DO NOT EDIT." for SIZE in 8 16 32 64 do -echo " - -// Add${SIZE} performs + operation on two int${SIZE} operands -// returning a result and status + echo " +// Add${SIZE} performs + operation on two int${SIZE} operands, returning a result and status. func Add${SIZE}(a, b int${SIZE}) (int${SIZE}, bool) { - c := a + b - if (c > a) == (b > 0) { - return c, true - } - return c, false + c := a + b + if (c > a) == (b > 0) { + return c, true + } + return c, false } -// Add${SIZE}p is the unchecked panicing version of Add${SIZE} +// Add${SIZE}p is the unchecked panicing version of Add${SIZE}. func Add${SIZE}p(a, b int${SIZE}) int${SIZE} { - r, ok := Add${SIZE}(a, b) - if !ok { - panic(\"addition overflow\") - } - return r + r, ok := Add${SIZE}(a, b) + if !ok { + panic(\"addition overflow\") + } + return r } - -// Sub${SIZE} performs - operation on two int${SIZE} operands -// returning a result and status +// Sub${SIZE} performs - operation on two int${SIZE} operands, returning a result and status. func Sub${SIZE}(a, b int${SIZE}) (int${SIZE}, bool) { - c := a - b - if (c < a) == (b > 0) { - return c, true - } - return c, false + c := a - b + if (c < a) == (b > 0) { + return c, true + } + return c, false } -// Sub${SIZE}p is the unchecked panicing version of Sub${SIZE} +// Sub${SIZE}p is the unchecked panicing version of Sub${SIZE}. func Sub${SIZE}p(a, b int${SIZE}) int${SIZE} { - r, ok := Sub${SIZE}(a, b) - if !ok { - panic(\"subtraction overflow\") - } - return r + r, ok := Sub${SIZE}(a, b) + if !ok { + panic(\"subtraction overflow\") + } + return r } - -// Mul${SIZE} performs * operation on two int${SIZE} operands -// returning a result and status +// Mul${SIZE} performs * operation on two int${SIZE} operands returning a result and status. func Mul${SIZE}(a, b int${SIZE}) (int${SIZE}, bool) { - if a == 0 || b == 0 { - return 0, true - } - c := a * b - if (c < 0) == ((a < 0) != (b < 0)) { - if c/b == a { - return c, true - } - } - return c, false + if a == 0 || b == 0 { + return 0, true + } + c := a * b + if (c < 0) == ((a < 0) != (b < 0)) { + if c/b == a { + return c, true + } + } + return c, false } -// Mul${SIZE}p is the unchecked panicing version of Mul${SIZE} +// Mul${SIZE}p is the unchecked panicing version of Mul${SIZE}. func Mul${SIZE}p(a, b int${SIZE}) int${SIZE} { - r, ok := Mul${SIZE}(a, b) - if !ok { - panic(\"multiplication overflow\") - } - return r + r, ok := Mul${SIZE}(a, b) + if !ok { + panic(\"multiplication overflow\") + } + return r } - - -// Div${SIZE} performs / operation on two int${SIZE} operands -// returning a result and status +// Div${SIZE} performs / operation on two int${SIZE} operands, returning a result and status. func Div${SIZE}(a, b int${SIZE}) (int${SIZE}, bool) { - q, _, ok := Quotient${SIZE}(a, b) - return q, ok + q, _, ok := Quotient${SIZE}(a, b) + return q, ok } -// Div${SIZE}p is the unchecked panicing version of Div${SIZE} +// Div${SIZE}p is the unchecked panicing version of Div${SIZE}. func Div${SIZE}p(a, b int${SIZE}) int${SIZE} { - r, ok := Div${SIZE}(a, b) - if !ok { - panic(\"division failure\") - } - return r + r, ok := Div${SIZE}(a, b) + if !ok { + panic(\"division failure\") + } + return r } -// Quotient${SIZE} performs + operation on two int${SIZE} operands -// returning a quotient, a remainder and status +// Quotient${SIZE} performs / operation on two int${SIZE} operands, returning a quotient, +// a remainder and status. func Quotient${SIZE}(a, b int${SIZE}) (int${SIZE}, int${SIZE}, bool) { - if b == 0 { - return 0, 0, false - } - c := a / b - status := (c < 0) == ((a < 0) != (b < 0)) - return c, a % b, status + if b == 0 { + return 0, 0, false + } + c := a / b + status := (c < 0) == ((a < 0) != (b < 0)) || (c == 0) // no sign check for 0 quotient + return c, a%b, status } " done - -go run -modfile ../../../misc/devdeps/go.mod mvdan.cc/gofumpt -w overflow_impl.go diff --git a/tm2/pkg/overflow/overflow_test.go b/tm2/pkg/overflow/overflow_test.go index 2b2d345b55d..e6327c9e862 100644 --- a/tm2/pkg/overflow/overflow_test.go +++ b/tm2/pkg/overflow/overflow_test.go @@ -28,8 +28,7 @@ func TestAlgorithms(t *testing.T) { // now the verification result, ok := Add8(a8, b8) if ok && int64(result) != r64 { - t.Errorf("failed to fail on %v + %v = %v instead of %v\n", - a8, b8, result, r64) + t.Errorf("failed to fail on %v + %v = %v instead of %v\n", a8, b8, result, r64) errors++ } if !ok && int64(result) == r64 { @@ -45,8 +44,7 @@ func TestAlgorithms(t *testing.T) { // now the verification result, ok := Sub8(a8, b8) if ok && int64(result) != r64 { - t.Errorf("failed to fail on %v - %v = %v instead of %v\n", - a8, b8, result, r64) + t.Errorf("failed to fail on %v - %v = %v instead of %v\n", a8, b8, result, r64) } if !ok && int64(result) == r64 { t.Fail() @@ -61,8 +59,7 @@ func TestAlgorithms(t *testing.T) { // now the verification result, ok := Mul8(a8, b8) if ok && int64(result) != r64 { - t.Errorf("failed to fail on %v * %v = %v instead of %v\n", - a8, b8, result, r64) + t.Errorf("failed to fail on %v * %v = %v instead of %v\n", a8, b8, result, r64) errors++ } if !ok && int64(result) == r64 { @@ -78,11 +75,10 @@ func TestAlgorithms(t *testing.T) { // now the verification result, _, ok := Quotient8(a8, b8) if ok && int64(result) != r64 { - t.Errorf("failed to fail on %v / %v = %v instead of %v\n", - a8, b8, result, r64) + t.Errorf("failed to fail on %v / %v = %v instead of %v\n", a8, b8, result, r64) errors++ } - if !ok && result != 0 && int64(result) == r64 { + if !ok && int64(result) == r64 { t.Fail() errors++ } diff --git a/tm2/pkg/p2p/README.md b/tm2/pkg/p2p/README.md index 81888403e1c..f64fafb53e0 100644 --- a/tm2/pkg/p2p/README.md +++ b/tm2/pkg/p2p/README.md @@ -1,4 +1,604 @@ -# p2p +## Overview -The p2p package provides an abstraction around peer-to-peer communication. +The `p2p` package, and its “sub-packages” contain the required building blocks for Tendermint2’s networking layer. +This document aims to explain the `p2p` terminology, and better document the way the `p2p` module works within the TM2 +ecosystem, especially in relation to other modules like `consensus`, `blockchain` and `mempool`. + +## Common Types + +To fully understand the `Concepts` section of the `p2p` documentation, there must be at least a basic understanding of +the terminology of the `p2p` module, because there are types that keep popping up constantly, and it’s worth +understanding what they’re about. + +### `NetAddress` + +```go +package types + +// NetAddress defines information about a peer on the network +// including its ID, IP address, and port +type NetAddress struct { + ID ID `json:"id"` // unique peer identifier (public key address) + IP net.IP `json:"ip"` // the IP part of the dial address + Port uint16 `json:"port"` // the port part of the dial address +} +``` + +A `NetAddress` is simply a wrapper for a unique peer in the network. + +This address consists of several parts: + +- the peer’s ID, derived from the peer’s public key (it’s the address). +- the peer’s dial address, used for executing TCP dials. + +### `ID` + +```go +// ID represents the cryptographically unique Peer ID +type ID = crypto.ID +``` + +The peer ID is the unique peer identifier. It is used for unambiguously resolving who a peer is, during communication. + +The reason the peer ID is utilized is because it is derived from the peer’s public key, used to encrypt communication, +and it needs to match the public key used in p2p communication. It can, and should be, considered unique. + +### `Reactor` + +Without going too much into detail in the terminology section, as a much more detailed explanation is discussed below: + +A `Reactor` is an abstraction of a Tendermint2 module, that needs to utilize the `p2p` layer. + +Currently active reactors in TM2, that utilize the p2p layer: + +- the consensus reactor, that handles consensus message passing +- the blockchain reactor, that handles block syncing +- the mempool reactor, that handles transaction gossiping + +All of these functionalities require a live p2p network to work, and `Reactor`s are the answer for how they can be aware +of things happening in the network (like new peers joining, for example). + +## Concepts + +### Peer + +`Peer` is an abstraction over a p2p connection that is: + +- **verified**, meaning it went through the handshaking process and the information the other peer shared checked out ( + this process is discussed in detail later). +- **multiplexed over TCP** (the only kind of p2p connections TM2 supports). + +```go +package p2p + +// Peer is a wrapper for a connected peer +type Peer interface { + service.Service + + FlushStop() + + ID() types.ID // peer's cryptographic ID + RemoteIP() net.IP // remote IP of the connection + RemoteAddr() net.Addr // remote address of the connection + + IsOutbound() bool // did we dial the peer + IsPersistent() bool // do we redial this peer when we disconnect + IsPrivate() bool // do we share the peer + + CloseConn() error // close original connection + + NodeInfo() types.NodeInfo // peer's info + Status() ConnectionStatus + SocketAddr() *types.NetAddress // actual address of the socket + + Send(byte, []byte) bool + TrySend(byte, []byte) bool + + Set(string, any) + Get(string) any +} +``` + +There are more than a few things to break down here, so let’s tackle them individually. + +The `Peer` abstraction holds callbacks relating to information about the actual live peer connection, such as what kind +of direction it is, what is the connection status, and others. + +```go +package p2p + +type Peer interface { + // ... + + ID() types.ID // peer's cryptographic ID + RemoteIP() net.IP // remote IP of the connection + RemoteAddr() net.Addr // remote address of the connection + + NodeInfo() types.NodeInfo // peer's info + Status() ConnectionStatus + SocketAddr() *types.NetAddress // actual address of the socket + + IsOutbound() bool // did we dial the peer + IsPersistent() bool // do we redial this peer when we disconnect + IsPrivate() bool // do we share the peer + + // ... +} + +``` + +However, there is part of the `Peer` abstraction that outlines the flipped design of the entire `p2p` module, and a +severe limitation of this implementation. + +```go +package p2p + +type Peer interface { + // ... + + Send(byte, []byte) bool + TrySend(byte, []byte) bool + + // ... +} +``` + +The `Peer` abstraction is used internally in `p2p`, but also by other modules that need to interact with the networking +layer — this is in itself the biggest crux of the current `p2p` implementation: modules *need to understand* how to use +and communicate with peers, regardless of the protocol logic. Networking is not an abstraction for the modules, but a +spec requirement. What this essentially means is there is heavy implementation leaking to parts of the TM2 codebase that +shouldn’t need to know how to handle individual peer broadcasts, or how to trigger custom protocol communication (like +syncing for example). +If `module A` wants to broadcast something to the peer network of the node, it needs to do something like this: + +```go +package main + +func main() { + // ... + + peers := sw.Peers().List() // fetch the peer list + + for _, p := range peers { + p.Send(...) // directly message the peer (imitate broadcast) + } + + // ... +} +``` + +An additional odd choice in the `Peer` API is the ability to use the peer as a KV store: + +```go +package p2p + +type Peer interface { + // ... + + Set(string, any) + Get(string) any + + // ... +} +``` + +For example, these methods are used within the `consensus` and `mempool` modules to keep track of active peer states ( +like current HRS data, or current peer mempool metadata). Instead of the module handling individual peer state, this +responsibility is shifted to the peer implementation, causing an odd code dependency situation. + +The root of this “flipped” design (modules needing to understand how to interact with peers) stems from the fact that +peers are instantiated with a multiplex TCP connection under the hood, and basically just wrap that connection. The +`Peer` API is an abstraction for the multiplexed TCP connection, under the hood. + +Changing this dependency stew would require a much larger rewrite of not just the `p2p` module, but other modules ( +`consensus`, `blockchain`, `mempool`) as well, and is as such left as-is. + +### Switch + +In short, a `Switch` is just the middleware layer that handles module <> `Transport` requests, and manages peers on a +high application level (that the `Transport` doesn’t concern itself with). + +The `Switch` is the entity that manages active peer connections. + +```go +package p2p + +// Switch is the abstraction in the p2p module that handles +// and manages peer connections thorough a Transport +type Switch interface { + // Broadcast publishes data on the given channel, to all peers + Broadcast(chID byte, data []byte) + + // Peers returns the latest peer set + Peers() PeerSet + + // Subscribe subscribes to active switch events + Subscribe(filterFn events.EventFilter) (<-chan events.Event, func()) + + // StopPeerForError stops the peer with the given reason + StopPeerForError(peer Peer, err error) + + // DialPeers marks the given peers as ready for async dialing + DialPeers(peerAddrs ...*types.NetAddress) +} + +``` + +The API of the `Switch` is relatively straightforward. Users of the `Switch` instantiate it with a `Transport`, and +utilize it as-is. + +The API of the `Switch` is geared towards asynchronicity, and as such users of the `Switch` need to adapt to some +limitations, such as not having synchronous dials, or synchronous broadcasts. + +#### Services + +There are 3 services that run on top of the `MultiplexSwitch`, upon startup: + +- **the accept service** +- **the dial service** +- **the redial service** + +```go +package p2p + +// OnStart implements BaseService. It starts all the reactors and peers. +func (sw *MultiplexSwitch) OnStart() error { + // Start reactors + for _, reactor := range sw.reactors { + if err := reactor.Start(); err != nil { + return fmt.Errorf("unable to start reactor %w", err) + } + } + + // Run the peer accept routine. + // The accept routine asynchronously accepts + // and processes incoming peer connections + go sw.runAcceptLoop(sw.ctx) + + // Run the dial routine. + // The dial routine parses items in the dial queue + // and initiates outbound peer connections + go sw.runDialLoop(sw.ctx) + + // Run the redial routine. + // The redial routine monitors for important + // peer disconnects, and attempts to reconnect + // to them + go sw.runRedialLoop(sw.ctx) + + return nil +} +``` + +##### Accept Service + +The `MultiplexSwitch` needs to actively listen for incoming connections, and handle them accordingly. These situations +occur when a peer *Dials* (more on this later) another peer, and wants to establish a connection. This connection is +outbound for one peer, and inbound for the other. + +Depending on what kind of security policies or configuration the peer has in place, the connection can be accepted, or +rejected for a number of reasons: + +- the maximum number of inbound peers is reached +- the multiplex connection fails upon startup (rare) + +The `Switch` relies on the `Transport` to return a **verified and valid** peer connection. After the `Transport` +delivers, the `Switch` makes sure having the peer makes sense, given the p2p configuration of the node. + +```go +package p2p + +func (sw *MultiplexSwitch) runAcceptLoop(ctx context.Context) { + // ... + + p, err := sw.transport.Accept(ctx, sw.peerBehavior) + if err != nil { + sw.Logger.Error( + "error encountered during peer connection accept", + "err", err, + ) + + continue + } + + // Ignore connection if we already have enough peers. + if in := sw.Peers().NumInbound(); in >= sw.maxInboundPeers { + sw.Logger.Info( + "Ignoring inbound connection: already have enough inbound peers", + "address", p.SocketAddr(), + "have", in, + "max", sw.maxInboundPeers, + ) + + sw.transport.Remove(p) + + continue + } + + // ... +} + +``` + +In fact, this is the central point in the relationship between the `Switch` and `Transport`. +The `Transport` is responsible for establishing the connection, and the `Switch` is responsible for handling it after +it’s been established. + +When TM2 modules communicate with the `p2p` module, they communicate *with the `Switch`, not the `Transport`* to execute +peer-related actions. + +##### Dial Service + +Peers are dialed asynchronously in the `Switch`, as is suggested by the `Switch` API: + +```go +DialPeers(peerAddrs ...*types.NetAddress) +``` + +The `MultiplexSwitch` implementation utilizes a concept called a *dial queue*. + +A dial queue is a priority-based queue (sorted by dial time, ascending) from which dial requests are taken out of and +executed in the form of peer dialing (through the `Transport`, of course). + +The queue needs to be sorted by the dial time, since there are asynchronous dial requests that need to be executed as +soon as possible, while others can wait to be executed up until a certain point in time. + +```go +package p2p + +func (sw *MultiplexSwitch) runDialLoop(ctx context.Context) { + // ... + + // Grab a dial item + item := sw.dialQueue.Peek() + if item == nil { + // Nothing to dial + continue + } + + // Check if the dial time is right + // for the item + if time.Now().Before(item.Time) { + // Nothing to dial + continue + } + + // Pop the item from the dial queue + item = sw.dialQueue.Pop() + + // Dial the peer + sw.Logger.Info( + "dialing peer", + "address", item.Address.String(), + ) + + // ... +} +``` + +To follow the outcomes of dial requests, users of the `Switch` can subscribe to peer events (more on this later). + +##### Redial Service + +The TM2 `p2p` module has a concept of something called *persistent peers*. + +Persistent peers are specific peers whose connections must be preserved, at all costs. They are specified in the +top-level node P2P configuration, under `p2p.persistent_peers`. + +These peer connections are special, as they don’t adhere to high-level configuration limits like the maximum peer cap, +instead, they are monitored and handled actively. + +A good candidate for a persistent peer is a bootnode, that bootstraps and facilitates peer discovery for the network. + +If a persistent peer connection is lost for whatever reason (for ex, the peer disconnects), the redial service of the +`MultiplexSwitch` will create a dial request for the dial service, and attempt to re-establish the lost connection. + +```go +package p2p + +func (sw *MultiplexSwitch) runRedialLoop(ctx context.Context) { + // ... + + var ( + peers = sw.Peers() + peersToDial = make([]*types.NetAddress, 0) + ) + + sw.persistentPeers.Range(func(key, value any) bool { + var ( + id = key.(types.ID) + addr = value.(*types.NetAddress) + ) + + // Check if the peer is part of the peer set + // or is scheduled for dialing + if peers.Has(id) || sw.dialQueue.Has(addr) { + return true + } + + peersToDial = append(peersToDial, addr) + + return true + }) + + if len(peersToDial) == 0 { + // No persistent peers are missing + return + } + + // Add the peers to the dial queue + sw.DialPeers(peersToDial...) + + // ... +} +``` + +#### Events + +The `Switch` is meant to be asynchronous. + +This means that processes like dialing peers, removing peers, doing broadcasts and more, is not a synchronous blocking +process for the `Switch` user. + +To be able to tap into the outcome of these asynchronous events, the `Switch` utilizes a simple event system, based on +event filters. + +```go +package main + +func main() { + // ... + + // Subscribe to live switch events + ch, unsubFn := multiplexSwitch.Subscribe(func(event events.Event) bool { + // This subscription will only return "PeerConnected" events + return event.Type() == events.PeerConnected + }) + + defer unsubFn() // removes the subscription + + select { + // Events are sent to the channel as soon as + // they appear and pass the subscription filter + case ev <- ch: + e := ev.(*events.PeerConnectedEvent) + // use event data... + case <-ctx.Done(): + // ... + } + + // ... +} +``` + +An event setup like this is useful for example when the user of the `Switch` wants to capture successful peer dial +events, in realtime. + +#### What is “peer behavior”? + +```go +package p2p + +// PeerBehavior wraps the Reactor and MultiplexSwitch information a Transport would need when +// dialing or accepting new Peer connections. +// It is worth noting that the only reason why this information is required in the first place, +// is because Peers expose an API through which different TM modules can interact with them. +// In the future™, modules should not directly "Send" anything to Peers, but instead communicate through +// other mediums, such as the P2P module +type PeerBehavior interface { + // ReactorChDescriptors returns the Reactor channel descriptors + ReactorChDescriptors() []*conn.ChannelDescriptor + + // Reactors returns the node's active p2p Reactors (modules) + Reactors() map[byte]Reactor + + // HandlePeerError propagates a peer connection error for further processing + HandlePeerError(Peer, error) + + // IsPersistentPeer returns a flag indicating if the given peer is persistent + IsPersistentPeer(types.ID) bool + + // IsPrivatePeer returns a flag indicating if the given peer is private + IsPrivatePeer(types.ID) bool +} + +``` + +In short, the previously-mentioned crux of the `p2p` implementation (having `Peer`s be directly managed by different TM2 +modules) requires information on how to behave when interacting with other peers. + +TM2 modules on `peer A` communicate through something called *channels* to the same modules on `peer B`. For example, if +the `mempool` module on `peer A` wants to share a transaction to the mempool module on `peer B`, it will utilize a +dedicated (and unique!) channel for it (ex. `0x30`). This is a protocol that lives on top of the already-established +multiplexed connection, and metadata relating to it is passed down through *peer behavior*. + +### Transport + +As previously mentioned, the `Transport` is the infrastructure layer of the `p2p` module. + +In contrast to the `Switch`, which is concerned with higher-level application logic (like the number of peers, peer +limits, etc), the `Transport` is concerned with actually establishing and maintaining peer connections on a much lower +level. + +```go +package p2p + +// Transport handles peer dialing and connection acceptance. Additionally, +// it is also responsible for any custom connection mechanisms (like handshaking). +// Peers returned by the transport are considered to be verified and sound +type Transport interface { + // NetAddress returns the Transport's dial address + NetAddress() types.NetAddress + + // Accept returns a newly connected inbound peer + Accept(context.Context, PeerBehavior) (Peer, error) + + // Dial dials a peer, and returns it + Dial(context.Context, types.NetAddress, PeerBehavior) (Peer, error) + + // Remove drops any resources associated + // with the Peer in the transport + Remove(Peer) +} +``` + +When peers dial other peers in TM2, they are in fact dialing their `Transport`s, and the connection is being handled +here. + +- `Accept` waits for an **incoming** connection, parses it and returns it. +- `Dial` attempts to establish an **outgoing** connection, parses it and returns it. + +There are a few important steps that happen when establishing a p2p connection in TM2, between 2 different peers: + +1. The peers go through a handshaking process, and establish something called a *secret connection*. The handshaking + process is based on the [STS protocol](https://github.com/tendermint/tendermint/blob/0.1/docs/sts-final.pdf), and + after it is completed successfully, all communication between the 2 peers is **encrypted**. +2. After establishing a secret connection, the peers exchange their respective node information. The purpose of this + step is to verify that the peers are indeed compatible with each other, and should be establishing a connection in + the first place (same network, common protocols , etc). +3. Once the secret connection is established, and the node information is exchanged, the connection to the peer is + considered valid and verified — it can now be used by the `Switch` (accepted, or rejected, based on `Switch` + high-level constraints). Note the distinction here that the `Transport` establishes and maintains the connection, but + it can ultimately be scraped by the `Switch` at any point in time. + +### Peer Discovery + +There is a final service that runs alongside the previously-mentioned `Switch` services — peer discovery. + +Every blockchain node needs an adequate amount of peers to communicate with, in order to ensure smooth functioning. For +validator nodes, they need to be *loosely connected* to at least 2/3+ of the validator set in order to participate and +not cause block misses or mis-votes (loosely connected means that there always exists a path between different peers in +the network topology, that allows them to be reachable to each other). + +The peer discovery service ensures that the given node is always learning more about the overall network topology, and +filling out any empty connection slots (outbound peers). + +This background service works in the following (albeit primitive) way: + +1. At specific intervals, `node A` checks its peer table, and picks a random peer `P`, from the active peer list. +2. When `P` is picked, `node A` initiates a discovery protocol process, in which: + - `node A` sends a request to peer `P` for his peer list (max 30 peers) + - peer `P` responds to the request + +3. Once `node A` has the peer list from `P`, it adds the entire peer list into the dial queue, to establish outbound + peer connections. + +This process repeats at specific intervals. It is worth nothing that if the limit of outbound peers is reached, the peer +dials have no effect. + +#### Bootnodes (Seeds) + +Bootnodes are specialized network nodes that play a critical role in the initial peer discovery process for new nodes +joining the network. + +When a blockchain client starts, it needs to find and connect to other nodes to synchronize data and participate in the +network. Bootnodes provide a predefined list of accessible and reliable entry points that act as a gateway for +discovering other active nodes (through peer discovery). + +These nodes are provided as part of the node’s p2p configuration. Once connected to a bootnode, the client uses peer +discovery to discover and connect to additional peers, enabling full participation and unlocking other client +protocols (consensus, mempool…). + +Bootnodes usually do not store the full blockchain or participate in consensus; their primary role is to facilitate +connectivity in the network (act as a peer relay). \ No newline at end of file diff --git a/tm2/pkg/p2p/base_reactor.go b/tm2/pkg/p2p/base_reactor.go index 91b3981d109..35b03f73be4 100644 --- a/tm2/pkg/p2p/base_reactor.go +++ b/tm2/pkg/p2p/base_reactor.go @@ -6,17 +6,17 @@ import ( ) // Reactor is responsible for handling incoming messages on one or more -// Channel. Switch calls GetChannels when reactor is added to it. When a new +// Channel. MultiplexSwitch calls GetChannels when reactor is added to it. When a new // peer joins our node, InitPeer and AddPeer are called. RemovePeer is called // when the peer is stopped. Receive is called when a message is received on a // channel associated with this reactor. // -// Peer#Send or Peer#TrySend should be used to send the message to a peer. +// PeerConn#Send or PeerConn#TrySend should be used to send the message to a peer. type Reactor interface { service.Service // Start, Stop // SetSwitch allows setting a switch. - SetSwitch(*Switch) + SetSwitch(Switch) // GetChannels returns the list of MConnection.ChannelDescriptor. Make sure // that each ID is unique across all the reactors added to the switch. @@ -28,15 +28,15 @@ type Reactor interface { // NOTE: The switch won't call AddPeer nor RemovePeer if it fails to start // the peer. Do not store any data associated with the peer in the reactor // itself unless you don't want to have a state, which is never cleaned up. - InitPeer(peer Peer) Peer + InitPeer(peer PeerConn) PeerConn // AddPeer is called by the switch after the peer is added and successfully // started. Use it to start goroutines communicating with the peer. - AddPeer(peer Peer) + AddPeer(peer PeerConn) // RemovePeer is called by the switch when the peer is stopped (due to error // or other reason). - RemovePeer(peer Peer, reason interface{}) + RemovePeer(peer PeerConn, reason interface{}) // Receive is called by the switch when msgBytes is received from the peer. // @@ -44,14 +44,14 @@ type Reactor interface { // copying. // // CONTRACT: msgBytes are not nil. - Receive(chID byte, peer Peer, msgBytes []byte) + Receive(chID byte, peer PeerConn, msgBytes []byte) } -//-------------------------------------- +// -------------------------------------- type BaseReactor struct { - service.BaseService // Provides Start, Stop, .Quit - Switch *Switch + service.BaseService // Provides Start, Stop, Quit + Switch Switch } func NewBaseReactor(name string, impl Reactor) *BaseReactor { @@ -61,11 +61,11 @@ func NewBaseReactor(name string, impl Reactor) *BaseReactor { } } -func (br *BaseReactor) SetSwitch(sw *Switch) { +func (br *BaseReactor) SetSwitch(sw Switch) { br.Switch = sw } -func (*BaseReactor) GetChannels() []*conn.ChannelDescriptor { return nil } -func (*BaseReactor) AddPeer(peer Peer) {} -func (*BaseReactor) RemovePeer(peer Peer, reason interface{}) {} -func (*BaseReactor) Receive(chID byte, peer Peer, msgBytes []byte) {} -func (*BaseReactor) InitPeer(peer Peer) Peer { return peer } +func (*BaseReactor) GetChannels() []*conn.ChannelDescriptor { return nil } +func (*BaseReactor) AddPeer(_ PeerConn) {} +func (*BaseReactor) RemovePeer(_ PeerConn, _ any) {} +func (*BaseReactor) Receive(_ byte, _ PeerConn, _ []byte) {} +func (*BaseReactor) InitPeer(peer PeerConn) PeerConn { return peer } diff --git a/tm2/pkg/p2p/cmd/stest/main.go b/tm2/pkg/p2p/cmd/stest/main.go deleted file mode 100644 index 2835e0cc1f0..00000000000 --- a/tm2/pkg/p2p/cmd/stest/main.go +++ /dev/null @@ -1,86 +0,0 @@ -package main - -import ( - "bufio" - "flag" - "fmt" - "net" - "os" - - "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" - p2pconn "github.com/gnolang/gno/tm2/pkg/p2p/conn" -) - -var ( - remote string - listen string -) - -func init() { - flag.StringVar(&listen, "listen", "", "set to :port if server, eg :8080") - flag.StringVar(&remote, "remote", "", "remote ip:port") - flag.Parse() -} - -func main() { - if listen != "" { - fmt.Println("listening at", listen) - ln, err := net.Listen("tcp", listen) - if err != nil { - // handle error - } - conn, err := ln.Accept() - if err != nil { - panic(err) - } - handleConnection(conn) - } else { - // connect to remote. - if remote == "" { - panic("must specify remote ip:port unless server") - } - fmt.Println("connecting to", remote) - conn, err := net.Dial("tcp", remote) - if err != nil { - panic(err) - } - handleConnection(conn) - } -} - -func handleConnection(conn net.Conn) { - priv := ed25519.GenPrivKey() - pub := priv.PubKey() - fmt.Println("local pubkey:", pub) - fmt.Println("local pubkey addr:", pub.Address()) - - sconn, err := p2pconn.MakeSecretConnection(conn, priv) - if err != nil { - panic(err) - } - // Read line from sconn and print. - go func() { - sc := bufio.NewScanner(sconn) - for sc.Scan() { - line := sc.Text() // GET the line string - fmt.Println(">>", line) - } - if err := sc.Err(); err != nil { - panic(err) - } - }() - // Read line from stdin and write. - for { - sc := bufio.NewScanner(os.Stdin) - for sc.Scan() { - line := sc.Text() + "\n" - _, err := sconn.Write([]byte(line)) - if err != nil { - panic(err) - } - } - if err := sc.Err(); err != nil { - panic(err) - } - } -} diff --git a/tm2/pkg/p2p/config/config.go b/tm2/pkg/p2p/config/config.go index 07692145fee..380dadc4f6f 100644 --- a/tm2/pkg/p2p/config/config.go +++ b/tm2/pkg/p2p/config/config.go @@ -1,19 +1,15 @@ package config import ( + "errors" "time" - - "github.com/gnolang/gno/tm2/pkg/errors" ) -// ----------------------------------------------------------------------------- -// P2PConfig - -const ( - // FuzzModeDrop is a mode in which we randomly drop reads/writes, connections or sleep - FuzzModeDrop = iota - // FuzzModeDelay is a mode in which we randomly sleep - FuzzModeDelay +var ( + ErrInvalidFlushThrottleTimeout = errors.New("invalid flush throttle timeout") + ErrInvalidMaxPayloadSize = errors.New("invalid message payload size") + ErrInvalidSendRate = errors.New("invalid packet send rate") + ErrInvalidReceiveRate = errors.New("invalid packet receive rate") ) // P2PConfig defines the configuration options for the Tendermint peer-to-peer networking layer @@ -32,14 +28,11 @@ type P2PConfig struct { // Comma separated list of nodes to keep persistent connections to PersistentPeers string `json:"persistent_peers" toml:"persistent_peers" comment:"Comma separated list of nodes to keep persistent connections to"` - // UPNP port forwarding - UPNP bool `json:"upnp" toml:"upnp" comment:"UPNP port forwarding"` - // Maximum number of inbound peers - MaxNumInboundPeers int `json:"max_num_inbound_peers" toml:"max_num_inbound_peers" comment:"Maximum number of inbound peers"` + MaxNumInboundPeers uint64 `json:"max_num_inbound_peers" toml:"max_num_inbound_peers" comment:"Maximum number of inbound peers"` // Maximum number of outbound peers to connect to, excluding persistent peers - MaxNumOutboundPeers int `json:"max_num_outbound_peers" toml:"max_num_outbound_peers" comment:"Maximum number of outbound peers to connect to, excluding persistent peers"` + MaxNumOutboundPeers uint64 `json:"max_num_outbound_peers" toml:"max_num_outbound_peers" comment:"Maximum number of outbound peers to connect to, excluding persistent peers"` // Time to wait before flushing messages out on the connection FlushThrottleTimeout time.Duration `json:"flush_throttle_timeout" toml:"flush_throttle_timeout" comment:"Time to wait before flushing messages out on the connection"` @@ -54,105 +47,45 @@ type P2PConfig struct { RecvRate int64 `json:"recv_rate" toml:"recv_rate" comment:"Rate at which packets can be received, in bytes/second"` // Set true to enable the peer-exchange reactor - PexReactor bool `json:"pex" toml:"pex" comment:"Set true to enable the peer-exchange reactor"` + PeerExchange bool `json:"pex" toml:"pex" comment:"Set true to enable the peer-exchange reactor"` - // Seed mode, in which node constantly crawls the network and looks for - // peers. If another node asks it for addresses, it responds and disconnects. - // - // Does not work if the peer-exchange reactor is disabled. - SeedMode bool `json:"seed_mode" toml:"seed_mode" comment:"Seed mode, in which node constantly crawls the network and looks for\n peers. If another node asks it for addresses, it responds and disconnects.\n\n Does not work if the peer-exchange reactor is disabled."` - - // Comma separated list of peer IDs to keep private (will not be gossiped to - // other peers) + // Comma separated list of peer IDs to keep private (will not be gossiped to other peers) PrivatePeerIDs string `json:"private_peer_ids" toml:"private_peer_ids" comment:"Comma separated list of peer IDs to keep private (will not be gossiped to other peers)"` - - // Toggle to disable guard against peers connecting from the same ip. - AllowDuplicateIP bool `json:"allow_duplicate_ip" toml:"allow_duplicate_ip" comment:"Toggle to disable guard against peers connecting from the same ip."` - - // Peer connection configuration. - HandshakeTimeout time.Duration `json:"handshake_timeout" toml:"handshake_timeout" comment:"Peer connection configuration."` - DialTimeout time.Duration `json:"dial_timeout" toml:"dial_timeout"` - - // Testing params. - // Force dial to fail - TestDialFail bool `json:"test_dial_fail" toml:"test_dial_fail"` - // FUzz connection - TestFuzz bool `json:"test_fuzz" toml:"test_fuzz"` - TestFuzzConfig *FuzzConnConfig `json:"test_fuzz_config" toml:"test_fuzz_config"` } // DefaultP2PConfig returns a default configuration for the peer-to-peer layer func DefaultP2PConfig() *P2PConfig { return &P2PConfig{ ListenAddress: "tcp://0.0.0.0:26656", - ExternalAddress: "", - UPNP: false, + ExternalAddress: "", // nothing is advertised differently MaxNumInboundPeers: 40, MaxNumOutboundPeers: 10, FlushThrottleTimeout: 100 * time.Millisecond, MaxPacketMsgPayloadSize: 1024, // 1 kB SendRate: 5120000, // 5 mB/s RecvRate: 5120000, // 5 mB/s - PexReactor: true, - SeedMode: false, - AllowDuplicateIP: false, - HandshakeTimeout: 20 * time.Second, - DialTimeout: 3 * time.Second, - TestDialFail: false, - TestFuzz: false, - TestFuzzConfig: DefaultFuzzConnConfig(), + PeerExchange: true, } } -// TestP2PConfig returns a configuration for testing the peer-to-peer layer -func TestP2PConfig() *P2PConfig { - cfg := DefaultP2PConfig() - cfg.ListenAddress = "tcp://0.0.0.0:26656" - cfg.FlushThrottleTimeout = 10 * time.Millisecond - cfg.AllowDuplicateIP = true - return cfg -} - // ValidateBasic performs basic validation (checking param bounds, etc.) and // returns an error if any check fails. func (cfg *P2PConfig) ValidateBasic() error { - if cfg.MaxNumInboundPeers < 0 { - return errors.New("max_num_inbound_peers can't be negative") - } - if cfg.MaxNumOutboundPeers < 0 { - return errors.New("max_num_outbound_peers can't be negative") - } if cfg.FlushThrottleTimeout < 0 { - return errors.New("flush_throttle_timeout can't be negative") + return ErrInvalidFlushThrottleTimeout } + if cfg.MaxPacketMsgPayloadSize < 0 { - return errors.New("max_packet_msg_payload_size can't be negative") + return ErrInvalidMaxPayloadSize } + if cfg.SendRate < 0 { - return errors.New("send_rate can't be negative") + return ErrInvalidSendRate } + if cfg.RecvRate < 0 { - return errors.New("recv_rate can't be negative") + return ErrInvalidReceiveRate } - return nil -} - -// FuzzConnConfig is a FuzzedConnection configuration. -type FuzzConnConfig struct { - Mode int - MaxDelay time.Duration - ProbDropRW float64 - ProbDropConn float64 - ProbSleep float64 -} -// DefaultFuzzConnConfig returns the default config. -func DefaultFuzzConnConfig() *FuzzConnConfig { - return &FuzzConnConfig{ - Mode: FuzzModeDrop, - MaxDelay: 3 * time.Second, - ProbDropRW: 0.2, - ProbDropConn: 0.00, - ProbSleep: 0.00, - } + return nil } diff --git a/tm2/pkg/p2p/config/config_test.go b/tm2/pkg/p2p/config/config_test.go new file mode 100644 index 00000000000..f624563d984 --- /dev/null +++ b/tm2/pkg/p2p/config/config_test.go @@ -0,0 +1,59 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestP2PConfig_ValidateBasic(t *testing.T) { + t.Parallel() + + t.Run("invalid flush throttle timeout", func(t *testing.T) { + t.Parallel() + + cfg := DefaultP2PConfig() + + cfg.FlushThrottleTimeout = -1 + + assert.ErrorIs(t, cfg.ValidateBasic(), ErrInvalidFlushThrottleTimeout) + }) + + t.Run("invalid max packet payload size", func(t *testing.T) { + t.Parallel() + + cfg := DefaultP2PConfig() + + cfg.MaxPacketMsgPayloadSize = -1 + + assert.ErrorIs(t, cfg.ValidateBasic(), ErrInvalidMaxPayloadSize) + }) + + t.Run("invalid send rate", func(t *testing.T) { + t.Parallel() + + cfg := DefaultP2PConfig() + + cfg.SendRate = -1 + + assert.ErrorIs(t, cfg.ValidateBasic(), ErrInvalidSendRate) + }) + + t.Run("invalid receive rate", func(t *testing.T) { + t.Parallel() + + cfg := DefaultP2PConfig() + + cfg.RecvRate = -1 + + assert.ErrorIs(t, cfg.ValidateBasic(), ErrInvalidReceiveRate) + }) + + t.Run("valid configuration", func(t *testing.T) { + t.Parallel() + + cfg := DefaultP2PConfig() + + assert.NoError(t, cfg.ValidateBasic()) + }) +} diff --git a/tm2/pkg/p2p/conn/conn.go b/tm2/pkg/p2p/conn/conn.go new file mode 100644 index 00000000000..3215adc38ca --- /dev/null +++ b/tm2/pkg/p2p/conn/conn.go @@ -0,0 +1,22 @@ +package conn + +import ( + "net" + "time" +) + +// pipe wraps the networking conn interface +type pipe struct { + net.Conn +} + +func (p *pipe) SetDeadline(_ time.Time) error { + return nil +} + +func NetPipe() (net.Conn, net.Conn) { + p1, p2 := net.Pipe() + return &pipe{p1}, &pipe{p2} +} + +var _ net.Conn = (*pipe)(nil) diff --git a/tm2/pkg/p2p/conn/conn_go110.go b/tm2/pkg/p2p/conn/conn_go110.go deleted file mode 100644 index 37796ac791d..00000000000 --- a/tm2/pkg/p2p/conn/conn_go110.go +++ /dev/null @@ -1,15 +0,0 @@ -//go:build go1.10 - -package conn - -// Go1.10 has a proper net.Conn implementation that -// has the SetDeadline method implemented as per -// https://github.com/golang/go/commit/e2dd8ca946be884bb877e074a21727f1a685a706 -// lest we run into problems like -// https://github.com/tendermint/classic/issues/851 - -import "net" - -func NetPipe() (net.Conn, net.Conn) { - return net.Pipe() -} diff --git a/tm2/pkg/p2p/conn/conn_notgo110.go b/tm2/pkg/p2p/conn/conn_notgo110.go deleted file mode 100644 index f91b0c7ea63..00000000000 --- a/tm2/pkg/p2p/conn/conn_notgo110.go +++ /dev/null @@ -1,36 +0,0 @@ -//go:build !go1.10 - -package conn - -import ( - "net" - "time" -) - -// Only Go1.10 has a proper net.Conn implementation that -// has the SetDeadline method implemented as per -// -// https://github.com/golang/go/commit/e2dd8ca946be884bb877e074a21727f1a685a706 -// -// lest we run into problems like -// -// https://github.com/tendermint/classic/issues/851 -// -// so for go versions < Go1.10 use our custom net.Conn creator -// that doesn't return an `Unimplemented error` for net.Conn. -// Before https://github.com/tendermint/classic/commit/49faa79bdce5663894b3febbf4955fb1d172df04 -// we hadn't cared about errors from SetDeadline so swallow them up anyways. -type pipe struct { - net.Conn -} - -func (p *pipe) SetDeadline(t time.Time) error { - return nil -} - -func NetPipe() (net.Conn, net.Conn) { - p1, p2 := net.Pipe() - return &pipe{p1}, &pipe{p2} -} - -var _ net.Conn = (*pipe)(nil) diff --git a/tm2/pkg/p2p/conn/connection.go b/tm2/pkg/p2p/conn/connection.go index 6b7400600d3..acbffdb3cd7 100644 --- a/tm2/pkg/p2p/conn/connection.go +++ b/tm2/pkg/p2p/conn/connection.go @@ -17,6 +17,7 @@ import ( "github.com/gnolang/gno/tm2/pkg/amino" "github.com/gnolang/gno/tm2/pkg/errors" "github.com/gnolang/gno/tm2/pkg/flow" + "github.com/gnolang/gno/tm2/pkg/p2p/config" "github.com/gnolang/gno/tm2/pkg/service" "github.com/gnolang/gno/tm2/pkg/timer" ) @@ -31,7 +32,6 @@ const ( // some of these defaults are written in the user config // flushThrottle, sendRate, recvRate - // TODO: remove values present in config defaultFlushThrottle = 100 * time.Millisecond defaultSendQueueCapacity = 1 @@ -46,7 +46,7 @@ const ( type ( receiveCbFunc func(chID byte, msgBytes []byte) - errorCbFunc func(interface{}) + errorCbFunc func(error) ) /* @@ -147,6 +147,18 @@ func DefaultMConnConfig() MConnConfig { } } +// MConfigFromP2P returns a multiplex connection configuration +// with fields updated from the P2PConfig +func MConfigFromP2P(cfg *config.P2PConfig) MConnConfig { + mConfig := DefaultMConnConfig() + mConfig.FlushThrottle = cfg.FlushThrottleTimeout + mConfig.SendRate = cfg.SendRate + mConfig.RecvRate = cfg.RecvRate + mConfig.MaxPacketMsgPayloadSize = cfg.MaxPacketMsgPayloadSize + + return mConfig +} + // NewMConnection wraps net.Conn and creates multiplex connection func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc) *MConnection { return NewMConnectionWithConfig( @@ -323,7 +335,7 @@ func (c *MConnection) _recover() { } } -func (c *MConnection) stopForError(r interface{}) { +func (c *MConnection) stopForError(r error) { c.Stop() if atomic.CompareAndSwapUint32(&c.errored, 0, 1) { if c.onError != nil { diff --git a/tm2/pkg/p2p/conn/connection_test.go b/tm2/pkg/p2p/conn/connection_test.go index 7bbe88ded22..58b363b7b78 100644 --- a/tm2/pkg/p2p/conn/connection_test.go +++ b/tm2/pkg/p2p/conn/connection_test.go @@ -21,7 +21,7 @@ func createTestMConnection(t *testing.T, conn net.Conn) *MConnection { onReceive := func(chID byte, msgBytes []byte) { } - onError := func(r interface{}) { + onError := func(r error) { } c := createMConnectionWithCallbacks(t, conn, onReceive, onError) c.SetLogger(log.NewTestingLogger(t)) @@ -32,7 +32,7 @@ func createMConnectionWithCallbacks( t *testing.T, conn net.Conn, onReceive func(chID byte, msgBytes []byte), - onError func(r interface{}), + onError func(r error), ) *MConnection { t.Helper() @@ -137,7 +137,7 @@ func TestMConnectionReceive(t *testing.T) { onReceive := func(chID byte, msgBytes []byte) { receivedCh <- msgBytes } - onError := func(r interface{}) { + onError := func(r error) { errorsCh <- r } mconn1 := createMConnectionWithCallbacks(t, client, onReceive, onError) @@ -192,7 +192,7 @@ func TestMConnectionPongTimeoutResultsInError(t *testing.T) { onReceive := func(chID byte, msgBytes []byte) { receivedCh <- msgBytes } - onError := func(r interface{}) { + onError := func(r error) { errorsCh <- r } mconn := createMConnectionWithCallbacks(t, client, onReceive, onError) @@ -233,7 +233,7 @@ func TestMConnectionMultiplePongsInTheBeginning(t *testing.T) { onReceive := func(chID byte, msgBytes []byte) { receivedCh <- msgBytes } - onError := func(r interface{}) { + onError := func(r error) { errorsCh <- r } mconn := createMConnectionWithCallbacks(t, client, onReceive, onError) @@ -288,7 +288,7 @@ func TestMConnectionMultiplePings(t *testing.T) { onReceive := func(chID byte, msgBytes []byte) { receivedCh <- msgBytes } - onError := func(r interface{}) { + onError := func(r error) { errorsCh <- r } mconn := createMConnectionWithCallbacks(t, client, onReceive, onError) @@ -331,7 +331,7 @@ func TestMConnectionPingPongs(t *testing.T) { onReceive := func(chID byte, msgBytes []byte) { receivedCh <- msgBytes } - onError := func(r interface{}) { + onError := func(r error) { errorsCh <- r } mconn := createMConnectionWithCallbacks(t, client, onReceive, onError) @@ -384,7 +384,7 @@ func TestMConnectionStopsAndReturnsError(t *testing.T) { onReceive := func(chID byte, msgBytes []byte) { receivedCh <- msgBytes } - onError := func(r interface{}) { + onError := func(r error) { errorsCh <- r } mconn := createMConnectionWithCallbacks(t, client, onReceive, onError) @@ -413,7 +413,7 @@ func newClientAndServerConnsForReadErrors(t *testing.T, chOnErr chan struct{}) ( server, client := NetPipe() onReceive := func(chID byte, msgBytes []byte) {} - onError := func(r interface{}) {} + onError := func(r error) {} // create client conn with two channels chDescs := []*ChannelDescriptor{ @@ -428,7 +428,7 @@ func newClientAndServerConnsForReadErrors(t *testing.T, chOnErr chan struct{}) ( // create server conn with 1 channel // it fires on chOnErr when there's an error serverLogger := log.NewNoopLogger().With("module", "server") - onError = func(r interface{}) { + onError = func(_ error) { chOnErr <- struct{}{} } mconnServer := createMConnectionWithCallbacks(t, server, onReceive, onError) diff --git a/tm2/pkg/p2p/conn/secret_connection.go b/tm2/pkg/p2p/conn/secret_connection.go index a37788b947d..d45b5b3846a 100644 --- a/tm2/pkg/p2p/conn/secret_connection.go +++ b/tm2/pkg/p2p/conn/secret_connection.go @@ -7,6 +7,7 @@ import ( "crypto/sha256" "crypto/subtle" "encoding/binary" + "fmt" "io" "math" "net" @@ -128,7 +129,10 @@ func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKey) (* } // Sign the challenge bytes for authentication. - locSignature := signChallenge(challenge, locPrivKey) + locSignature, err := locPrivKey.Sign(challenge[:]) + if err != nil { + return nil, fmt.Errorf("unable to sign challenge, %w", err) + } // Share (in secret) each other's pubkey & challenge signature authSigMsg, err := shareAuthSignature(sc, locPubKey, locSignature) @@ -424,15 +428,6 @@ func sort32(foo, bar *[32]byte) (lo, hi *[32]byte) { return } -func signChallenge(challenge *[32]byte, locPrivKey crypto.PrivKey) (signature []byte) { - signature, err := locPrivKey.Sign(challenge[:]) - // TODO(ismail): let signChallenge return an error instead - if err != nil { - panic(err) - } - return -} - type authSigMessage struct { Key crypto.PubKey Sig []byte diff --git a/tm2/pkg/p2p/conn_set.go b/tm2/pkg/p2p/conn_set.go deleted file mode 100644 index d646227831a..00000000000 --- a/tm2/pkg/p2p/conn_set.go +++ /dev/null @@ -1,81 +0,0 @@ -package p2p - -import ( - "net" - "sync" -) - -// ConnSet is a lookup table for connections and all their ips. -type ConnSet interface { - Has(net.Conn) bool - HasIP(net.IP) bool - Set(net.Conn, []net.IP) - Remove(net.Conn) - RemoveAddr(net.Addr) -} - -type connSetItem struct { - conn net.Conn - ips []net.IP -} - -type connSet struct { - sync.RWMutex - - conns map[string]connSetItem -} - -// NewConnSet returns a ConnSet implementation. -func NewConnSet() *connSet { - return &connSet{ - conns: map[string]connSetItem{}, - } -} - -func (cs *connSet) Has(c net.Conn) bool { - cs.RLock() - defer cs.RUnlock() - - _, ok := cs.conns[c.RemoteAddr().String()] - - return ok -} - -func (cs *connSet) HasIP(ip net.IP) bool { - cs.RLock() - defer cs.RUnlock() - - for _, c := range cs.conns { - for _, known := range c.ips { - if known.Equal(ip) { - return true - } - } - } - - return false -} - -func (cs *connSet) Remove(c net.Conn) { - cs.Lock() - defer cs.Unlock() - - delete(cs.conns, c.RemoteAddr().String()) -} - -func (cs *connSet) RemoveAddr(addr net.Addr) { - cs.Lock() - defer cs.Unlock() - - delete(cs.conns, addr.String()) -} - -func (cs *connSet) Set(c net.Conn, ips []net.IP) { - cs.Lock() - defer cs.Unlock() - - cs.conns[c.RemoteAddr().String()] = connSetItem{ - conn: c, - ips: ips, - } -} diff --git a/tm2/pkg/p2p/dial/dial.go b/tm2/pkg/p2p/dial/dial.go new file mode 100644 index 00000000000..e4a7d6fd445 --- /dev/null +++ b/tm2/pkg/p2p/dial/dial.go @@ -0,0 +1,83 @@ +package dial + +import ( + "sync" + "time" + + "github.com/gnolang/gno/tm2/pkg/p2p/types" + queue "github.com/sig-0/insertion-queue" +) + +// Item is a single dial queue item, wrapping +// the approximately appropriate dial time, and the +// peer dial address +type Item struct { + Time time.Time // appropriate dial time + Address *types.NetAddress // the dial address of the peer +} + +// Less is the comparison method for the dial queue Item (time ascending) +func (i Item) Less(item Item) bool { + return i.Time.Before(item.Time) +} + +// Queue is a time-sorted (ascending) dial queue +type Queue struct { + mux sync.RWMutex + + items queue.Queue[Item] // sorted dial queue (by time, ascending) +} + +// NewQueue creates a new dial queue +func NewQueue() *Queue { + return &Queue{ + items: queue.NewQueue[Item](), + } +} + +// Peek returns the first item in the dial queue, if any +func (q *Queue) Peek() *Item { + q.mux.RLock() + defer q.mux.RUnlock() + + if q.items.Len() == 0 { + return nil + } + + item := q.items.Index(0) + + return &item +} + +// Push adds new items to the dial queue +func (q *Queue) Push(items ...Item) { + q.mux.Lock() + defer q.mux.Unlock() + + for _, item := range items { + q.items.Push(item) + } +} + +// Pop removes an item from the dial queue, if any +func (q *Queue) Pop() *Item { + q.mux.Lock() + defer q.mux.Unlock() + + return q.items.PopFront() +} + +// Has returns a flag indicating if the given +// address is in the dial queue +func (q *Queue) Has(addr *types.NetAddress) bool { + q.mux.RLock() + defer q.mux.RUnlock() + + for _, i := range q.items { + if addr.Equals(*i.Address) { + return true + } + } + + return false +} diff --git a/tm2/pkg/p2p/dial/dial_test.go b/tm2/pkg/p2p/dial/dial_test.go new file mode 100644 index 00000000000..5e85ec1f95e --- /dev/null +++ b/tm2/pkg/p2p/dial/dial_test.go @@ -0,0 +1,147 @@ +package dial + +import ( + "crypto/rand" + "math/big" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// generateRandomTimes generates random time intervals +func generateRandomTimes(t *testing.T, count int) []time.Time { + t.Helper() + + const timeRange = 94608000 // 3 years + + var ( + maxRange = big.NewInt(time.Now().Unix() - timeRange) + times = make([]time.Time, 0, count) + ) + + for range count { + n, err := rand.Int(rand.Reader, maxRange) + require.NoError(t, err) + + randTime := time.Unix(n.Int64()+timeRange, 0) + + times = append(times, randTime) + } + + return times +} + +func TestQueue_Push(t *testing.T) { + t.Parallel() + + var ( + timestamps = generateRandomTimes(t, 10) + q = NewQueue() + ) + + // Add the dial items + for _, timestamp := range timestamps { + q.Push(Item{ + Time: timestamp, + }) + } + + assert.Len(t, q.items, len(timestamps)) +} + +func TestQueue_Peek(t *testing.T) { + t.Parallel() + + t.Run("empty queue", func(t *testing.T) { + t.Parallel() + + q := NewQueue() + + assert.Nil(t, q.Peek()) + }) + + t.Run("existing item", func(t *testing.T) { + t.Parallel() + + var ( + timestamps = generateRandomTimes(t, 100) + q = NewQueue() + ) + + // Add the dial items + for _, timestamp := range timestamps { + q.Push(Item{ + Time: timestamp, + }) + } + + // Sort the initial list to find the best timestamp + slices.SortFunc(timestamps, func(a, b time.Time) int { + if a.Before(b) { + return -1 + } + + if a.After(b) { + return 1 + } + + return 0 + }) + + assert.Equal(t, q.Peek().Time.Unix(), timestamps[0].Unix()) + }) +} + +func TestQueue_Pop(t *testing.T) { + t.Parallel() + + t.Run("empty queue", func(t *testing.T) { + t.Parallel() + + q := NewQueue() + + assert.Nil(t, q.Pop()) + }) + + t.Run("existing item", func(t *testing.T) { + t.Parallel() + + var ( + timestamps = generateRandomTimes(t, 100) + q = NewQueue() + ) + + // Add the dial items + for _, timestamp := range timestamps { + q.Push(Item{ + Time: timestamp, + }) + } + + assert.Len(t, q.items, len(timestamps)) + + // Sort the initial list to find the best timestamp + slices.SortFunc(timestamps, func(a, b time.Time) int { + if a.Before(b) { + return -1 + } + + if a.After(b) { + return 1 + } + + return 0 + }) + + for index, timestamp := range timestamps { + item := q.Pop() + + require.Len(t, q.items, len(timestamps)-1-index) + + assert.Equal(t, item.Time.Unix(), timestamp.Unix()) + } + }) +} diff --git a/tm2/pkg/p2p/dial/doc.go b/tm2/pkg/p2p/dial/doc.go new file mode 100644 index 00000000000..069160e73e6 --- /dev/null +++ b/tm2/pkg/p2p/dial/doc.go @@ -0,0 +1,10 @@ +// Package dial contains an implementation of a thread-safe priority dial queue. The queue is sorted by +// dial items, time ascending. +// The behavior of the dial queue is the following: +// +// - Peeking the dial queue will return the most urgent dial item, or nil if the queue is empty. +// +// - Popping the dial queue will return the most urgent dial item or nil if the queue is empty. Popping removes the dial item. +// +// - Push will push a new item to the dial queue, upon which the queue will find an adequate place for it. +package dial diff --git a/tm2/pkg/p2p/discovery/discovery.go b/tm2/pkg/p2p/discovery/discovery.go new file mode 100644 index 00000000000..d884b118c75 --- /dev/null +++ b/tm2/pkg/p2p/discovery/discovery.go @@ -0,0 +1,242 @@ +package discovery + +import ( + "context" + "crypto/rand" + "fmt" + "math/big" + "time" + + "github.com/gnolang/gno/tm2/pkg/amino" + "github.com/gnolang/gno/tm2/pkg/p2p" + "github.com/gnolang/gno/tm2/pkg/p2p/conn" + "github.com/gnolang/gno/tm2/pkg/p2p/types" + "golang.org/x/exp/slices" +) + +const ( + // Channel is the unique channel for the peer discovery protocol + Channel = byte(0x50) + + // discoveryInterval is the peer discovery interval, for random peers + discoveryInterval = time.Second * 3 + + // maxPeersShared is the maximum number of peers shared in the discovery request + maxPeersShared = 30 +) + +// descriptor is the constant peer discovery protocol descriptor +var descriptor = &conn.ChannelDescriptor{ + ID: Channel, + Priority: 1, // peer discovery is high priority + SendQueueCapacity: 20, // more than enough active conns + RecvMessageCapacity: 5242880, // 5MB +} + +// Reactor wraps the logic for the peer exchange protocol +type Reactor struct { + // This embed and the usage of "services" + // like the peer discovery reactor highlight the + // flipped design of the p2p package. + // The peer exchange service needs to be instantiated _outside_ + // the p2p module, because of this flipped design. + // Peers communicate with each other through Reactor channels, + // which are instantiated outside the p2p module + p2p.BaseReactor + + ctx context.Context + cancelFn context.CancelFunc + + discoveryInterval time.Duration +} + +// NewReactor creates a new peer discovery reactor +func NewReactor(opts ...Option) *Reactor { + ctx, cancelFn := context.WithCancel(context.Background()) + + r := &Reactor{ + ctx: ctx, + cancelFn: cancelFn, + discoveryInterval: discoveryInterval, + } + + r.BaseReactor = *p2p.NewBaseReactor("Reactor", r) + + // Apply the options + for _, opt := range opts { + opt(r) + } + + return r +} + +// OnStart runs the peer discovery protocol +func (r *Reactor) OnStart() error { + go func() { + ticker := time.NewTicker(r.discoveryInterval) + defer ticker.Stop() + + for { + select { + case <-r.ctx.Done(): + r.Logger.Debug("discovery service stopped") + + return + case <-ticker.C: + // Run the discovery protocol // + + // Grab a random peer, and engage + // them for peer discovery + peers := r.Switch.Peers().List() + + if len(peers) == 0 { + // No discovery to run + continue + } + + // Generate a random peer index + randomPeer, _ := rand.Int( + rand.Reader, + big.NewInt(int64(len(peers))), + ) + + // Request peers, async + go r.requestPeers(peers[randomPeer.Int64()]) + } + } + }() + + return nil +} + +// OnStop stops the peer discovery protocol +func (r *Reactor) OnStop() { + r.cancelFn() +} + +// requestPeers requests the peer set from the given peer +func (r *Reactor) requestPeers(peer p2p.PeerConn) { + // Initiate peer discovery + r.Logger.Debug("running peer discovery", "peer", peer.ID()) + + // Prepare the request + // (empty, as it's a notification) + req := &Request{} + + reqBytes, err := amino.MarshalAny(req) + if err != nil { + r.Logger.Error("unable to marshal discovery request", "err", err) + + return + } + + // Send the request + if !peer.Send(Channel, reqBytes) { + r.Logger.Warn("unable to send discovery request", "peer", peer.ID()) + } +} + +// GetChannels returns the channels associated with peer discovery +func (r *Reactor) GetChannels() []*conn.ChannelDescriptor { + return []*conn.ChannelDescriptor{descriptor} +} + +// Receive handles incoming messages for the peer discovery reactor +func (r *Reactor) Receive(chID byte, peer p2p.PeerConn, msgBytes []byte) { + r.Logger.Debug( + "received message", + "peerID", peer.ID(), + "chID", chID, + ) + + // Unmarshal the message + var msg Message + + if err := amino.UnmarshalAny(msgBytes, &msg); err != nil { + r.Logger.Error("unable to unmarshal discovery message", "err", err) + + return + } + + // Validate the message + if err := msg.ValidateBasic(); err != nil { + r.Logger.Error("unable to validate discovery message", "err", err) + + return + } + + switch msg := msg.(type) { + case *Request: + if err := r.handleDiscoveryRequest(peer); err != nil { + r.Logger.Error("unable to handle discovery request", "err", err) + } + case *Response: + // Make the peers available for dialing on the switch + r.Switch.DialPeers(msg.Peers...) + default: + r.Logger.Warn("invalid message received", "msg", msgBytes) + } +} + +// handleDiscoveryRequest prepares a peer list that can be shared +// with the peer requesting discovery +func (r *Reactor) handleDiscoveryRequest(peer p2p.PeerConn) error { + var ( + localPeers = r.Switch.Peers().List() + peers = make([]*types.NetAddress, 0, len(localPeers)) + ) + + // Exclude the private peers from being shared + localPeers = slices.DeleteFunc(localPeers, func(p p2p.PeerConn) bool { + return p.IsPrivate() + }) + + // Check if there is anything to share, + // to avoid useless traffic + if len(localPeers) == 0 { + r.Logger.Warn("no peers to share in discovery request") + + return nil + } + + // Shuffle and limit the peers shared + shufflePeers(localPeers) + + if len(localPeers) > maxPeersShared { + localPeers = localPeers[:maxPeersShared] + } + + for _, p := range localPeers { + peers = append(peers, p.SocketAddr()) + } + + // Create the response, and marshal + // it to Amino binary + resp := &Response{ + Peers: peers, + } + + preparedResp, err := amino.MarshalAny(resp) + if err != nil { + return fmt.Errorf("unable to marshal discovery response, %w", err) + } + + // Send the response to the peer + if !peer.Send(Channel, preparedResp) { + return fmt.Errorf("unable to send discovery response to peer %s", peer.ID()) + } + + return nil +} + +// shufflePeers shuffles the peer list in-place +func shufflePeers(peers []p2p.PeerConn) { + for i := len(peers) - 1; i > 0; i-- { + jBig, _ := rand.Int(rand.Reader, big.NewInt(int64(i+1))) + + j := int(jBig.Int64()) + + // Swap elements + peers[i], peers[j] = peers[j], peers[i] + } +} diff --git a/tm2/pkg/p2p/discovery/discovery_test.go b/tm2/pkg/p2p/discovery/discovery_test.go new file mode 100644 index 00000000000..17404e6039a --- /dev/null +++ b/tm2/pkg/p2p/discovery/discovery_test.go @@ -0,0 +1,453 @@ +package discovery + +import ( + "slices" + "testing" + "time" + + "github.com/gnolang/gno/tm2/pkg/amino" + "github.com/gnolang/gno/tm2/pkg/p2p" + "github.com/gnolang/gno/tm2/pkg/p2p/mock" + "github.com/gnolang/gno/tm2/pkg/p2p/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReactor_DiscoveryRequest(t *testing.T) { + t.Parallel() + + var ( + notifCh = make(chan struct{}, 1) + + capturedSend []byte + + mockPeer = &mock.Peer{ + SendFn: func(chID byte, data []byte) bool { + require.Equal(t, Channel, chID) + + capturedSend = data + + notifCh <- struct{}{} + + return true + }, + } + + ps = &mockPeerSet{ + listFn: func() []p2p.PeerConn { + return []p2p.PeerConn{mockPeer} + }, + } + + mockSwitch = &mockSwitch{ + peersFn: func() p2p.PeerSet { + return ps + }, + } + ) + + r := NewReactor( + WithDiscoveryInterval(10 * time.Millisecond), + ) + + // Set the mock switch + r.SetSwitch(mockSwitch) + + // Start the discovery service + require.NoError(t, r.Start()) + t.Cleanup(func() { + require.NoError(t, r.Stop()) + }) + + select { + case <-notifCh: + case <-time.After(5 * time.Second): + } + + // Make sure the adequate message was captured + require.NotNil(t, capturedSend) + + // Parse the message + var msg Message + + require.NoError(t, amino.Unmarshal(capturedSend, &msg)) + + // Make sure the base message is valid + require.NoError(t, msg.ValidateBasic()) + + _, ok := msg.(*Request) + + require.True(t, ok) +} + +func TestReactor_DiscoveryResponse(t *testing.T) { + t.Parallel() + + t.Run("discovery request received", func(t *testing.T) { + t.Parallel() + + var ( + peers = mock.GeneratePeers(t, 50) + notifCh = make(chan struct{}, 1) + + capturedSend []byte + + mockPeer = &mock.Peer{ + SendFn: func(chID byte, data []byte) bool { + require.Equal(t, Channel, chID) + + capturedSend = data + + notifCh <- struct{}{} + + return true + }, + } + + ps = &mockPeerSet{ + listFn: func() []p2p.PeerConn { + listed := make([]p2p.PeerConn, 0, len(peers)) + + for _, peer := range peers { + listed = append(listed, peer) + } + + return listed + }, + numInboundFn: func() uint64 { + return uint64(len(peers)) + }, + } + + mockSwitch = &mockSwitch{ + peersFn: func() p2p.PeerSet { + return ps + }, + } + ) + + r := NewReactor( + WithDiscoveryInterval(10 * time.Millisecond), + ) + + // Set the mock switch + r.SetSwitch(mockSwitch) + + // Prepare the message + req := &Request{} + + preparedReq, err := amino.MarshalAny(req) + require.NoError(t, err) + + // Receive the message + r.Receive(Channel, mockPeer, preparedReq) + + select { + case <-notifCh: + case <-time.After(5 * time.Second): + } + + // Make sure the adequate message was captured + require.NotNil(t, capturedSend) + + // Parse the message + var msg Message + + require.NoError(t, amino.Unmarshal(capturedSend, &msg)) + + // Make sure the base message is valid + require.NoError(t, msg.ValidateBasic()) + + resp, ok := msg.(*Response) + require.True(t, ok) + + // Make sure the peers are valid + require.Len(t, resp.Peers, maxPeersShared) + + slices.ContainsFunc(resp.Peers, func(addr *types.NetAddress) bool { + for _, localP := range peers { + if localP.SocketAddr().Equals(*addr) { + return true + } + } + + return false + }) + }) + + t.Run("empty peers on discover", func(t *testing.T) { + t.Parallel() + + var ( + capturedSend []byte + + mockPeer = &mock.Peer{ + SendFn: func(chID byte, data []byte) bool { + require.Equal(t, Channel, chID) + + capturedSend = data + + return true + }, + } + + ps = &mockPeerSet{ + listFn: func() []p2p.PeerConn { + return make([]p2p.PeerConn, 0) + }, + } + + mockSwitch = &mockSwitch{ + peersFn: func() p2p.PeerSet { + return ps + }, + } + ) + + r := NewReactor( + WithDiscoveryInterval(10 * time.Millisecond), + ) + + // Set the mock switch + r.SetSwitch(mockSwitch) + + // Prepare the message + req := &Request{} + + preparedReq, err := amino.MarshalAny(req) + require.NoError(t, err) + + // Receive the message + r.Receive(Channel, mockPeer, preparedReq) + + // Make sure no message was captured + assert.Nil(t, capturedSend) + }) + + t.Run("private peers not shared", func(t *testing.T) { + t.Parallel() + + var ( + publicPeers = 1 + privatePeers = 50 + + peers = mock.GeneratePeers(t, publicPeers+privatePeers) + notifCh = make(chan struct{}, 1) + + capturedSend []byte + + mockPeer = &mock.Peer{ + SendFn: func(chID byte, data []byte) bool { + require.Equal(t, Channel, chID) + + capturedSend = data + + notifCh <- struct{}{} + + return true + }, + } + + ps = &mockPeerSet{ + listFn: func() []p2p.PeerConn { + listed := make([]p2p.PeerConn, 0, len(peers)) + + for _, peer := range peers { + listed = append(listed, peer) + } + + return listed + }, + numInboundFn: func() uint64 { + return uint64(len(peers)) + }, + } + + mockSwitch = &mockSwitch{ + peersFn: func() p2p.PeerSet { + return ps + }, + } + ) + + // Mark all except the last X peers as private + for _, peer := range peers[:privatePeers] { + peer.IsPrivateFn = func() bool { + return true + } + } + + r := NewReactor( + WithDiscoveryInterval(10 * time.Millisecond), + ) + + // Set the mock switch + r.SetSwitch(mockSwitch) + + // Prepare the message + req := &Request{} + + preparedReq, err := amino.MarshalAny(req) + require.NoError(t, err) + + // Receive the message + r.Receive(Channel, mockPeer, preparedReq) + + select { + case <-notifCh: + case <-time.After(5 * time.Second): + } + + // Make sure the adequate message was captured + require.NotNil(t, capturedSend) + + // Parse the message + var msg Message + + require.NoError(t, amino.Unmarshal(capturedSend, &msg)) + + // Make sure the base message is valid + require.NoError(t, msg.ValidateBasic()) + + resp, ok := msg.(*Response) + require.True(t, ok) + + // Make sure the peers are valid + require.Len(t, resp.Peers, publicPeers) + + slices.ContainsFunc(resp.Peers, func(addr *types.NetAddress) bool { + for _, localP := range peers { + if localP.SocketAddr().Equals(*addr) { + return true + } + } + + return false + }) + }) + + t.Run("peer response received", func(t *testing.T) { + t.Parallel() + + var ( + peers = mock.GeneratePeers(t, 50) + notifCh = make(chan struct{}, 1) + + capturedDials []*types.NetAddress + + ps = &mockPeerSet{ + listFn: func() []p2p.PeerConn { + listed := make([]p2p.PeerConn, 0, len(peers)) + + for _, peer := range peers { + listed = append(listed, peer) + } + + return listed + }, + numInboundFn: func() uint64 { + return uint64(len(peers)) + }, + } + + mockSwitch = &mockSwitch{ + peersFn: func() p2p.PeerSet { + return ps + }, + dialPeersFn: func(addresses ...*types.NetAddress) { + capturedDials = append(capturedDials, addresses...) + + notifCh <- struct{}{} + }, + } + ) + + r := NewReactor( + WithDiscoveryInterval(10 * time.Millisecond), + ) + + // Set the mock switch + r.SetSwitch(mockSwitch) + + // Prepare the addresses + peerAddrs := make([]*types.NetAddress, 0, len(peers)) + + for _, p := range peers { + peerAddrs = append(peerAddrs, p.SocketAddr()) + } + + // Prepare the message + req := &Response{ + Peers: peerAddrs, + } + + preparedReq, err := amino.MarshalAny(req) + require.NoError(t, err) + + // Receive the message + r.Receive(Channel, &mock.Peer{}, preparedReq) + + select { + case <-notifCh: + case <-time.After(5 * time.Second): + } + + // Make sure the correct peers were dialed + assert.Equal(t, capturedDials, peerAddrs) + }) + + t.Run("invalid peer response received", func(t *testing.T) { + t.Parallel() + + var ( + peers = mock.GeneratePeers(t, 50) + + capturedDials []*types.NetAddress + + ps = &mockPeerSet{ + listFn: func() []p2p.PeerConn { + listed := make([]p2p.PeerConn, 0, len(peers)) + + for _, peer := range peers { + listed = append(listed, peer) + } + + return listed + }, + numInboundFn: func() uint64 { + return uint64(len(peers)) + }, + } + + mockSwitch = &mockSwitch{ + peersFn: func() p2p.PeerSet { + return ps + }, + dialPeersFn: func(addresses ...*types.NetAddress) { + capturedDials = append(capturedDials, addresses...) + }, + } + ) + + r := NewReactor( + WithDiscoveryInterval(10 * time.Millisecond), + ) + + // Set the mock switch + r.SetSwitch(mockSwitch) + + // Prepare the message + req := &Response{ + Peers: make([]*types.NetAddress, 0), // empty + } + + preparedReq, err := amino.MarshalAny(req) + require.NoError(t, err) + + // Receive the message + r.Receive(Channel, &mock.Peer{}, preparedReq) + + // Make sure no peers were dialed + assert.Empty(t, capturedDials) + }) +} diff --git a/tm2/pkg/p2p/discovery/doc.go b/tm2/pkg/p2p/discovery/doc.go new file mode 100644 index 00000000000..5426bb41277 --- /dev/null +++ b/tm2/pkg/p2p/discovery/doc.go @@ -0,0 +1,9 @@ +// Package discovery contains the p2p peer discovery service (Reactor). +// The purpose of the peer discovery service is to gather peer lists from known peers, +// and attempt to fill out open peer connection slots in order to build out a fuller mesh. +// +// The implementation of the peer discovery protocol is relatively simple. +// In essence, it pings a random peer at a specific interval (3s), for a list of their known peers (max 30). +// After receiving the list, and verifying it, the node attempts to establish outbound connections to the +// given peers. +package discovery diff --git a/tm2/pkg/p2p/discovery/mock_test.go b/tm2/pkg/p2p/discovery/mock_test.go new file mode 100644 index 00000000000..cd543428f86 --- /dev/null +++ b/tm2/pkg/p2p/discovery/mock_test.go @@ -0,0 +1,135 @@ +package discovery + +import ( + "net" + + "github.com/gnolang/gno/tm2/pkg/p2p" + "github.com/gnolang/gno/tm2/pkg/p2p/events" + "github.com/gnolang/gno/tm2/pkg/p2p/types" +) + +type ( + broadcastDelegate func(byte, []byte) + peersDelegate func() p2p.PeerSet + stopPeerForErrorDelegate func(p2p.PeerConn, error) + dialPeersDelegate func(...*types.NetAddress) + subscribeDelegate func(events.EventFilter) (<-chan events.Event, func()) +) + +type mockSwitch struct { + broadcastFn broadcastDelegate + peersFn peersDelegate + stopPeerForErrorFn stopPeerForErrorDelegate + dialPeersFn dialPeersDelegate + subscribeFn subscribeDelegate +} + +func (m *mockSwitch) Broadcast(chID byte, data []byte) { + if m.broadcastFn != nil { + m.broadcastFn(chID, data) + } +} + +func (m *mockSwitch) Peers() p2p.PeerSet { + if m.peersFn != nil { + return m.peersFn() + } + + return nil +} + +func (m *mockSwitch) StopPeerForError(peer p2p.PeerConn, err error) { + if m.stopPeerForErrorFn != nil { + m.stopPeerForErrorFn(peer, err) + } +} + +func (m *mockSwitch) DialPeers(peerAddrs ...*types.NetAddress) { + if m.dialPeersFn != nil { + m.dialPeersFn(peerAddrs...) + } +} + +func (m *mockSwitch) Subscribe(filter events.EventFilter) (<-chan events.Event, func()) { + if m.subscribeFn != nil { + m.subscribeFn(filter) + } + + return nil, func() {} +} + +type ( + addDelegate func(p2p.PeerConn) + removeDelegate func(types.ID) bool + hasDelegate func(types.ID) bool + hasIPDelegate func(net.IP) bool + getPeerDelegate func(types.ID) p2p.PeerConn + listDelegate func() []p2p.PeerConn + numInboundDelegate func() uint64 + numOutboundDelegate func() uint64 +) + +type mockPeerSet struct { + addFn addDelegate + removeFn removeDelegate + hasFn hasDelegate + hasIPFn hasIPDelegate + getFn getPeerDelegate + listFn listDelegate + numInboundFn numInboundDelegate + numOutboundFn numOutboundDelegate +} + +func (m *mockPeerSet) Add(peer p2p.PeerConn) { + if m.addFn != nil { + m.addFn(peer) + } +} + +func (m *mockPeerSet) Remove(key types.ID) bool { + if m.removeFn != nil { + m.removeFn(key) + } + + return false +} + +func (m *mockPeerSet) Has(key types.ID) bool { + if m.hasFn != nil { + return m.hasFn(key) + } + + return false +} + +func (m *mockPeerSet) Get(key types.ID) p2p.PeerConn { + if m.getFn != nil { + return m.getFn(key) + } + + return nil +} + +func (m *mockPeerSet) List() []p2p.PeerConn { + if m.listFn != nil { + return m.listFn() + } + + return nil +} + +func (m *mockPeerSet) NumInbound() uint64 { + if m.numInboundFn != nil { + return m.numInboundFn() + } + + return 0 +} + +func (m *mockPeerSet) NumOutbound() uint64 { + if m.numOutboundFn != nil { + return m.numOutboundFn() + } + + return 0 +} diff --git a/tm2/pkg/p2p/discovery/option.go b/tm2/pkg/p2p/discovery/option.go new file mode 100644 index 00000000000..dc0fb95b109 --- /dev/null +++ b/tm2/pkg/p2p/discovery/option.go @@ -0,0 +1,12 @@ +package discovery + +import "time" + +type Option func(*Reactor) + +// WithDiscoveryInterval sets the discovery crawl interval +func WithDiscoveryInterval(interval time.Duration) Option { + return func(r *Reactor) { + r.discoveryInterval = interval + } +} diff --git a/tm2/pkg/p2p/discovery/package.go b/tm2/pkg/p2p/discovery/package.go new file mode 100644 index 00000000000..a3865fdf5d2 --- /dev/null +++ b/tm2/pkg/p2p/discovery/package.go @@ -0,0 +1,16 @@ +package discovery + +import ( + "github.com/gnolang/gno/tm2/pkg/amino" +) + +var Package = amino.RegisterPackage(amino.NewPackage( + "github.com/gnolang/gno/tm2/pkg/p2p/discovery", + "p2p", + amino.GetCallersDirname(), +). + WithTypes( + &Request{}, + &Response{}, + ), +) diff --git a/tm2/pkg/p2p/discovery/types.go b/tm2/pkg/p2p/discovery/types.go new file mode 100644 index 00000000000..87ea936ebb5 --- /dev/null +++ b/tm2/pkg/p2p/discovery/types.go @@ -0,0 +1,44 @@ +package discovery + +import ( + "github.com/gnolang/gno/tm2/pkg/errors" + "github.com/gnolang/gno/tm2/pkg/p2p/types" +) + +var errNoPeers = errors.New("no peers received") + +// Message is the wrapper for the discovery message +type Message interface { + ValidateBasic() error +} + +// Request is the peer discovery request. +// It is empty by design, since it's used as +// a notification type +type Request struct{} + +func (r *Request) ValidateBasic() error { + return nil +} + +// Response is the peer discovery response +type Response struct { + Peers []*types.NetAddress // the peer set returned by the peer +} + +func (r *Response) ValidateBasic() error { + // Make sure at least some peers were received + if len(r.Peers) == 0 { + return errNoPeers + } + + // Make sure the returned peer dial + // addresses are valid + for _, peer := range r.Peers { + if err := peer.Validate(); err != nil { + return err + } + } + + return nil +} diff --git a/tm2/pkg/p2p/discovery/types_test.go b/tm2/pkg/p2p/discovery/types_test.go new file mode 100644 index 00000000000..0ac2c16f4e5 --- /dev/null +++ b/tm2/pkg/p2p/discovery/types_test.go @@ -0,0 +1,80 @@ +package discovery + +import ( + "net" + "testing" + + "github.com/gnolang/gno/tm2/pkg/p2p/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// generateNetAddrs generates random net addresses +func generateNetAddrs(t *testing.T, count int) []*types.NetAddress { + t.Helper() + + addrs := make([]*types.NetAddress, count) + + for i := 0; i < count; i++ { + var ( + key = types.GenerateNodeKey() + address = "127.0.0.1:8080" + ) + + tcpAddr, err := net.ResolveTCPAddr("tcp", address) + require.NoError(t, err) + + addr, err := types.NewNetAddress(key.ID(), tcpAddr) + require.NoError(t, err) + + addrs[i] = addr + } + + return addrs +} + +func TestRequest_ValidateBasic(t *testing.T) { + t.Parallel() + + r := &Request{} + + assert.NoError(t, r.ValidateBasic()) +} + +func TestResponse_ValidateBasic(t *testing.T) { + t.Parallel() + + t.Run("empty peer set", func(t *testing.T) { + t.Parallel() + + r := &Response{ + Peers: make([]*types.NetAddress, 0), + } + + assert.ErrorIs(t, r.ValidateBasic(), errNoPeers) + }) + + t.Run("invalid peer dial address", func(t *testing.T) { + t.Parallel() + + r := &Response{ + Peers: []*types.NetAddress{ + { + ID: "", // invalid ID + }, + }, + } + + assert.Error(t, r.ValidateBasic()) + }) + + t.Run("valid peer set", func(t *testing.T) { + t.Parallel() + + r := &Response{ + Peers: generateNetAddrs(t, 10), + } + + assert.NoError(t, r.ValidateBasic()) + }) +} diff --git a/tm2/pkg/p2p/errors.go b/tm2/pkg/p2p/errors.go deleted file mode 100644 index d4ad58e8ab5..00000000000 --- a/tm2/pkg/p2p/errors.go +++ /dev/null @@ -1,184 +0,0 @@ -package p2p - -import ( - "fmt" - "net" -) - -// FilterTimeoutError indicates that a filter operation timed out. -type FilterTimeoutError struct{} - -func (e FilterTimeoutError) Error() string { - return "filter timed out" -} - -// RejectedError indicates that a Peer was rejected carrying additional -// information as to the reason. -type RejectedError struct { - addr NetAddress - conn net.Conn - err error - id ID - isAuthFailure bool - isDuplicate bool - isFiltered bool - isIncompatible bool - isNodeInfoInvalid bool - isSelf bool -} - -// Addr returns the NetAddress for the rejected Peer. -func (e RejectedError) Addr() NetAddress { - return e.addr -} - -func (e RejectedError) Error() string { - if e.isAuthFailure { - return fmt.Sprintf("auth failure: %s", e.err) - } - - if e.isDuplicate { - if e.conn != nil { - return fmt.Sprintf( - "duplicate CONN<%s>", - e.conn.RemoteAddr().String(), - ) - } - if !e.id.IsZero() { - return fmt.Sprintf("duplicate ID<%v>", e.id) - } - } - - if e.isFiltered { - if e.conn != nil { - return fmt.Sprintf( - "filtered CONN<%s>: %s", - e.conn.RemoteAddr().String(), - e.err, - ) - } - - if !e.id.IsZero() { - return fmt.Sprintf("filtered ID<%v>: %s", e.id, e.err) - } - } - - if e.isIncompatible { - return fmt.Sprintf("incompatible: %s", e.err) - } - - if e.isNodeInfoInvalid { - return fmt.Sprintf("invalid NodeInfo: %s", e.err) - } - - if e.isSelf { - return fmt.Sprintf("self ID<%v>", e.id) - } - - return fmt.Sprintf("%s", e.err) -} - -// IsAuthFailure when Peer authentication was unsuccessful. -func (e RejectedError) IsAuthFailure() bool { return e.isAuthFailure } - -// IsDuplicate when Peer ID or IP are present already. -func (e RejectedError) IsDuplicate() bool { return e.isDuplicate } - -// IsFiltered when Peer ID or IP was filtered. -func (e RejectedError) IsFiltered() bool { return e.isFiltered } - -// IsIncompatible when Peer NodeInfo is not compatible with our own. -func (e RejectedError) IsIncompatible() bool { return e.isIncompatible } - -// IsNodeInfoInvalid when the sent NodeInfo is not valid. -func (e RejectedError) IsNodeInfoInvalid() bool { return e.isNodeInfoInvalid } - -// IsSelf when Peer is our own node. -func (e RejectedError) IsSelf() bool { return e.isSelf } - -// SwitchDuplicatePeerIDError to be raised when a peer is connecting with a known -// ID. -type SwitchDuplicatePeerIDError struct { - ID ID -} - -func (e SwitchDuplicatePeerIDError) Error() string { - return fmt.Sprintf("duplicate peer ID %v", e.ID) -} - -// SwitchDuplicatePeerIPError to be raised when a peer is connecting with a known -// IP. -type SwitchDuplicatePeerIPError struct { - IP net.IP -} - -func (e SwitchDuplicatePeerIPError) Error() string { - return fmt.Sprintf("duplicate peer IP %v", e.IP.String()) -} - -// SwitchConnectToSelfError to be raised when trying to connect to itself. -type SwitchConnectToSelfError struct { - Addr *NetAddress -} - -func (e SwitchConnectToSelfError) Error() string { - return fmt.Sprintf("connect to self: %v", e.Addr) -} - -type SwitchAuthenticationFailureError struct { - Dialed *NetAddress - Got ID -} - -func (e SwitchAuthenticationFailureError) Error() string { - return fmt.Sprintf( - "failed to authenticate peer. Dialed %v, but got peer with ID %s", - e.Dialed, - e.Got, - ) -} - -// TransportClosedError is raised when the Transport has been closed. -type TransportClosedError struct{} - -func (e TransportClosedError) Error() string { - return "transport has been closed" -} - -// ------------------------------------------------------------------- - -type NetAddressNoIDError struct { - Addr string -} - -func (e NetAddressNoIDError) Error() string { - return fmt.Sprintf("address (%s) does not contain ID", e.Addr) -} - -type NetAddressInvalidError struct { - Addr string - Err error -} - -func (e NetAddressInvalidError) Error() string { - return fmt.Sprintf("invalid address (%s): %v", e.Addr, e.Err) -} - -type NetAddressLookupError struct { - Addr string - Err error -} - -func (e NetAddressLookupError) Error() string { - return fmt.Sprintf("error looking up host (%s): %v", e.Addr, e.Err) -} - -// CurrentlyDialingOrExistingAddressError indicates that we're currently -// dialing this address or it belongs to an existing peer. -type CurrentlyDialingOrExistingAddressError struct { - Addr string -} - -func (e CurrentlyDialingOrExistingAddressError) Error() string { - return fmt.Sprintf("connection with %s has been established or dialed", e.Addr) -} diff --git a/tm2/pkg/p2p/events/doc.go b/tm2/pkg/p2p/events/doc.go new file mode 100644 index 00000000000..a624102379e --- /dev/null +++ b/tm2/pkg/p2p/events/doc.go @@ -0,0 +1,3 @@ +// Package events contains a simple p2p event system implementation, that simplifies asynchronous event flows in the +// p2p module. The event subscriptions allow for event filtering, which eases the load on the event notification flow. +package events diff --git a/tm2/pkg/p2p/events/events.go b/tm2/pkg/p2p/events/events.go new file mode 100644 index 00000000000..1eb4699fb45 --- /dev/null +++ b/tm2/pkg/p2p/events/events.go @@ -0,0 +1,112 @@ +package events + +import ( + "sync" + + "github.com/rs/xid" +) + +// EventFilter is the filter function used to +// filter incoming p2p events. A false flag will +// consider the event as irrelevant +type EventFilter func(Event) bool + +// Events is the p2p event switch +type Events struct { + subs subscriptions + subscriptionsMux sync.RWMutex +} + +// New creates a new event subscription manager +func New() *Events { + return &Events{ + subs: make(subscriptions), + } +} + +// Subscribe registers a new filtered event listener +func (es *Events) Subscribe(filterFn EventFilter) (<-chan Event, func()) { + es.subscriptionsMux.Lock() + defer es.subscriptionsMux.Unlock() + + // Create a new subscription + id, ch := es.subs.add(filterFn) + + // Create the unsubscribe callback + unsubscribeFn := func() { + es.subscriptionsMux.Lock() + defer es.subscriptionsMux.Unlock() + + es.subs.remove(id) + } + + return ch, unsubscribeFn +} + +// Notify notifies all subscribers of an incoming event [BLOCKING] +func (es *Events) Notify(event Event) { + es.subscriptionsMux.RLock() + defer es.subscriptionsMux.RUnlock() + + es.subs.notify(event) +} + +type ( + // subscriptions holds the corresponding subscription information + subscriptions map[string]subscription // subscription ID -> subscription + + // subscription wraps the subscription notification channel, + // and the event filter + subscription struct { + ch chan Event + filterFn EventFilter + } +) + +// add adds a new subscription to the subscription map. +// Returns the subscription ID, and update channel +func (s *subscriptions) add(filterFn EventFilter) (string, chan Event) { + var ( + id = xid.New().String() + // Since the event stream is non-blocking, + // the event buffer should be sufficiently + // large for most use-cases. Subscribers can + // handle large event load caller-side to mitigate + // events potentially being missed + ch = make(chan Event, 100) + ) + + (*s)[id] = subscription{ + ch: ch, + filterFn: filterFn, + } + + return id, ch +} + +// remove removes the given subscription +func (s *subscriptions) remove(id string) { + if sub, exists := (*s)[id]; exists { + // Close the notification channel + close(sub.ch) + } + + // Delete the subscription + delete(*s, id) +} + +// notify notifies all subscription listeners, +// if their filters pass +func (s *subscriptions) notify(event Event) { + // Notify the listeners + for _, sub := range *s { + if !sub.filterFn(event) { + continue + } + + select { + case sub.ch <- event: + default: // non-blocking + } + } +} diff --git a/tm2/pkg/p2p/events/events_test.go b/tm2/pkg/p2p/events/events_test.go new file mode 100644 index 00000000000..a0feafceddb --- /dev/null +++ b/tm2/pkg/p2p/events/events_test.go @@ -0,0 +1,94 @@ +package events + +import ( + "fmt" + "sync" + "testing" + "time" + + "github.com/gnolang/gno/tm2/pkg/p2p/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// generateEvents generates p2p events +func generateEvents(count int) []Event { + events := make([]Event, 0, count) + + for i := range count { + var event Event + + if i%2 == 0 { + event = PeerConnectedEvent{ + PeerID: types.ID(fmt.Sprintf("peer-%d", i)), + } + } else { + event = PeerDisconnectedEvent{ + PeerID: types.ID(fmt.Sprintf("peer-%d", i)), + } + } + + events = append(events, event) + } + + return events +} + +func TestEvents_Subscribe(t *testing.T) { + t.Parallel() + + var ( + capturedEvents []Event + + events = generateEvents(10) + subFn = func(e Event) bool { + return e.Type() == PeerDisconnected + } + ) + + // Create the events manager + e := New() + + // Subscribe to events + ch, unsubFn := e.Subscribe(subFn) + defer unsubFn() + + // Listen for the events + var wg sync.WaitGroup + + wg.Add(1) + + go func() { + defer wg.Done() + + timeout := time.After(5 * time.Second) + + for { + select { + case ev := <-ch: + capturedEvents = append(capturedEvents, ev) + + if len(capturedEvents) == len(events)/2 { + return + } + case <-timeout: + return + } + } + }() + + // Send out the events + for _, ev := range events { + e.Notify(ev) + } + + wg.Wait() + + // Make sure the events were captured + // and filtered properly + require.Len(t, capturedEvents, len(events)/2) + + for _, ev := range capturedEvents { + assert.Equal(t, ev.Type(), PeerDisconnected) + } +} diff --git a/tm2/pkg/p2p/events/types.go b/tm2/pkg/p2p/events/types.go new file mode 100644 index 00000000000..cbaac1816ff --- /dev/null +++ b/tm2/pkg/p2p/events/types.go @@ -0,0 +1,39 @@ +package events + +import ( + "net" + + "github.com/gnolang/gno/tm2/pkg/p2p/types" +) + +type EventType string + +const ( + PeerConnected EventType = "PeerConnected" // emitted when a fresh peer connects + PeerDisconnected EventType = "PeerDisconnected" // emitted when a peer disconnects +) + +// Event is a generic p2p event +type Event interface { + // Type returns the type information for the event + Type() EventType +} + +type PeerConnectedEvent struct { + PeerID types.ID // the ID of the peer + Address net.Addr // the remote address of the peer +} + +func (p PeerConnectedEvent) Type() EventType { + return PeerConnected +} + +type PeerDisconnectedEvent struct { + PeerID types.ID // the ID of the peer + Address net.Addr // the remote address of the peer + Reason error // the disconnect reason, if any +} + +func (p PeerDisconnectedEvent) Type() EventType { + return PeerDisconnected +} diff --git a/tm2/pkg/p2p/fuzz.go b/tm2/pkg/p2p/fuzz.go deleted file mode 100644 index 03cf88cf750..00000000000 --- a/tm2/pkg/p2p/fuzz.go +++ /dev/null @@ -1,131 +0,0 @@ -package p2p - -import ( - "net" - "sync" - "time" - - "github.com/gnolang/gno/tm2/pkg/p2p/config" - "github.com/gnolang/gno/tm2/pkg/random" -) - -// FuzzedConnection wraps any net.Conn and depending on the mode either delays -// reads/writes or randomly drops reads/writes/connections. -type FuzzedConnection struct { - conn net.Conn - - mtx sync.Mutex - start <-chan time.Time - active bool - - config *config.FuzzConnConfig -} - -// FuzzConnAfterFromConfig creates a new FuzzedConnection from a config. -// Fuzzing starts when the duration elapses. -func FuzzConnAfterFromConfig( - conn net.Conn, - d time.Duration, - config *config.FuzzConnConfig, -) net.Conn { - return &FuzzedConnection{ - conn: conn, - start: time.After(d), - active: false, - config: config, - } -} - -// Config returns the connection's config. -func (fc *FuzzedConnection) Config() *config.FuzzConnConfig { - return fc.config -} - -// Read implements net.Conn. -func (fc *FuzzedConnection) Read(data []byte) (n int, err error) { - if fc.fuzz() { - return 0, nil - } - return fc.conn.Read(data) -} - -// Write implements net.Conn. -func (fc *FuzzedConnection) Write(data []byte) (n int, err error) { - if fc.fuzz() { - return 0, nil - } - return fc.conn.Write(data) -} - -// Close implements net.Conn. -func (fc *FuzzedConnection) Close() error { return fc.conn.Close() } - -// LocalAddr implements net.Conn. -func (fc *FuzzedConnection) LocalAddr() net.Addr { return fc.conn.LocalAddr() } - -// RemoteAddr implements net.Conn. -func (fc *FuzzedConnection) RemoteAddr() net.Addr { return fc.conn.RemoteAddr() } - -// SetDeadline implements net.Conn. -func (fc *FuzzedConnection) SetDeadline(t time.Time) error { return fc.conn.SetDeadline(t) } - -// SetReadDeadline implements net.Conn. -func (fc *FuzzedConnection) SetReadDeadline(t time.Time) error { - return fc.conn.SetReadDeadline(t) -} - -// SetWriteDeadline implements net.Conn. -func (fc *FuzzedConnection) SetWriteDeadline(t time.Time) error { - return fc.conn.SetWriteDeadline(t) -} - -func (fc *FuzzedConnection) randomDuration() time.Duration { - maxDelayMillis := int(fc.config.MaxDelay.Nanoseconds() / 1000) - return time.Millisecond * time.Duration(random.RandInt()%maxDelayMillis) //nolint: gas -} - -// implements the fuzz (delay, kill conn) -// and returns whether or not the read/write should be ignored -func (fc *FuzzedConnection) fuzz() bool { - if !fc.shouldFuzz() { - return false - } - - switch fc.config.Mode { - case config.FuzzModeDrop: - // randomly drop the r/w, drop the conn, or sleep - r := random.RandFloat64() - switch { - case r <= fc.config.ProbDropRW: - return true - case r < fc.config.ProbDropRW+fc.config.ProbDropConn: - // XXX: can't this fail because machine precision? - // XXX: do we need an error? - fc.Close() //nolint: errcheck, gas - return true - case r < fc.config.ProbDropRW+fc.config.ProbDropConn+fc.config.ProbSleep: - time.Sleep(fc.randomDuration()) - } - case config.FuzzModeDelay: - // sleep a bit - time.Sleep(fc.randomDuration()) - } - return false -} - -func (fc *FuzzedConnection) shouldFuzz() bool { - if fc.active { - return true - } - - fc.mtx.Lock() - defer fc.mtx.Unlock() - - select { - case <-fc.start: - fc.active = true - return true - default: - return false - } -} diff --git a/tm2/pkg/p2p/key.go b/tm2/pkg/p2p/key.go deleted file mode 100644 index a41edeb07f8..00000000000 --- a/tm2/pkg/p2p/key.go +++ /dev/null @@ -1,94 +0,0 @@ -package p2p - -import ( - "bytes" - "fmt" - "os" - - "github.com/gnolang/gno/tm2/pkg/amino" - "github.com/gnolang/gno/tm2/pkg/crypto" - "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" - osm "github.com/gnolang/gno/tm2/pkg/os" -) - -// ------------------------------------------------------------------------------ -// Persistent peer ID -// TODO: encrypt on disk - -// NodeKey is the persistent peer key. -// It contains the nodes private key for authentication. -// NOTE: keep in sync with gno.land/cmd/gnoland/secrets.go -type NodeKey struct { - crypto.PrivKey `json:"priv_key"` // our priv key -} - -func (nk NodeKey) ID() ID { - return nk.PubKey().Address().ID() -} - -// LoadOrGenNodeKey attempts to load the NodeKey from the given filePath. -// If the file does not exist, it generates and saves a new NodeKey. -func LoadOrGenNodeKey(filePath string) (*NodeKey, error) { - if osm.FileExists(filePath) { - nodeKey, err := LoadNodeKey(filePath) - if err != nil { - return nil, err - } - return nodeKey, nil - } - return genNodeKey(filePath) -} - -func LoadNodeKey(filePath string) (*NodeKey, error) { - jsonBytes, err := os.ReadFile(filePath) - if err != nil { - return nil, err - } - nodeKey := new(NodeKey) - err = amino.UnmarshalJSON(jsonBytes, nodeKey) - if err != nil { - return nil, fmt.Errorf("Error reading NodeKey from %v: %w", filePath, err) - } - return nodeKey, nil -} - -func genNodeKey(filePath string) (*NodeKey, error) { - privKey := ed25519.GenPrivKey() - nodeKey := &NodeKey{ - PrivKey: privKey, - } - - jsonBytes, err := amino.MarshalJSON(nodeKey) - if err != nil { - return nil, err - } - err = os.WriteFile(filePath, jsonBytes, 0o600) - if err != nil { - return nil, err - } - return nodeKey, nil -} - -// ------------------------------------------------------------------------------ - -// MakePoWTarget returns the big-endian encoding of 2^(targetBits - difficulty) - 1. -// It can be used as a Proof of Work target. -// NOTE: targetBits must be a multiple of 8 and difficulty must be less than targetBits. -func MakePoWTarget(difficulty, targetBits uint) []byte { - if targetBits%8 != 0 { - panic(fmt.Sprintf("targetBits (%d) not a multiple of 8", targetBits)) - } - if difficulty >= targetBits { - panic(fmt.Sprintf("difficulty (%d) >= targetBits (%d)", difficulty, targetBits)) - } - targetBytes := targetBits / 8 - zeroPrefixLen := (int(difficulty) / 8) - prefix := bytes.Repeat([]byte{0}, zeroPrefixLen) - mod := (difficulty % 8) - if mod > 0 { - nonZeroPrefix := byte(1<<(8-mod) - 1) - prefix = append(prefix, nonZeroPrefix) - } - tailLen := int(targetBytes) - len(prefix) - return append(prefix, bytes.Repeat([]byte{0xFF}, tailLen)...) -} diff --git a/tm2/pkg/p2p/key_test.go b/tm2/pkg/p2p/key_test.go deleted file mode 100644 index 4f67cc0a5da..00000000000 --- a/tm2/pkg/p2p/key_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package p2p - -import ( - "bytes" - "os" - "path/filepath" - "testing" - - "github.com/gnolang/gno/tm2/pkg/random" - "github.com/stretchr/testify/assert" -) - -func TestLoadOrGenNodeKey(t *testing.T) { - t.Parallel() - - filePath := filepath.Join(os.TempDir(), random.RandStr(12)+"_peer_id.json") - - nodeKey, err := LoadOrGenNodeKey(filePath) - assert.Nil(t, err) - - nodeKey2, err := LoadOrGenNodeKey(filePath) - assert.Nil(t, err) - - assert.Equal(t, nodeKey, nodeKey2) -} - -// ---------------------------------------------------------- - -func padBytes(bz []byte, targetBytes int) []byte { - return append(bz, bytes.Repeat([]byte{0xFF}, targetBytes-len(bz))...) -} - -func TestPoWTarget(t *testing.T) { - t.Parallel() - - targetBytes := 20 - cases := []struct { - difficulty uint - target []byte - }{ - {0, padBytes([]byte{}, targetBytes)}, - {1, padBytes([]byte{127}, targetBytes)}, - {8, padBytes([]byte{0}, targetBytes)}, - {9, padBytes([]byte{0, 127}, targetBytes)}, - {10, padBytes([]byte{0, 63}, targetBytes)}, - {16, padBytes([]byte{0, 0}, targetBytes)}, - {17, padBytes([]byte{0, 0, 127}, targetBytes)}, - } - - for _, c := range cases { - assert.Equal(t, MakePoWTarget(c.difficulty, 20*8), c.target) - } -} diff --git a/tm2/pkg/p2p/mock/peer.go b/tm2/pkg/p2p/mock/peer.go index 906c168c3a8..e5a01952831 100644 --- a/tm2/pkg/p2p/mock/peer.go +++ b/tm2/pkg/p2p/mock/peer.go @@ -1,68 +1,377 @@ package mock import ( + "fmt" + "io" + "log/slog" "net" + "testing" + "time" - "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" - "github.com/gnolang/gno/tm2/pkg/p2p" "github.com/gnolang/gno/tm2/pkg/p2p/conn" + "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/gnolang/gno/tm2/pkg/service" + "github.com/stretchr/testify/require" ) +type ( + flushStopDelegate func() + idDelegate func() types.ID + remoteIPDelegate func() net.IP + remoteAddrDelegate func() net.Addr + isOutboundDelegate func() bool + isPersistentDelegate func() bool + isPrivateDelegate func() bool + closeConnDelegate func() error + nodeInfoDelegate func() types.NodeInfo + statusDelegate func() conn.ConnectionStatus + socketAddrDelegate func() *types.NetAddress + sendDelegate func(byte, []byte) bool + trySendDelegate func(byte, []byte) bool + setDelegate func(string, any) + getDelegate func(string) any + stopDelegate func() error +) + +// GeneratePeers generates random peers +func GeneratePeers(t *testing.T, count int) []*Peer { + t.Helper() + + peers := make([]*Peer, count) + + for i := range count { + var ( + key = types.GenerateNodeKey() + address = "127.0.0.1:8080" + ) + + tcpAddr, err := net.ResolveTCPAddr("tcp", address) + require.NoError(t, err) + + addr, err := types.NewNetAddress(key.ID(), tcpAddr) + require.NoError(t, err) + + p := &Peer{ + IDFn: func() types.ID { + return key.ID() + }, + NodeInfoFn: func() types.NodeInfo { + return types.NodeInfo{ + PeerID: key.ID(), + } + }, + SocketAddrFn: func() *types.NetAddress { + return addr + }, + } + + p.BaseService = *service.NewBaseService( + slog.New(slog.NewTextHandler(io.Discard, nil)), + fmt.Sprintf("peer-%d", i), + p, + ) + + peers[i] = p + } + + return peers +} + type Peer struct { - *service.BaseService - ip net.IP - id p2p.ID - addr *p2p.NetAddress - kv map[string]interface{} - Outbound, Persistent bool -} - -// NewPeer creates and starts a new mock peer. If the ip -// is nil, random routable address is used. -func NewPeer(ip net.IP) *Peer { - var netAddr *p2p.NetAddress - if ip == nil { - _, netAddr = p2p.CreateRoutableAddr() - } else { - netAddr = p2p.NewNetAddressFromIPPort("", ip, 26656) - } - nodeKey := p2p.NodeKey{PrivKey: ed25519.GenPrivKey()} - netAddr.ID = nodeKey.ID() - mp := &Peer{ - ip: ip, - id: nodeKey.ID(), - addr: netAddr, - kv: make(map[string]interface{}), - } - mp.BaseService = service.NewBaseService(nil, "MockPeer", mp) - mp.Start() - return mp -} - -func (mp *Peer) FlushStop() { mp.Stop() } -func (mp *Peer) TrySend(chID byte, msgBytes []byte) bool { return true } -func (mp *Peer) Send(chID byte, msgBytes []byte) bool { return true } -func (mp *Peer) NodeInfo() p2p.NodeInfo { - return p2p.NodeInfo{ - NetAddress: mp.addr, - } -} -func (mp *Peer) Status() conn.ConnectionStatus { return conn.ConnectionStatus{} } -func (mp *Peer) ID() p2p.ID { return mp.id } -func (mp *Peer) IsOutbound() bool { return mp.Outbound } -func (mp *Peer) IsPersistent() bool { return mp.Persistent } -func (mp *Peer) Get(key string) interface{} { - if value, ok := mp.kv[key]; ok { - return value + service.BaseService + + FlushStopFn flushStopDelegate + IDFn idDelegate + RemoteIPFn remoteIPDelegate + RemoteAddrFn remoteAddrDelegate + IsOutboundFn isOutboundDelegate + IsPersistentFn isPersistentDelegate + IsPrivateFn isPrivateDelegate + CloseConnFn closeConnDelegate + NodeInfoFn nodeInfoDelegate + StopFn stopDelegate + StatusFn statusDelegate + SocketAddrFn socketAddrDelegate + SendFn sendDelegate + TrySendFn trySendDelegate + SetFn setDelegate + GetFn getDelegate +} + +func (m *Peer) FlushStop() { + if m.FlushStopFn != nil { + m.FlushStopFn() + } +} + +func (m *Peer) ID() types.ID { + if m.IDFn != nil { + return m.IDFn() + } + + return "" +} + +func (m *Peer) RemoteIP() net.IP { + if m.RemoteIPFn != nil { + return m.RemoteIPFn() + } + + return nil +} + +func (m *Peer) RemoteAddr() net.Addr { + if m.RemoteAddrFn != nil { + return m.RemoteAddrFn() + } + + return nil +} + +func (m *Peer) Stop() error { + if m.StopFn != nil { + return m.StopFn() + } + + return nil +} + +func (m *Peer) IsOutbound() bool { + if m.IsOutboundFn != nil { + return m.IsOutboundFn() + } + + return false +} + +func (m *Peer) IsPersistent() bool { + if m.IsPersistentFn != nil { + return m.IsPersistentFn() + } + + return false +} + +func (m *Peer) IsPrivate() bool { + if m.IsPrivateFn != nil { + return m.IsPrivateFn() + } + + return false +} + +func (m *Peer) CloseConn() error { + if m.CloseConnFn != nil { + return m.CloseConnFn() + } + + return nil +} + +func (m *Peer) NodeInfo() types.NodeInfo { + if m.NodeInfoFn != nil { + return m.NodeInfoFn() + } + + return types.NodeInfo{} +} + +func (m *Peer) Status() conn.ConnectionStatus { + if m.StatusFn != nil { + return m.StatusFn() + } + + return conn.ConnectionStatus{} +} + +func (m *Peer) SocketAddr() *types.NetAddress { + if m.SocketAddrFn != nil { + return m.SocketAddrFn() + } + + return nil +} + +func (m *Peer) Send(classifier byte, data []byte) bool { + if m.SendFn != nil { + return m.SendFn(classifier, data) + } + + return false +} + +func (m *Peer) TrySend(classifier byte, data []byte) bool { + if m.TrySendFn != nil { + return m.TrySendFn(classifier, data) + } + + return false +} + +func (m *Peer) Set(key string, data any) { + if m.SetFn != nil { + m.SetFn(key, data) + } +} + +func (m *Peer) Get(key string) any { + if m.GetFn != nil { + return m.GetFn(key) + } + + return nil +} + +type ( + readDelegate func([]byte) (int, error) + writeDelegate func([]byte) (int, error) + closeDelegate func() error + localAddrDelegate func() net.Addr + setDeadlineDelegate func(time.Time) error +) + +type Conn struct { + ReadFn readDelegate + WriteFn writeDelegate + CloseFn closeDelegate + LocalAddrFn localAddrDelegate + RemoteAddrFn remoteAddrDelegate + SetDeadlineFn setDeadlineDelegate + SetReadDeadlineFn setDeadlineDelegate + SetWriteDeadlineFn setDeadlineDelegate +} + +func (m *Conn) Read(b []byte) (int, error) { + if m.ReadFn != nil { + return m.ReadFn(b) + } + + return 0, nil +} + +func (m *Conn) Write(b []byte) (int, error) { + if m.WriteFn != nil { + return m.WriteFn(b) } + + return 0, nil +} + +func (m *Conn) Close() error { + if m.CloseFn != nil { + return m.CloseFn() + } + return nil } -func (mp *Peer) Set(key string, value interface{}) { - mp.kv[key] = value +func (m *Conn) LocalAddr() net.Addr { + if m.LocalAddrFn != nil { + return m.LocalAddrFn() + } + + return nil +} + +func (m *Conn) RemoteAddr() net.Addr { + if m.RemoteAddrFn != nil { + return m.RemoteAddrFn() + } + + return nil +} + +func (m *Conn) SetDeadline(t time.Time) error { + if m.SetDeadlineFn != nil { + return m.SetDeadlineFn(t) + } + + return nil +} + +func (m *Conn) SetReadDeadline(t time.Time) error { + if m.SetReadDeadlineFn != nil { + return m.SetReadDeadlineFn(t) + } + + return nil +} + +func (m *Conn) SetWriteDeadline(t time.Time) error { + if m.SetWriteDeadlineFn != nil { + return m.SetWriteDeadlineFn(t) + } + + return nil +} + +type ( + startDelegate func() error + stringDelegate func() string +) + +type MConn struct { + FlushFn flushStopDelegate + StartFn startDelegate + StopFn stopDelegate + SendFn sendDelegate + TrySendFn trySendDelegate + StatusFn statusDelegate + StringFn stringDelegate +} + +func (m *MConn) FlushStop() { + if m.FlushFn != nil { + m.FlushFn() + } +} + +func (m *MConn) Start() error { + if m.StartFn != nil { + return m.StartFn() + } + + return nil +} + +func (m *MConn) Stop() error { + if m.StopFn != nil { + return m.StopFn() + } + + return nil +} + +func (m *MConn) Send(ch byte, data []byte) bool { + if m.SendFn != nil { + return m.SendFn(ch, data) + } + + return false +} + +func (m *MConn) TrySend(ch byte, data []byte) bool { + if m.TrySendFn != nil { + return m.TrySendFn(ch, data) + } + + return false +} + +func (m *MConn) SetLogger(_ *slog.Logger) {} + +func (m *MConn) Status() conn.ConnectionStatus { + if m.StatusFn != nil { + return m.StatusFn() + } + + return conn.ConnectionStatus{} +} + +func (m *MConn) String() string { + if m.StringFn != nil { + return m.StringFn() + } + + return "" } -func (mp *Peer) RemoteIP() net.IP { return mp.ip } -func (mp *Peer) SocketAddr() *p2p.NetAddress { return mp.addr } -func (mp *Peer) RemoteAddr() net.Addr { return &net.TCPAddr{IP: mp.ip, Port: 8800} } -func (mp *Peer) CloseConn() error { return nil } diff --git a/tm2/pkg/p2p/mock/reactor.go b/tm2/pkg/p2p/mock/reactor.go deleted file mode 100644 index fe123fdc0b2..00000000000 --- a/tm2/pkg/p2p/mock/reactor.go +++ /dev/null @@ -1,23 +0,0 @@ -package mock - -import ( - "github.com/gnolang/gno/tm2/pkg/log" - "github.com/gnolang/gno/tm2/pkg/p2p" - "github.com/gnolang/gno/tm2/pkg/p2p/conn" -) - -type Reactor struct { - p2p.BaseReactor -} - -func NewReactor() *Reactor { - r := &Reactor{} - r.BaseReactor = *p2p.NewBaseReactor("Reactor", r) - r.SetLogger(log.NewNoopLogger()) - return r -} - -func (r *Reactor) GetChannels() []*conn.ChannelDescriptor { return []*conn.ChannelDescriptor{} } -func (r *Reactor) AddPeer(peer p2p.Peer) {} -func (r *Reactor) RemovePeer(peer p2p.Peer, reason interface{}) {} -func (r *Reactor) Receive(chID byte, peer p2p.Peer, msgBytes []byte) {} diff --git a/tm2/pkg/p2p/mock_test.go b/tm2/pkg/p2p/mock_test.go new file mode 100644 index 00000000000..5fd7f947b71 --- /dev/null +++ b/tm2/pkg/p2p/mock_test.go @@ -0,0 +1,404 @@ +package p2p + +import ( + "context" + "log/slog" + "net" + "time" + + "github.com/gnolang/gno/tm2/pkg/p2p/conn" + "github.com/gnolang/gno/tm2/pkg/p2p/types" +) + +type ( + netAddressDelegate func() types.NetAddress + acceptDelegate func(context.Context, PeerBehavior) (PeerConn, error) + dialDelegate func(context.Context, types.NetAddress, PeerBehavior) (PeerConn, error) + removeDelegate func(PeerConn) +) + +type mockTransport struct { + netAddressFn netAddressDelegate + acceptFn acceptDelegate + dialFn dialDelegate + removeFn removeDelegate +} + +func (m *mockTransport) NetAddress() types.NetAddress { + if m.netAddressFn != nil { + return m.netAddressFn() + } + + return types.NetAddress{} +} + +func (m *mockTransport) Accept(ctx context.Context, behavior PeerBehavior) (PeerConn, error) { + if m.acceptFn != nil { + return m.acceptFn(ctx, behavior) + } + + return nil, nil +} + +func (m *mockTransport) Dial(ctx context.Context, address types.NetAddress, behavior PeerBehavior) (PeerConn, error) { + if m.dialFn != nil { + return m.dialFn(ctx, address, behavior) + } + + return nil, nil +} + +func (m *mockTransport) Remove(p PeerConn) { + if m.removeFn != nil { + m.removeFn(p) + } +} + +type ( + addDelegate func(PeerConn) + removePeerDelegate func(types.ID) bool + hasDelegate func(types.ID) bool + hasIPDelegate func(net.IP) bool + getDelegate func(types.ID) PeerConn + listDelegate func() []PeerConn + numInboundDelegate func() uint64 + numOutboundDelegate func() uint64 +) + +type mockSet struct { + addFn addDelegate + removeFn removePeerDelegate + hasFn hasDelegate + hasIPFn hasIPDelegate + listFn listDelegate + getFn getDelegate + numInboundFn numInboundDelegate + numOutboundFn numOutboundDelegate +} + +func (m *mockSet) Add(peer PeerConn) { + if m.addFn != nil { + m.addFn(peer) + } +} + +func (m *mockSet) Remove(key types.ID) bool { + if m.removeFn != nil { + m.removeFn(key) + } + + return false +} + +func (m *mockSet) Has(key types.ID) bool { + if m.hasFn != nil { + return m.hasFn(key) + } + + return false +} + +func (m *mockSet) Get(key types.ID) PeerConn { + if m.getFn != nil { + return m.getFn(key) + } + + return nil +} + +func (m *mockSet) List() []PeerConn { + if m.listFn != nil { + return m.listFn() + } + + return nil +} + +func (m *mockSet) NumInbound() uint64 { + if m.numInboundFn != nil { + return m.numInboundFn() + } + + return 0 +} + +func (m *mockSet) NumOutbound() uint64 { + if m.numOutboundFn != nil { + return m.numOutboundFn() + } + + return 0 +} + +type ( + listenerAcceptDelegate func() (net.Conn, error) + closeDelegate func() error + addrDelegate func() net.Addr +) + +type mockListener struct { + acceptFn listenerAcceptDelegate + closeFn closeDelegate + addrFn addrDelegate +} + +func (m *mockListener) Accept() (net.Conn, error) { + if m.acceptFn != nil { + return m.acceptFn() + } + + return nil, nil +} + +func (m *mockListener) Close() error { + if m.closeFn != nil { + return m.closeFn() + } + + return nil +} + +func (m *mockListener) Addr() net.Addr { + if m.addrFn != nil { + return m.addrFn() + } + + return nil +} + +type ( + readDelegate func([]byte) (int, error) + writeDelegate func([]byte) (int, error) + localAddrDelegate func() net.Addr + remoteAddrDelegate func() net.Addr + setDeadlineDelegate func(time.Time) error +) + +type mockConn struct { + readFn readDelegate + writeFn writeDelegate + closeFn closeDelegate + localAddrFn localAddrDelegate + remoteAddrFn remoteAddrDelegate + setDeadlineFn setDeadlineDelegate + setReadDeadlineFn setDeadlineDelegate + setWriteDeadlineFn setDeadlineDelegate +} + +func (m *mockConn) Read(buff []byte) (int, error) { + if m.readFn != nil { + return m.readFn(buff) + } + + return 0, nil +} + +func (m *mockConn) Write(buff []byte) (int, error) { + if m.writeFn != nil { + return m.writeFn(buff) + } + + return 0, nil +} + +func (m *mockConn) Close() error { + if m.closeFn != nil { + return m.closeFn() + } + + return nil +} + +func (m *mockConn) LocalAddr() net.Addr { + if m.localAddrFn != nil { + return m.localAddrFn() + } + + return nil +} + +func (m *mockConn) RemoteAddr() net.Addr { + if m.remoteAddrFn != nil { + return m.remoteAddrFn() + } + + return nil +} + +func (m *mockConn) SetDeadline(t time.Time) error { + if m.setDeadlineFn != nil { + return m.setDeadlineFn(t) + } + + return nil +} + +func (m *mockConn) SetReadDeadline(t time.Time) error { + if m.setReadDeadlineFn != nil { + return m.setReadDeadlineFn(t) + } + + return nil +} + +func (m *mockConn) SetWriteDeadline(t time.Time) error { + if m.setWriteDeadlineFn != nil { + return m.setWriteDeadlineFn(t) + } + + return nil +} + +type ( + startDelegate func() error + onStartDelegate func() error + stopDelegate func() error + onStopDelegate func() + resetDelegate func() error + onResetDelegate func() error + isRunningDelegate func() bool + quitDelegate func() <-chan struct{} + stringDelegate func() string + setLoggerDelegate func(*slog.Logger) + setSwitchDelegate func(Switch) + getChannelsDelegate func() []*conn.ChannelDescriptor + initPeerDelegate func(PeerConn) + addPeerDelegate func(PeerConn) + removeSwitchPeerDelegate func(PeerConn, any) + receiveDelegate func(byte, PeerConn, []byte) +) + +type mockReactor struct { + startFn startDelegate + onStartFn onStartDelegate + stopFn stopDelegate + onStopFn onStopDelegate + resetFn resetDelegate + onResetFn onResetDelegate + isRunningFn isRunningDelegate + quitFn quitDelegate + stringFn stringDelegate + setLoggerFn setLoggerDelegate + setSwitchFn setSwitchDelegate + getChannelsFn getChannelsDelegate + initPeerFn initPeerDelegate + addPeerFn addPeerDelegate + removePeerFn removeSwitchPeerDelegate + receiveFn receiveDelegate +} + +func (m *mockReactor) Start() error { + if m.startFn != nil { + return m.startFn() + } + + return nil +} + +func (m *mockReactor) OnStart() error { + if m.onStartFn != nil { + return m.onStartFn() + } + + return nil +} + +func (m *mockReactor) Stop() error { + if m.stopFn != nil { + return m.stopFn() + } + + return nil +} + +func (m *mockReactor) OnStop() { + if m.onStopFn != nil { + m.onStopFn() + } +} + +func (m *mockReactor) Reset() error { + if m.resetFn != nil { + return m.resetFn() + } + + return nil +} + +func (m *mockReactor) OnReset() error { + if m.onResetFn != nil { + return m.onResetFn() + } + + return nil +} + +func (m *mockReactor) IsRunning() bool { + if m.isRunningFn != nil { + return m.isRunningFn() + } + + return false +} + +func (m *mockReactor) Quit() <-chan struct{} { + if m.quitFn != nil { + return m.quitFn() + } + + return nil +} + +func (m *mockReactor) String() string { + if m.stringFn != nil { + return m.stringFn() + } + + return "" +} + +func (m *mockReactor) SetLogger(logger *slog.Logger) { + if m.setLoggerFn != nil { + m.setLoggerFn(logger) + } +} + +func (m *mockReactor) SetSwitch(s Switch) { + if m.setSwitchFn != nil { + m.setSwitchFn(s) + } +} + +func (m *mockReactor) GetChannels() []*conn.ChannelDescriptor { + if m.getChannelsFn != nil { + return m.getChannelsFn() + } + + return nil +} + +func (m *mockReactor) InitPeer(peer PeerConn) PeerConn { + if m.initPeerFn != nil { + m.initPeerFn(peer) + } + + return nil +} + +func (m *mockReactor) AddPeer(peer PeerConn) { + if m.addPeerFn != nil { + m.addPeerFn(peer) + } +} + +func (m *mockReactor) RemovePeer(peer PeerConn, reason any) { + if m.removePeerFn != nil { + m.removePeerFn(peer, reason) + } +} + +func (m *mockReactor) Receive(chID byte, peer PeerConn, msgBytes []byte) { + if m.receiveFn != nil { + m.receiveFn(chID, peer, msgBytes) + } +} diff --git a/tm2/pkg/p2p/netaddress_test.go b/tm2/pkg/p2p/netaddress_test.go deleted file mode 100644 index 413d020c153..00000000000 --- a/tm2/pkg/p2p/netaddress_test.go +++ /dev/null @@ -1,178 +0,0 @@ -package p2p - -import ( - "encoding/hex" - "net" - "testing" - - "github.com/gnolang/gno/tm2/pkg/crypto" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestAddress2ID(t *testing.T) { - t.Parallel() - - idbz, _ := hex.DecodeString("deadbeefdeadbeefdeadbeefdeadbeefdeadbeef") - id := crypto.AddressFromBytes(idbz).ID() - assert.Equal(t, crypto.ID("g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6"), id) - - idbz, _ = hex.DecodeString("deadbeefdeadbeefdeadbeefdeadbeefdead0000") - id = crypto.AddressFromBytes(idbz).ID() - assert.Equal(t, crypto.ID("g1m6kmam774klwlh4dhmhaatd7al026qqqq9c22r"), id) -} - -func TestNewNetAddress(t *testing.T) { - t.Parallel() - - tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:8080") - require.Nil(t, err) - - assert.Panics(t, func() { - NewNetAddress("", tcpAddr) - }) - - idbz, _ := hex.DecodeString("deadbeefdeadbeefdeadbeefdeadbeefdeadbeef") - id := crypto.AddressFromBytes(idbz).ID() - // ^-- is "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6" - - addr := NewNetAddress(id, tcpAddr) - assert.Equal(t, "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", addr.String()) - - assert.NotPanics(t, func() { - NewNetAddress("", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8000}) - }, "Calling NewNetAddress with UDPAddr should not panic in testing") -} - -func TestNewNetAddressFromString(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - addr string - expected string - correct bool - }{ - {"no node id and no protocol", "127.0.0.1:8080", "", false}, - {"no node id w/ tcp input", "tcp://127.0.0.1:8080", "", false}, - {"no node id w/ udp input", "udp://127.0.0.1:8080", "", false}, - - {"no protocol", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", true}, - {"tcp input", "tcp://g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", true}, - {"udp input", "udp://g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", true}, - {"malformed tcp input", "tcp//g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "", false}, - {"malformed udp input", "udp//g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "", false}, - - // {"127.0.0:8080", false}, - {"invalid host", "notahost", "", false}, - {"invalid port", "127.0.0.1:notapath", "", false}, - {"invalid host w/ port", "notahost:8080", "", false}, - {"just a port", "8082", "", false}, - {"non-existent port", "127.0.0:8080000", "", false}, - - {"too short nodeId", "deadbeef@127.0.0.1:8080", "", false}, - {"too short, not hex nodeId", "this-isnot-hex@127.0.0.1:8080", "", false}, - {"not bech32 nodeId", "xxxm6kmam774klwlh4dhmhaatd7al02m0h0hdap9l@127.0.0.1:8080", "", false}, - - {"too short nodeId w/tcp", "tcp://deadbeef@127.0.0.1:8080", "", false}, - {"too short notHex nodeId w/tcp", "tcp://this-isnot-hex@127.0.0.1:8080", "", false}, - {"not bech32 nodeId w/tcp", "tcp://xxxxm6kmam774klwlh4dhmhaatd7al02m0h0hdap9l@127.0.0.1:8080", "", false}, - {"correct nodeId w/tcp", "tcp://g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", true}, - - {"no node id", "tcp://@127.0.0.1:8080", "", false}, - {"no node id or IP", "tcp://@", "", false}, - {"tcp no host, w/ port", "tcp://:26656", "", false}, - {"empty", "", "", false}, - {"node id delimiter 1", "@", "", false}, - {"node id delimiter 2", " @", "", false}, - {"node id delimiter 3", " @ ", "", false}, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - addr, err := NewNetAddressFromString(tc.addr) - if tc.correct { - if assert.Nil(t, err, tc.addr) { - assert.Equal(t, tc.expected, addr.String()) - } - } else { - assert.NotNil(t, err, tc.addr) - } - }) - } -} - -func TestNewNetAddressFromStrings(t *testing.T) { - t.Parallel() - - addrs, errs := NewNetAddressFromStrings([]string{ - "127.0.0.1:8080", - "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", - "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.2:8080", - }) - assert.Len(t, errs, 1) - assert.Equal(t, 2, len(addrs)) -} - -func TestNewNetAddressFromIPPort(t *testing.T) { - t.Parallel() - - addr := NewNetAddressFromIPPort("", net.ParseIP("127.0.0.1"), 8080) - assert.Equal(t, "127.0.0.1:8080", addr.String()) -} - -func TestNetAddressProperties(t *testing.T) { - t.Parallel() - - // TODO add more test cases - testCases := []struct { - addr string - valid bool - local bool - routable bool - }{ - {"g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", true, true, false}, - {"g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@ya.ru:80", true, false, true}, - } - - for _, tc := range testCases { - addr, err := NewNetAddressFromString(tc.addr) - require.Nil(t, err) - - err = addr.Validate() - if tc.valid { - assert.NoError(t, err) - } else { - assert.Error(t, err) - } - assert.Equal(t, tc.local, addr.Local()) - assert.Equal(t, tc.routable, addr.Routable()) - } -} - -func TestNetAddressReachabilityTo(t *testing.T) { - t.Parallel() - - // TODO add more test cases - testCases := []struct { - addr string - other string - reachability int - }{ - {"g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8081", 0}, - {"g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@ya.ru:80", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", 1}, - } - - for _, tc := range testCases { - addr, err := NewNetAddressFromString(tc.addr) - require.Nil(t, err) - - other, err := NewNetAddressFromString(tc.other) - require.Nil(t, err) - - assert.Equal(t, tc.reachability, addr.ReachabilityTo(other)) - } -} diff --git a/tm2/pkg/p2p/node_info.go b/tm2/pkg/p2p/node_info.go deleted file mode 100644 index 48ba8f7776b..00000000000 --- a/tm2/pkg/p2p/node_info.go +++ /dev/null @@ -1,156 +0,0 @@ -package p2p - -import ( - "fmt" - - "github.com/gnolang/gno/tm2/pkg/bft/state/eventstore" - "github.com/gnolang/gno/tm2/pkg/strings" - "github.com/gnolang/gno/tm2/pkg/versionset" -) - -const ( - maxNodeInfoSize = 10240 // 10KB - maxNumChannels = 16 // plenty of room for upgrades, for now -) - -// Max size of the NodeInfo struct -func MaxNodeInfoSize() int { - return maxNodeInfoSize -} - -// ------------------------------------------------------------- - -// NodeInfo is the basic node information exchanged -// between two peers during the Tendermint P2P handshake. -type NodeInfo struct { - // Set of protocol versions - VersionSet versionset.VersionSet `json:"version_set"` - - // Authenticate - NetAddress *NetAddress `json:"net_address"` - - // Check compatibility. - // Channels are HexBytes so easier to read as JSON - Network string `json:"network"` // network/chain ID - Software string `json:"software"` // name of immediate software - Version string `json:"version"` // software major.minor.revision - Channels []byte `json:"channels"` // channels this node knows about - - // ASCIIText fields - Moniker string `json:"moniker"` // arbitrary moniker - Other NodeInfoOther `json:"other"` // other application specific data -} - -// NodeInfoOther is the misc. application specific data -type NodeInfoOther struct { - TxIndex string `json:"tx_index"` - RPCAddress string `json:"rpc_address"` -} - -// Validate checks the self-reported NodeInfo is safe. -// It returns an error if there -// are too many Channels, if there are any duplicate Channels, -// if the ListenAddr is malformed, or if the ListenAddr is a host name -// that can not be resolved to some IP. -// TODO: constraints for Moniker/Other? Or is that for the UI ? -// JAE: It needs to be done on the client, but to prevent ambiguous -// unicode characters, maybe it's worth sanitizing it here. -// In the future we might want to validate these, once we have a -// name-resolution system up. -// International clients could then use punycode (or we could use -// url-encoding), and we just need to be careful with how we handle that in our -// clients. (e.g. off by default). -func (info NodeInfo) Validate() error { - // ID is already validated. TODO validate - - // Validate ListenAddr. - if info.NetAddress == nil { - return fmt.Errorf("info.NetAddress cannot be nil") - } - if err := info.NetAddress.ValidateLocal(); err != nil { - return err - } - - // Network is validated in CompatibleWith. - - // Validate Version - if len(info.Version) > 0 && - (!strings.IsASCIIText(info.Version) || strings.ASCIITrim(info.Version) == "") { - return fmt.Errorf("info.Version must be valid ASCII text without tabs, but got %v", info.Version) - } - - // Validate Channels - ensure max and check for duplicates. - if len(info.Channels) > maxNumChannels { - return fmt.Errorf("info.Channels is too long (%v). Max is %v", len(info.Channels), maxNumChannels) - } - channels := make(map[byte]struct{}) - for _, ch := range info.Channels { - _, ok := channels[ch] - if ok { - return fmt.Errorf("info.Channels contains duplicate channel id %v", ch) - } - channels[ch] = struct{}{} - } - - // Validate Moniker. - if !strings.IsASCIIText(info.Moniker) || strings.ASCIITrim(info.Moniker) == "" { - return fmt.Errorf("info.Moniker must be valid non-empty ASCII text without tabs, but got %v", info.Moniker) - } - - // Validate Other. - other := info.Other - txIndex := other.TxIndex - switch txIndex { - case "", eventstore.StatusOn, eventstore.StatusOff: - default: - return fmt.Errorf("info.Other.TxIndex should be either 'on', 'off', or empty string, got '%v'", txIndex) - } - // XXX: Should we be more strict about address formats? - rpcAddr := other.RPCAddress - if len(rpcAddr) > 0 && (!strings.IsASCIIText(rpcAddr) || strings.ASCIITrim(rpcAddr) == "") { - return fmt.Errorf("info.Other.RPCAddress=%v must be valid ASCII text without tabs", rpcAddr) - } - - return nil -} - -func (info NodeInfo) ID() ID { - return info.NetAddress.ID -} - -// CompatibleWith checks if two NodeInfo are compatible with eachother. -// CONTRACT: two nodes are compatible if the Block version and network match -// and they have at least one channel in common. -func (info NodeInfo) CompatibleWith(other NodeInfo) error { - // check protocol versions - _, err := info.VersionSet.CompatibleWith(other.VersionSet) - if err != nil { - return err - } - - // nodes must be on the same network - if info.Network != other.Network { - return fmt.Errorf("Peer is on a different network. Got %v, expected %v", other.Network, info.Network) - } - - // if we have no channels, we're just testing - if len(info.Channels) == 0 { - return nil - } - - // for each of our channels, check if they have it - found := false -OUTER_LOOP: - for _, ch1 := range info.Channels { - for _, ch2 := range other.Channels { - if ch1 == ch2 { - found = true - break OUTER_LOOP // only need one - } - } - } - if !found { - return fmt.Errorf("Peer has no common channels. Our channels: %v ; Peer channels: %v", info.Channels, other.Channels) - } - return nil -} diff --git a/tm2/pkg/p2p/node_info_test.go b/tm2/pkg/p2p/node_info_test.go deleted file mode 100644 index 58f1dab8854..00000000000 --- a/tm2/pkg/p2p/node_info_test.go +++ /dev/null @@ -1,134 +0,0 @@ -package p2p - -import ( - "fmt" - "net" - "testing" - - "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" - "github.com/gnolang/gno/tm2/pkg/versionset" - "github.com/stretchr/testify/assert" -) - -func TestNodeInfoValidate(t *testing.T) { - t.Parallel() - - // empty fails - ni := NodeInfo{} - assert.Error(t, ni.Validate()) - - channels := make([]byte, maxNumChannels) - for i := 0; i < maxNumChannels; i++ { - channels[i] = byte(i) - } - dupChannels := make([]byte, 5) - copy(dupChannels, channels[:5]) - dupChannels = append(dupChannels, testCh) - - nonAscii := "¢§µ" - emptyTab := fmt.Sprintf("\t") - emptySpace := fmt.Sprintf(" ") - - testCases := []struct { - testName string - malleateNodeInfo func(*NodeInfo) - expectErr bool - }{ - {"Too Many Channels", func(ni *NodeInfo) { ni.Channels = append(channels, byte(maxNumChannels)) }, true}, //nolint: gocritic - {"Duplicate Channel", func(ni *NodeInfo) { ni.Channels = dupChannels }, true}, - {"Good Channels", func(ni *NodeInfo) { ni.Channels = ni.Channels[:5] }, false}, - - {"Nil NetAddress", func(ni *NodeInfo) { ni.NetAddress = nil }, true}, - {"Zero NetAddress ID", func(ni *NodeInfo) { ni.NetAddress.ID = "" }, true}, - {"Invalid NetAddress IP", func(ni *NodeInfo) { ni.NetAddress.IP = net.IP([]byte{0x00}) }, true}, - - {"Non-ASCII Version", func(ni *NodeInfo) { ni.Version = nonAscii }, true}, - {"Empty tab Version", func(ni *NodeInfo) { ni.Version = emptyTab }, true}, - {"Empty space Version", func(ni *NodeInfo) { ni.Version = emptySpace }, true}, - {"Empty Version", func(ni *NodeInfo) { ni.Version = "" }, false}, - - {"Non-ASCII Moniker", func(ni *NodeInfo) { ni.Moniker = nonAscii }, true}, - {"Empty tab Moniker", func(ni *NodeInfo) { ni.Moniker = emptyTab }, true}, - {"Empty space Moniker", func(ni *NodeInfo) { ni.Moniker = emptySpace }, true}, - {"Empty Moniker", func(ni *NodeInfo) { ni.Moniker = "" }, true}, - {"Good Moniker", func(ni *NodeInfo) { ni.Moniker = "hey its me" }, false}, - - {"Non-ASCII TxIndex", func(ni *NodeInfo) { ni.Other.TxIndex = nonAscii }, true}, - {"Empty tab TxIndex", func(ni *NodeInfo) { ni.Other.TxIndex = emptyTab }, true}, - {"Empty space TxIndex", func(ni *NodeInfo) { ni.Other.TxIndex = emptySpace }, true}, - {"Empty TxIndex", func(ni *NodeInfo) { ni.Other.TxIndex = "" }, false}, - {"Off TxIndex", func(ni *NodeInfo) { ni.Other.TxIndex = "off" }, false}, - - {"Non-ASCII RPCAddress", func(ni *NodeInfo) { ni.Other.RPCAddress = nonAscii }, true}, - {"Empty tab RPCAddress", func(ni *NodeInfo) { ni.Other.RPCAddress = emptyTab }, true}, - {"Empty space RPCAddress", func(ni *NodeInfo) { ni.Other.RPCAddress = emptySpace }, true}, - {"Empty RPCAddress", func(ni *NodeInfo) { ni.Other.RPCAddress = "" }, false}, - {"Good RPCAddress", func(ni *NodeInfo) { ni.Other.RPCAddress = "0.0.0.0:26657" }, false}, - } - - nodeKey := NodeKey{PrivKey: ed25519.GenPrivKey()} - name := "testing" - - // test case passes - ni = testNodeInfo(nodeKey.ID(), name) - ni.Channels = channels - assert.NoError(t, ni.Validate()) - - for _, tc := range testCases { - ni := testNodeInfo(nodeKey.ID(), name) - ni.Channels = channels - tc.malleateNodeInfo(&ni) - err := ni.Validate() - if tc.expectErr { - assert.Error(t, err, tc.testName) - } else { - assert.NoError(t, err, tc.testName) - } - } -} - -func TestNodeInfoCompatible(t *testing.T) { - t.Parallel() - - nodeKey1 := NodeKey{PrivKey: ed25519.GenPrivKey()} - nodeKey2 := NodeKey{PrivKey: ed25519.GenPrivKey()} - name := "testing" - - var newTestChannel byte = 0x2 - - // test NodeInfo is compatible - ni1 := testNodeInfo(nodeKey1.ID(), name) - ni2 := testNodeInfo(nodeKey2.ID(), name) - assert.NoError(t, ni1.CompatibleWith(ni2)) - - // add another channel; still compatible - ni2.Channels = []byte{newTestChannel, testCh} - assert.NoError(t, ni1.CompatibleWith(ni2)) - - // wrong NodeInfo type is not compatible - _, netAddr := CreateRoutableAddr() - ni3 := NodeInfo{NetAddress: netAddr} - assert.Error(t, ni1.CompatibleWith(ni3)) - - testCases := []struct { - testName string - malleateNodeInfo func(*NodeInfo) - }{ - {"Bad block version", func(ni *NodeInfo) { - ni.VersionSet.Set(versionset.VersionInfo{Name: "Block", Version: "badversion"}) - }}, - {"Wrong block version", func(ni *NodeInfo) { - ni.VersionSet.Set(versionset.VersionInfo{Name: "Block", Version: "v999.999.999-wrong"}) - }}, - {"Wrong network", func(ni *NodeInfo) { ni.Network += "-wrong" }}, - {"No common channels", func(ni *NodeInfo) { ni.Channels = []byte{newTestChannel} }}, - } - - for i, tc := range testCases { - t.Logf("case #%v", i) - ni := testNodeInfo(nodeKey2.ID(), name) - tc.malleateNodeInfo(&ni) - fmt.Printf("case #%v\n", i) - assert.Error(t, ni1.CompatibleWith(ni)) - } -} diff --git a/tm2/pkg/p2p/peer.go b/tm2/pkg/p2p/peer.go index ef2ddcf2c25..135bf4b250c 100644 --- a/tm2/pkg/p2p/peer.go +++ b/tm2/pkg/p2p/peer.go @@ -6,169 +6,132 @@ import ( "net" "github.com/gnolang/gno/tm2/pkg/cmap" - connm "github.com/gnolang/gno/tm2/pkg/p2p/conn" + "github.com/gnolang/gno/tm2/pkg/p2p/conn" + "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/gnolang/gno/tm2/pkg/service" ) -// Peer is an interface representing a peer connected on a reactor. -type Peer interface { - service.Service - FlushStop() - - ID() ID // peer's cryptographic ID - RemoteIP() net.IP // remote IP of the connection - RemoteAddr() net.Addr // remote address of the connection - - IsOutbound() bool // did we dial the peer - IsPersistent() bool // do we redial this peer when we disconnect - - CloseConn() error // close original connection - - NodeInfo() NodeInfo // peer's info - Status() connm.ConnectionStatus - SocketAddr() *NetAddress // actual address of the socket - - Send(byte, []byte) bool - TrySend(byte, []byte) bool - - Set(string, interface{}) - Get(string) interface{} +type ConnConfig struct { + MConfig conn.MConnConfig + ReactorsByCh map[byte]Reactor + ChDescs []*conn.ChannelDescriptor + OnPeerError func(PeerConn, error) } -// ---------------------------------------------------------- - -// peerConn contains the raw connection and its config. -type peerConn struct { - outbound bool - persistent bool - conn net.Conn // source connection - - socketAddr *NetAddress - - // cached RemoteIP() - ip net.IP +// ConnInfo wraps the remote peer connection +type ConnInfo struct { + Outbound bool // flag indicating if the connection is dialed + Persistent bool // flag indicating if the connection is persistent + Private bool // flag indicating if the peer is private (not shared) + Conn net.Conn // the source connection + RemoteIP net.IP // the remote IP of the peer + SocketAddr *types.NetAddress } -func newPeerConn( - outbound, persistent bool, - conn net.Conn, - socketAddr *NetAddress, -) peerConn { - return peerConn{ - outbound: outbound, - persistent: persistent, - conn: conn, - socketAddr: socketAddr, - } -} - -// ID only exists for SecretConnection. -// NOTE: Will panic if conn is not *SecretConnection. -func (pc peerConn) ID() ID { - return (pc.conn.(*connm.SecretConnection).RemotePubKey()).Address().ID() -} - -// Return the IP from the connection RemoteAddr -func (pc peerConn) RemoteIP() net.IP { - if pc.ip != nil { - return pc.ip - } - - host, _, err := net.SplitHostPort(pc.conn.RemoteAddr().String()) - if err != nil { - panic(err) - } - - ips, err := net.LookupIP(host) - if err != nil { - panic(err) - } - - pc.ip = ips[0] - - return pc.ip +type multiplexConn interface { + FlushStop() + Start() error + Stop() error + Send(byte, []byte) bool + TrySend(byte, []byte) bool + SetLogger(*slog.Logger) + Status() conn.ConnectionStatus + String() string } -// peer implements Peer. -// +// peer is a wrapper for a remote peer // Before using a peer, you will need to perform a handshake on connection. type peer struct { service.BaseService - // raw peerConn and the multiplex connection - peerConn - mconn *connm.MConnection - - // peer's node info and the channel it knows about - // channels = nodeInfo.Channels - // cached to avoid copying nodeInfo in hasChannel - nodeInfo NodeInfo - channels []byte + connInfo *ConnInfo // Metadata about the connection + nodeInfo types.NodeInfo // Information about the peer's node + mConn multiplexConn // The multiplexed connection - // User data - Data *cmap.CMap + data *cmap.CMap // Arbitrary data store associated with the peer } -type PeerOption func(*peer) - +// newPeer creates an uninitialized peer instance func newPeer( - pc peerConn, - mConfig connm.MConnConfig, - nodeInfo NodeInfo, - reactorsByCh map[byte]Reactor, - chDescs []*connm.ChannelDescriptor, - onPeerError func(Peer, interface{}), - options ...PeerOption, -) *peer { + connInfo *ConnInfo, + nodeInfo types.NodeInfo, + mConfig *ConnConfig, +) PeerConn { p := &peer{ - peerConn: pc, + connInfo: connInfo, nodeInfo: nodeInfo, - channels: nodeInfo.Channels, // TODO - Data: cmap.NewCMap(), + data: cmap.NewCMap(), } - p.mconn = createMConnection( - pc.conn, - p, - reactorsByCh, - chDescs, - onPeerError, + p.mConn = p.createMConnection( + connInfo.Conn, mConfig, ) + p.BaseService = *service.NewBaseService(nil, "Peer", p) - for _, option := range options { - option(p) - } return p } -// String representation. +// RemoteIP returns the IP from the remote connection +func (p *peer) RemoteIP() net.IP { + return p.connInfo.RemoteIP +} + +// RemoteAddr returns the address from the remote connection +func (p *peer) RemoteAddr() net.Addr { + return p.connInfo.Conn.RemoteAddr() +} + func (p *peer) String() string { - if p.outbound { - return fmt.Sprintf("Peer{%v %v out}", p.mconn, p.ID()) + if p.connInfo.Outbound { + return fmt.Sprintf("Peer{%s %s out}", p.mConn, p.ID()) } - return fmt.Sprintf("Peer{%v %v in}", p.mconn, p.ID()) + return fmt.Sprintf("Peer{%s %s in}", p.mConn, p.ID()) } -// --------------------------------------------------- -// Implements service.Service +// IsOutbound returns true if the connection is outbound, false otherwise. +func (p *peer) IsOutbound() bool { + return p.connInfo.Outbound +} + +// IsPersistent returns true if the peer is persistent, false otherwise. +func (p *peer) IsPersistent() bool { + return p.connInfo.Persistent +} + +// IsPrivate returns true if the peer is private, false otherwise. +func (p *peer) IsPrivate() bool { + return p.connInfo.Private +} + +// SocketAddr returns the address of the socket. +// For outbound peers, it's the address dialed (after DNS resolution). +// For inbound peers, it's the address returned by the underlying connection +// (not what's reported in the peer's NodeInfo). +func (p *peer) SocketAddr() *types.NetAddress { + return p.connInfo.SocketAddr +} + +// CloseConn closes original connection. +// Used for cleaning up in cases where the peer had not been started at all. +func (p *peer) CloseConn() error { + return p.connInfo.Conn.Close() +} -// SetLogger implements BaseService. func (p *peer) SetLogger(l *slog.Logger) { p.Logger = l - p.mconn.SetLogger(l) + p.mConn.SetLogger(l) } -// OnStart implements BaseService. func (p *peer) OnStart() error { if err := p.BaseService.OnStart(); err != nil { - return err + return fmt.Errorf("unable to start base service, %w", err) } - if err := p.mconn.Start(); err != nil { - return err + if err := p.mConn.Start(); err != nil { + return fmt.Errorf("unable to start multiplex connection, %w", err) } return nil @@ -179,164 +142,105 @@ func (p *peer) OnStart() error { // NOTE: it is not safe to call this method more than once. func (p *peer) FlushStop() { p.BaseService.OnStop() - p.mconn.FlushStop() // stop everything and close the conn + p.mConn.FlushStop() // stop everything and close the conn } // OnStop implements BaseService. func (p *peer) OnStop() { p.BaseService.OnStop() - p.mconn.Stop() // stop everything and close the conn -} -// --------------------------------------------------- -// Implements Peer - -// ID returns the peer's ID - the hex encoded hash of its pubkey. -func (p *peer) ID() ID { - return p.nodeInfo.NetAddress.ID -} - -// IsOutbound returns true if the connection is outbound, false otherwise. -func (p *peer) IsOutbound() bool { - return p.peerConn.outbound + if err := p.mConn.Stop(); err != nil { + p.Logger.Error( + "unable to gracefully close mConn", + "err", + err, + ) + } } -// IsPersistent returns true if the peer is persistent, false otherwise. -func (p *peer) IsPersistent() bool { - return p.peerConn.persistent +// ID returns the peer's ID - the hex encoded hash of its pubkey. +func (p *peer) ID() types.ID { + return p.nodeInfo.PeerID } // NodeInfo returns a copy of the peer's NodeInfo. -func (p *peer) NodeInfo() NodeInfo { +func (p *peer) NodeInfo() types.NodeInfo { return p.nodeInfo } -// SocketAddr returns the address of the socket. -// For outbound peers, it's the address dialed (after DNS resolution). -// For inbound peers, it's the address returned by the underlying connection -// (not what's reported in the peer's NodeInfo). -func (p *peer) SocketAddr() *NetAddress { - return p.peerConn.socketAddr -} - // Status returns the peer's ConnectionStatus. -func (p *peer) Status() connm.ConnectionStatus { - return p.mconn.Status() +func (p *peer) Status() conn.ConnectionStatus { + return p.mConn.Status() } // Send msg bytes to the channel identified by chID byte. Returns false if the // send queue is full after timeout, specified by MConnection. func (p *peer) Send(chID byte, msgBytes []byte) bool { - if !p.IsRunning() { - // see Switch#Broadcast, where we fetch the list of peers and loop over + if !p.IsRunning() || !p.hasChannel(chID) { + // see MultiplexSwitch#Broadcast, where we fetch the list of peers and loop over // them - while we're looping, one peer may be removed and stopped. return false - } else if !p.hasChannel(chID) { - return false } - res := p.mconn.Send(chID, msgBytes) - return res + + return p.mConn.Send(chID, msgBytes) } // TrySend msg bytes to the channel identified by chID byte. Immediately returns // false if the send queue is full. func (p *peer) TrySend(chID byte, msgBytes []byte) bool { - if !p.IsRunning() { - return false - } else if !p.hasChannel(chID) { + if !p.IsRunning() || !p.hasChannel(chID) { return false } - res := p.mconn.TrySend(chID, msgBytes) - return res + + return p.mConn.TrySend(chID, msgBytes) } // Get the data for a given key. -func (p *peer) Get(key string) interface{} { - return p.Data.Get(key) +func (p *peer) Get(key string) any { + return p.data.Get(key) } // Set sets the data for the given key. -func (p *peer) Set(key string, data interface{}) { - p.Data.Set(key, data) +func (p *peer) Set(key string, data any) { + p.data.Set(key, data) } // hasChannel returns true if the peer reported // knowing about the given chID. func (p *peer) hasChannel(chID byte) bool { - for _, ch := range p.channels { + for _, ch := range p.nodeInfo.Channels { if ch == chID { return true } } - // NOTE: probably will want to remove this - // but could be helpful while the feature is new - p.Logger.Debug( - "Unknown channel for peer", - "channel", - chID, - "channels", - p.channels, - ) - return false -} - -// CloseConn closes original connection. Used for cleaning up in cases where the peer had not been started at all. -func (p *peer) CloseConn() error { - return p.peerConn.conn.Close() -} - -// --------------------------------------------------- -// methods only used for testing -// TODO: can we remove these? - -// CloseConn closes the underlying connection -func (pc *peerConn) CloseConn() { - pc.conn.Close() //nolint: errcheck -} -// RemoteAddr returns peer's remote network address. -func (p *peer) RemoteAddr() net.Addr { - return p.peerConn.conn.RemoteAddr() -} - -// CanSend returns true if the send queue is not full, false otherwise. -func (p *peer) CanSend(chID byte) bool { - if !p.IsRunning() { - return false - } - return p.mconn.CanSend(chID) + return false } -// ------------------------------------------------------------------ -// helper funcs - -func createMConnection( - conn net.Conn, - p *peer, - reactorsByCh map[byte]Reactor, - chDescs []*connm.ChannelDescriptor, - onPeerError func(Peer, interface{}), - config connm.MConnConfig, -) *connm.MConnection { +func (p *peer) createMConnection( + c net.Conn, + config *ConnConfig, +) *conn.MConnection { onReceive := func(chID byte, msgBytes []byte) { - reactor := reactorsByCh[chID] + reactor := config.ReactorsByCh[chID] if reactor == nil { // Note that its ok to panic here as it's caught in the connm._recover, // which does onPeerError. panic(fmt.Sprintf("Unknown channel %X", chID)) } + reactor.Receive(chID, p, msgBytes) } - onError := func(r interface{}) { - onPeerError(p, r) + onError := func(r error) { + config.OnPeerError(p, r) } - return connm.NewMConnectionWithConfig( - conn, - chDescs, + return conn.NewMConnectionWithConfig( + c, + config.ChDescs, onReceive, onError, - config, + config.MConfig, ) } diff --git a/tm2/pkg/p2p/peer_set.go b/tm2/pkg/p2p/peer_set.go deleted file mode 100644 index 396ba56da11..00000000000 --- a/tm2/pkg/p2p/peer_set.go +++ /dev/null @@ -1,147 +0,0 @@ -package p2p - -import ( - "net" - "sync" -) - -// IPeerSet has a (immutable) subset of the methods of PeerSet. -type IPeerSet interface { - Has(key ID) bool - HasIP(ip net.IP) bool - Get(key ID) Peer - List() []Peer - Size() int -} - -// ----------------------------------------------------------------------------- - -// PeerSet is a special structure for keeping a table of peers. -// Iteration over the peers is super fast and thread-safe. -type PeerSet struct { - mtx sync.Mutex - lookup map[ID]*peerSetItem - list []Peer -} - -type peerSetItem struct { - peer Peer - index int -} - -// NewPeerSet creates a new peerSet with a list of initial capacity of 256 items. -func NewPeerSet() *PeerSet { - return &PeerSet{ - lookup: make(map[ID]*peerSetItem), - list: make([]Peer, 0, 256), - } -} - -// Add adds the peer to the PeerSet. -// It returns an error carrying the reason, if the peer is already present. -func (ps *PeerSet) Add(peer Peer) error { - ps.mtx.Lock() - defer ps.mtx.Unlock() - - if ps.lookup[peer.ID()] != nil { - return SwitchDuplicatePeerIDError{peer.ID()} - } - - index := len(ps.list) - // Appending is safe even with other goroutines - // iterating over the ps.list slice. - ps.list = append(ps.list, peer) - ps.lookup[peer.ID()] = &peerSetItem{peer, index} - return nil -} - -// Has returns true if the set contains the peer referred to by this -// peerKey, otherwise false. -func (ps *PeerSet) Has(peerKey ID) bool { - ps.mtx.Lock() - _, ok := ps.lookup[peerKey] - ps.mtx.Unlock() - return ok -} - -// HasIP returns true if the set contains the peer referred to by this IP -// address, otherwise false. -func (ps *PeerSet) HasIP(peerIP net.IP) bool { - ps.mtx.Lock() - defer ps.mtx.Unlock() - - return ps.hasIP(peerIP) -} - -// hasIP does not acquire a lock so it can be used in public methods which -// already lock. -func (ps *PeerSet) hasIP(peerIP net.IP) bool { - for _, item := range ps.lookup { - if item.peer.RemoteIP().Equal(peerIP) { - return true - } - } - - return false -} - -// Get looks up a peer by the provided peerKey. Returns nil if peer is not -// found. -func (ps *PeerSet) Get(peerKey ID) Peer { - ps.mtx.Lock() - defer ps.mtx.Unlock() - item, ok := ps.lookup[peerKey] - if ok { - return item.peer - } - return nil -} - -// Remove discards peer by its Key, if the peer was previously memoized. -// Returns true if the peer was removed, and false if it was not found. -// in the set. -func (ps *PeerSet) Remove(peer Peer) bool { - ps.mtx.Lock() - defer ps.mtx.Unlock() - - item := ps.lookup[peer.ID()] - if item == nil { - return false - } - - index := item.index - // Create a new copy of the list but with one less item. - // (we must copy because we'll be mutating the list). - newList := make([]Peer, len(ps.list)-1) - copy(newList, ps.list) - // If it's the last peer, that's an easy special case. - if index == len(ps.list)-1 { - ps.list = newList - delete(ps.lookup, peer.ID()) - return true - } - - // Replace the popped item with the last item in the old list. - lastPeer := ps.list[len(ps.list)-1] - lastPeerKey := lastPeer.ID() - lastPeerItem := ps.lookup[lastPeerKey] - newList[index] = lastPeer - lastPeerItem.index = index - ps.list = newList - delete(ps.lookup, peer.ID()) - return true -} - -// Size returns the number of unique items in the peerSet. -func (ps *PeerSet) Size() int { - ps.mtx.Lock() - defer ps.mtx.Unlock() - return len(ps.list) -} - -// List returns the threadsafe list of peers. -func (ps *PeerSet) List() []Peer { - ps.mtx.Lock() - defer ps.mtx.Unlock() - return ps.list -} diff --git a/tm2/pkg/p2p/peer_set_test.go b/tm2/pkg/p2p/peer_set_test.go deleted file mode 100644 index 7aca84d59b0..00000000000 --- a/tm2/pkg/p2p/peer_set_test.go +++ /dev/null @@ -1,190 +0,0 @@ -package p2p - -import ( - "net" - "sync" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" - "github.com/gnolang/gno/tm2/pkg/service" -) - -// mockPeer for testing the PeerSet -type mockPeer struct { - service.BaseService - ip net.IP - id ID -} - -func (mp *mockPeer) FlushStop() { mp.Stop() } -func (mp *mockPeer) TrySend(chID byte, msgBytes []byte) bool { return true } -func (mp *mockPeer) Send(chID byte, msgBytes []byte) bool { return true } -func (mp *mockPeer) NodeInfo() NodeInfo { return NodeInfo{} } -func (mp *mockPeer) Status() ConnectionStatus { return ConnectionStatus{} } -func (mp *mockPeer) ID() ID { return mp.id } -func (mp *mockPeer) IsOutbound() bool { return false } -func (mp *mockPeer) IsPersistent() bool { return true } -func (mp *mockPeer) Get(s string) interface{} { return s } -func (mp *mockPeer) Set(string, interface{}) {} -func (mp *mockPeer) RemoteIP() net.IP { return mp.ip } -func (mp *mockPeer) SocketAddr() *NetAddress { return nil } -func (mp *mockPeer) RemoteAddr() net.Addr { return &net.TCPAddr{IP: mp.ip, Port: 8800} } -func (mp *mockPeer) CloseConn() error { return nil } - -// Returns a mock peer -func newMockPeer(ip net.IP) *mockPeer { - if ip == nil { - ip = net.IP{127, 0, 0, 1} - } - nodeKey := NodeKey{PrivKey: ed25519.GenPrivKey()} - return &mockPeer{ - ip: ip, - id: nodeKey.ID(), - } -} - -func TestPeerSetAddRemoveOne(t *testing.T) { - t.Parallel() - - peerSet := NewPeerSet() - - var peerList []Peer - for i := 0; i < 5; i++ { - p := newMockPeer(net.IP{127, 0, 0, byte(i)}) - if err := peerSet.Add(p); err != nil { - t.Error(err) - } - peerList = append(peerList, p) - } - - n := len(peerList) - // 1. Test removing from the front - for i, peerAtFront := range peerList { - removed := peerSet.Remove(peerAtFront) - assert.True(t, removed) - wantSize := n - i - 1 - for j := 0; j < 2; j++ { - assert.Equal(t, false, peerSet.Has(peerAtFront.ID()), "#%d Run #%d: failed to remove peer", i, j) - assert.Equal(t, wantSize, peerSet.Size(), "#%d Run #%d: failed to remove peer and decrement size", i, j) - // Test the route of removing the now non-existent element - removed := peerSet.Remove(peerAtFront) - assert.False(t, removed) - } - } - - // 2. Next we are testing removing the peer at the end - // a) Replenish the peerSet - for _, peer := range peerList { - if err := peerSet.Add(peer); err != nil { - t.Error(err) - } - } - - // b) In reverse, remove each element - for i := n - 1; i >= 0; i-- { - peerAtEnd := peerList[i] - removed := peerSet.Remove(peerAtEnd) - assert.True(t, removed) - assert.Equal(t, false, peerSet.Has(peerAtEnd.ID()), "#%d: failed to remove item at end", i) - assert.Equal(t, i, peerSet.Size(), "#%d: differing sizes after peerSet.Remove(atEndPeer)", i) - } -} - -func TestPeerSetAddRemoveMany(t *testing.T) { - t.Parallel() - peerSet := NewPeerSet() - - peers := []Peer{} - N := 100 - for i := 0; i < N; i++ { - peer := newMockPeer(net.IP{127, 0, 0, byte(i)}) - if err := peerSet.Add(peer); err != nil { - t.Errorf("Failed to add new peer") - } - if peerSet.Size() != i+1 { - t.Errorf("Failed to add new peer and increment size") - } - peers = append(peers, peer) - } - - for i, peer := range peers { - removed := peerSet.Remove(peer) - assert.True(t, removed) - if peerSet.Has(peer.ID()) { - t.Errorf("Failed to remove peer") - } - if peerSet.Size() != len(peers)-i-1 { - t.Errorf("Failed to remove peer and decrement size") - } - } -} - -func TestPeerSetAddDuplicate(t *testing.T) { - t.Parallel() - peerSet := NewPeerSet() - peer := newMockPeer(nil) - - n := 20 - errsChan := make(chan error) - // Add the same asynchronously to test the - // concurrent guarantees of our APIs, and - // our expectation in the end is that only - // one addition succeeded, but the rest are - // instances of ErrSwitchDuplicatePeer. - for i := 0; i < n; i++ { - go func() { - errsChan <- peerSet.Add(peer) - }() - } - - // Now collect and tally the results - errsTally := make(map[string]int) - for i := 0; i < n; i++ { - err := <-errsChan - - switch err.(type) { - case SwitchDuplicatePeerIDError: - errsTally["duplicateID"]++ - default: - errsTally["other"]++ - } - } - - // Our next procedure is to ensure that only one addition - // succeeded and that the rest are each ErrSwitchDuplicatePeer. - wantErrCount, gotErrCount := n-1, errsTally["duplicateID"] - assert.Equal(t, wantErrCount, gotErrCount, "invalid ErrSwitchDuplicatePeer count") - - wantNilErrCount, gotNilErrCount := 1, errsTally["other"] - assert.Equal(t, wantNilErrCount, gotNilErrCount, "invalid nil errCount") -} - -func TestPeerSetGet(t *testing.T) { - t.Parallel() - - var ( - peerSet = NewPeerSet() - peer = newMockPeer(nil) - ) - - assert.Nil(t, peerSet.Get(peer.ID()), "expecting a nil lookup, before .Add") - - if err := peerSet.Add(peer); err != nil { - t.Fatalf("Failed to add new peer: %v", err) - } - - var wg sync.WaitGroup - for i := 0; i < 10; i++ { - // Add them asynchronously to test the - // concurrent guarantees of our APIs. - wg.Add(1) - go func(i int) { - defer wg.Done() - have, want := peerSet.Get(peer.ID()), peer - assert.Equal(t, have, want, "%d: have %v, want %v", i, have, want) - }(i) - } - wg.Wait() -} diff --git a/tm2/pkg/p2p/peer_test.go b/tm2/pkg/p2p/peer_test.go index 28217c4486e..a74ea9e96a4 100644 --- a/tm2/pkg/p2p/peer_test.go +++ b/tm2/pkg/p2p/peer_test.go @@ -1,233 +1,630 @@ package p2p import ( + "errors" "fmt" - golog "log" + "io" + "log/slog" "net" "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/gnolang/gno/tm2/pkg/crypto" - "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" - "github.com/gnolang/gno/tm2/pkg/errors" - "github.com/gnolang/gno/tm2/pkg/log" + "github.com/gnolang/gno/tm2/pkg/cmap" "github.com/gnolang/gno/tm2/pkg/p2p/config" "github.com/gnolang/gno/tm2/pkg/p2p/conn" + "github.com/gnolang/gno/tm2/pkg/p2p/mock" + "github.com/gnolang/gno/tm2/pkg/p2p/types" + "github.com/gnolang/gno/tm2/pkg/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestPeerBasic(t *testing.T) { +func TestPeer_Properties(t *testing.T) { t.Parallel() - assert, require := assert.New(t), require.New(t) + t.Run("connection info", func(t *testing.T) { + t.Parallel() - // simulate remote peer - rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} - rp.Start() - defer rp.Stop() + t.Run("remote IP", func(t *testing.T) { + t.Parallel() - p, err := createOutboundPeerAndPerformHandshake(t, rp.Addr(), cfg, conn.DefaultMConnConfig()) - require.Nil(err) + var ( + info = &ConnInfo{ + RemoteIP: net.IP{127, 0, 0, 1}, + } - err = p.Start() - require.Nil(err) - defer p.Stop() + p = &peer{ + connInfo: info, + } + ) - assert.True(p.IsRunning()) - assert.True(p.IsOutbound()) - assert.False(p.IsPersistent()) - p.persistent = true - assert.True(p.IsPersistent()) - assert.Equal(rp.Addr().DialString(), p.RemoteAddr().String()) - assert.Equal(rp.ID(), p.ID()) -} + assert.Equal(t, info.RemoteIP, p.RemoteIP()) + }) -func TestPeerSend(t *testing.T) { - t.Parallel() + t.Run("remote address", func(t *testing.T) { + t.Parallel() - assert, require := assert.New(t), require.New(t) + tcpAddr, err := net.ResolveTCPAddr("tcp", "localhost:8080") + require.NoError(t, err) - config := cfg + var ( + info = &ConnInfo{ + Conn: &mock.Conn{ + RemoteAddrFn: func() net.Addr { + return tcpAddr + }, + }, + } - // simulate remote peer - rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: config} - rp.Start() - defer rp.Stop() + p = &peer{ + connInfo: info, + } + ) - p, err := createOutboundPeerAndPerformHandshake(t, rp.Addr(), config, conn.DefaultMConnConfig()) - require.Nil(err) + assert.Equal(t, tcpAddr.String(), p.RemoteAddr().String()) + }) - err = p.Start() - require.Nil(err) + t.Run("socket address", func(t *testing.T) { + t.Parallel() - defer p.Stop() + tcpAddr, err := net.ResolveTCPAddr("tcp", "localhost:8080") + require.NoError(t, err) - assert.True(p.CanSend(testCh)) - assert.True(p.Send(testCh, []byte("Asylum"))) -} + netAddr, err := types.NewNetAddress(types.GenerateNodeKey().ID(), tcpAddr) + require.NoError(t, err) -func createOutboundPeerAndPerformHandshake( - t *testing.T, - addr *NetAddress, - config *config.P2PConfig, - mConfig conn.MConnConfig, -) (*peer, error) { - t.Helper() - - chDescs := []*conn.ChannelDescriptor{ - {ID: testCh, Priority: 1}, - } - reactorsByCh := map[byte]Reactor{testCh: NewTestReactor(chDescs, true)} - pk := ed25519.GenPrivKey() - pc, err := testOutboundPeerConn(addr, config, false, pk) - if err != nil { - return nil, err - } - timeout := 1 * time.Second - ourNodeInfo := testNodeInfo(addr.ID, "host_peer") - peerNodeInfo, err := handshake(pc.conn, timeout, ourNodeInfo) - if err != nil { - return nil, err - } - - p := newPeer(pc, mConfig, peerNodeInfo, reactorsByCh, chDescs, func(p Peer, r interface{}) {}) - p.SetLogger(log.NewTestingLogger(t).With("peer", addr)) - return p, nil -} + var ( + info = &ConnInfo{ + SocketAddr: netAddr, + } + + p = &peer{ + connInfo: info, + } + ) + + assert.Equal(t, netAddr.String(), p.SocketAddr().String()) + }) + + t.Run("set logger", func(t *testing.T) { + t.Parallel() + + var ( + l = slog.New(slog.NewTextHandler(io.Discard, nil)) + + p = &peer{ + mConn: &mock.MConn{}, + } + ) + + p.SetLogger(l) + + assert.Equal(t, l, p.Logger) + }) + + t.Run("peer start", func(t *testing.T) { + t.Parallel() + + var ( + expectedErr = errors.New("some error") + + mConn = &mock.MConn{ + StartFn: func() error { + return expectedErr + }, + } + + p = &peer{ + mConn: mConn, + } + ) + + assert.ErrorIs(t, p.OnStart(), expectedErr) + }) + + t.Run("peer stop", func(t *testing.T) { + t.Parallel() + + var ( + stopCalled = false + expectedErr = errors.New("some error") + + mConn = &mock.MConn{ + StopFn: func() error { + stopCalled = true + + return expectedErr + }, + } + + p = &peer{ + mConn: mConn, + } + ) + + p.BaseService = *service.NewBaseService(nil, "Peer", p) + + p.OnStop() + + assert.True(t, stopCalled) + }) + + t.Run("flush stop", func(t *testing.T) { + t.Parallel() + + var ( + stopCalled = false + + mConn = &mock.MConn{ + FlushFn: func() { + stopCalled = true + }, + } + + p = &peer{ + mConn: mConn, + } + ) + + p.BaseService = *service.NewBaseService(nil, "Peer", p) + + p.FlushStop() + + assert.True(t, stopCalled) + }) -func testDial(addr *NetAddress, cfg *config.P2PConfig) (net.Conn, error) { - if cfg.TestDialFail { - return nil, fmt.Errorf("dial err (peerConfig.DialFail == true)") - } + t.Run("node info fetch", func(t *testing.T) { + t.Parallel() - conn, err := addr.DialTimeout(cfg.DialTimeout) - if err != nil { - return nil, err - } - return conn, nil + var ( + info = types.NodeInfo{ + Network: "gnoland", + } + + p = &peer{ + nodeInfo: info, + } + ) + + assert.Equal(t, info, p.NodeInfo()) + }) + + t.Run("node status fetch", func(t *testing.T) { + t.Parallel() + + var ( + status = conn.ConnectionStatus{ + Duration: 5 * time.Second, + } + + mConn = &mock.MConn{ + StatusFn: func() conn.ConnectionStatus { + return status + }, + } + + p = &peer{ + mConn: mConn, + } + ) + + assert.Equal(t, status, p.Status()) + }) + + t.Run("string representation", func(t *testing.T) { + t.Parallel() + + testTable := []struct { + name string + outbound bool + }{ + { + "outbound", + true, + }, + { + "inbound", + false, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + var ( + id = types.GenerateNodeKey().ID() + mConnStr = "description" + + p = &peer{ + mConn: &mock.MConn{ + StringFn: func() string { + return mConnStr + }, + }, + nodeInfo: types.NodeInfo{ + PeerID: id, + }, + connInfo: &ConnInfo{ + Outbound: testCase.outbound, + }, + } + + direction = "in" + ) + + if testCase.outbound { + direction = "out" + } + + assert.Contains( + t, + p.String(), + fmt.Sprintf( + "Peer{%s %s %s}", + mConnStr, + id, + direction, + ), + ) + }) + } + }) + + t.Run("outbound information", func(t *testing.T) { + t.Parallel() + + p := &peer{ + connInfo: &ConnInfo{ + Outbound: true, + }, + } + + assert.True( + t, + p.IsOutbound(), + ) + }) + + t.Run("persistent information", func(t *testing.T) { + t.Parallel() + + p := &peer{ + connInfo: &ConnInfo{ + Persistent: true, + }, + } + + assert.True(t, p.IsPersistent()) + }) + + t.Run("initial conn close", func(t *testing.T) { + t.Parallel() + + var ( + closeErr = errors.New("close error") + + mockConn = &mock.Conn{ + CloseFn: func() error { + return closeErr + }, + } + + p = &peer{ + connInfo: &ConnInfo{ + Conn: mockConn, + }, + } + ) + + assert.ErrorIs(t, p.CloseConn(), closeErr) + }) + }) } -func testOutboundPeerConn( - addr *NetAddress, - config *config.P2PConfig, - persistent bool, - ourNodePrivKey crypto.PrivKey, -) (peerConn, error) { - var pc peerConn - conn, err := testDial(addr, config) - if err != nil { - return pc, errors.Wrap(err, "Error creating peer") - } - - pc, err = testPeerConn(conn, config, true, persistent, ourNodePrivKey, addr) - if err != nil { - if cerr := conn.Close(); cerr != nil { - return pc, errors.Wrap(err, cerr.Error()) - } - return pc, err - } +func TestPeer_GetSet(t *testing.T) { + t.Parallel() - // ensure dialed ID matches connection ID - if addr.ID != pc.ID() { - if cerr := conn.Close(); cerr != nil { - return pc, errors.Wrap(err, cerr.Error()) + var ( + key = "key" + data = []byte("random") + + p = &peer{ + data: cmap.NewCMap(), } - return pc, SwitchAuthenticationFailureError{addr, pc.ID()} - } + ) - return pc, nil -} + assert.Nil(t, p.Get(key)) -type remotePeer struct { - PrivKey crypto.PrivKey - Config *config.P2PConfig - addr *NetAddress - channels []byte - listenAddr string - listener net.Listener -} + // Set the key + p.Set(key, data) -func (rp *remotePeer) Addr() *NetAddress { - return rp.addr + assert.Equal(t, data, p.Get(key)) } -func (rp *remotePeer) ID() ID { - return rp.PrivKey.PubKey().Address().ID() -} +func TestPeer_Send(t *testing.T) { + t.Parallel() -func (rp *remotePeer) Start() { - if rp.listenAddr == "" { - rp.listenAddr = "127.0.0.1:0" - } - - l, e := net.Listen("tcp", rp.listenAddr) // any available address - if e != nil { - golog.Fatalf("net.Listen tcp :0: %+v", e) - } - rp.listener = l - rp.addr = NewNetAddress(rp.PrivKey.PubKey().Address().ID(), l.Addr()) - if rp.channels == nil { - rp.channels = []byte{testCh} - } - go rp.accept() -} + t.Run("peer not running", func(t *testing.T) { + t.Parallel() -func (rp *remotePeer) Stop() { - rp.listener.Close() -} + var ( + chID = byte(10) + data = []byte("random") + + capturedSendID byte + capturedSendData []byte + + mockConn = &mock.MConn{ + SendFn: func(c byte, d []byte) bool { + capturedSendID = c + capturedSendData = d + + return true + }, + } + + p = &peer{ + nodeInfo: types.NodeInfo{ + Channels: []byte{ + chID, + }, + }, + mConn: mockConn, + } + ) + + p.BaseService = *service.NewBaseService(nil, "Peer", p) + + // Make sure the send fails + require.False(t, p.Send(chID, data)) + + assert.Empty(t, capturedSendID) + assert.Nil(t, capturedSendData) + }) + + t.Run("peer doesn't have channel", func(t *testing.T) { + t.Parallel() + + var ( + chID = byte(10) + data = []byte("random") + + capturedSendID byte + capturedSendData []byte + + mockConn = &mock.MConn{ + SendFn: func(c byte, d []byte) bool { + capturedSendID = c + capturedSendData = d + + return true + }, + } + + p = &peer{ + nodeInfo: types.NodeInfo{ + Channels: []byte{}, + }, + mConn: mockConn, + } + ) + + p.BaseService = *service.NewBaseService(nil, "Peer", p) + + // Start the peer "multiplexing" + require.NoError(t, p.Start()) + t.Cleanup(func() { + require.NoError(t, p.Stop()) + }) + + // Make sure the send fails + require.False(t, p.Send(chID, data)) + + assert.Empty(t, capturedSendID) + assert.Nil(t, capturedSendData) + }) -func (rp *remotePeer) Dial(addr *NetAddress) (net.Conn, error) { - conn, err := addr.DialTimeout(1 * time.Second) - if err != nil { - return nil, err - } - pc, err := testInboundPeerConn(conn, rp.Config, rp.PrivKey) - if err != nil { - return nil, err - } - _, err = handshake(pc.conn, time.Second, rp.nodeInfo()) - if err != nil { - return nil, err - } - return conn, err + t.Run("valid peer data send", func(t *testing.T) { + t.Parallel() + + var ( + chID = byte(10) + data = []byte("random") + + capturedSendID byte + capturedSendData []byte + + mockConn = &mock.MConn{ + SendFn: func(c byte, d []byte) bool { + capturedSendID = c + capturedSendData = d + + return true + }, + } + + p = &peer{ + nodeInfo: types.NodeInfo{ + Channels: []byte{ + chID, + }, + }, + mConn: mockConn, + } + ) + + p.BaseService = *service.NewBaseService(nil, "Peer", p) + + // Start the peer "multiplexing" + require.NoError(t, p.Start()) + t.Cleanup(func() { + require.NoError(t, p.Stop()) + }) + + // Make sure the send is valid + require.True(t, p.Send(chID, data)) + + assert.Equal(t, chID, capturedSendID) + assert.Equal(t, data, capturedSendData) + }) } -func (rp *remotePeer) accept() { - conns := []net.Conn{} +func TestPeer_TrySend(t *testing.T) { + t.Parallel() + + t.Run("peer not running", func(t *testing.T) { + t.Parallel() + + var ( + chID = byte(10) + data = []byte("random") - for { - conn, err := rp.listener.Accept() - if err != nil { - golog.Printf("Failed to accept conn: %+v", err) - for _, conn := range conns { - _ = conn.Close() + capturedSendID byte + capturedSendData []byte + + mockConn = &mock.MConn{ + TrySendFn: func(c byte, d []byte) bool { + capturedSendID = c + capturedSendData = d + + return true + }, } - return - } - pc, err := testInboundPeerConn(conn, rp.Config, rp.PrivKey) - if err != nil { - golog.Fatalf("Failed to create a peer: %+v", err) - } + p = &peer{ + nodeInfo: types.NodeInfo{ + Channels: []byte{ + chID, + }, + }, + mConn: mockConn, + } + ) - _, err = handshake(pc.conn, time.Second, rp.nodeInfo()) - if err != nil { - golog.Fatalf("Failed to perform handshake: %+v", err) - } + p.BaseService = *service.NewBaseService(nil, "Peer", p) + + // Make sure the send fails + require.False(t, p.TrySend(chID, data)) + + assert.Empty(t, capturedSendID) + assert.Nil(t, capturedSendData) + }) + + t.Run("peer doesn't have channel", func(t *testing.T) { + t.Parallel() + + var ( + chID = byte(10) + data = []byte("random") + + capturedSendID byte + capturedSendData []byte + + mockConn = &mock.MConn{ + TrySendFn: func(c byte, d []byte) bool { + capturedSendID = c + capturedSendData = d + + return true + }, + } + + p = &peer{ + nodeInfo: types.NodeInfo{ + Channels: []byte{}, + }, + mConn: mockConn, + } + ) + + p.BaseService = *service.NewBaseService(nil, "Peer", p) + + // Start the peer "multiplexing" + require.NoError(t, p.Start()) + t.Cleanup(func() { + require.NoError(t, p.Stop()) + }) + + // Make sure the send fails + require.False(t, p.TrySend(chID, data)) + + assert.Empty(t, capturedSendID) + assert.Nil(t, capturedSendData) + }) + + t.Run("valid peer data send", func(t *testing.T) { + t.Parallel() + + var ( + chID = byte(10) + data = []byte("random") - conns = append(conns, conn) - } + capturedSendID byte + capturedSendData []byte + + mockConn = &mock.MConn{ + TrySendFn: func(c byte, d []byte) bool { + capturedSendID = c + capturedSendData = d + + return true + }, + } + + p = &peer{ + nodeInfo: types.NodeInfo{ + Channels: []byte{ + chID, + }, + }, + mConn: mockConn, + } + ) + + p.BaseService = *service.NewBaseService(nil, "Peer", p) + + // Start the peer "multiplexing" + require.NoError(t, p.Start()) + t.Cleanup(func() { + require.NoError(t, p.Stop()) + }) + + // Make sure the send is valid + require.True(t, p.TrySend(chID, data)) + + assert.Equal(t, chID, capturedSendID) + assert.Equal(t, data, capturedSendData) + }) } -func (rp *remotePeer) nodeInfo() NodeInfo { - return NodeInfo{ - VersionSet: testVersionSet(), - NetAddress: rp.Addr(), - Network: "testing", - Version: "1.2.3-rc0-deadbeef", - Channels: rp.channels, - Moniker: "remote_peer", - } +func TestPeer_NewPeer(t *testing.T) { + t.Parallel() + + tcpAddr, err := net.ResolveTCPAddr("tcp", "localhost:8080") + require.NoError(t, err) + + netAddr, err := types.NewNetAddress(types.GenerateNodeKey().ID(), tcpAddr) + require.NoError(t, err) + + var ( + connInfo = &ConnInfo{ + Outbound: false, + Persistent: true, + Conn: &mock.Conn{}, + RemoteIP: tcpAddr.IP, + SocketAddr: netAddr, + } + + mConfig = &ConnConfig{ + MConfig: conn.MConfigFromP2P(config.DefaultP2PConfig()), + ReactorsByCh: make(map[byte]Reactor), + ChDescs: make([]*conn.ChannelDescriptor, 0), + OnPeerError: nil, + } + ) + + assert.NotPanics(t, func() { + _ = newPeer(connInfo, types.NodeInfo{}, mConfig) + }) } diff --git a/tm2/pkg/p2p/set.go b/tm2/pkg/p2p/set.go new file mode 100644 index 00000000000..b347c480b7e --- /dev/null +++ b/tm2/pkg/p2p/set.go @@ -0,0 +1,121 @@ +package p2p + +import ( + "sync" + + "github.com/gnolang/gno/tm2/pkg/p2p/types" +) + +type set struct { + mux sync.RWMutex + + peers map[types.ID]PeerConn + outbound uint64 + inbound uint64 +} + +// newSet creates an empty peer set +func newSet() *set { + return &set{ + peers: make(map[types.ID]PeerConn), + outbound: 0, + inbound: 0, + } +} + +// Add adds the peer to the set +func (s *set) Add(peer PeerConn) { + s.mux.Lock() + defer s.mux.Unlock() + + s.peers[peer.ID()] = peer + + if peer.IsOutbound() { + s.outbound += 1 + + return + } + + s.inbound += 1 +} + +// Has returns true if the set contains the peer referred to by this +// peerKey, otherwise false. +func (s *set) Has(peerKey types.ID) bool { + s.mux.RLock() + defer s.mux.RUnlock() + + _, exists := s.peers[peerKey] + + return exists +} + +// Get looks up a peer by the peer ID. Returns nil if peer is not +// found. +func (s *set) Get(key types.ID) PeerConn { + s.mux.RLock() + defer s.mux.RUnlock() + + p, found := s.peers[key] + if !found { + // TODO change this to an error, it doesn't make + // sense to propagate an implementation detail like this + return nil + } + + return p.(PeerConn) +} + +// Remove discards peer by its Key, if the peer was previously memoized. +// Returns true if the peer was removed, and false if it was not found. +// in the set. +func (s *set) Remove(key types.ID) bool { + s.mux.Lock() + defer s.mux.Unlock() + + p, found := s.peers[key] + if !found { + return false + } + + delete(s.peers, key) + + if p.(PeerConn).IsOutbound() { + s.outbound -= 1 + + return true + } + + s.inbound -= 1 + + return true +} + +// NumInbound returns the number of inbound peers +func (s *set) NumInbound() uint64 { + s.mux.RLock() + defer s.mux.RUnlock() + + return s.inbound +} + +// NumOutbound returns the number of outbound peers +func (s *set) NumOutbound() uint64 { + s.mux.RLock() + defer s.mux.RUnlock() + + return s.outbound +} + +// List returns the list of peers +func (s *set) List() []PeerConn { + s.mux.RLock() + defer s.mux.RUnlock() + + peers := make([]PeerConn, 0) + for _, p := range s.peers { + peers = append(peers, p.(PeerConn)) + } + + return peers +} diff --git a/tm2/pkg/p2p/set_test.go b/tm2/pkg/p2p/set_test.go new file mode 100644 index 00000000000..ced35538a9b --- /dev/null +++ b/tm2/pkg/p2p/set_test.go @@ -0,0 +1,146 @@ +package p2p + +import ( + "sort" + "testing" + + "github.com/gnolang/gno/tm2/pkg/p2p/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSet_Add(t *testing.T) { + t.Parallel() + + var ( + numPeers = 100 + peers = mock.GeneratePeers(t, numPeers) + + s = newSet() + ) + + for _, peer := range peers { + // Add the peer + s.Add(peer) + + // Make sure the peer is present + assert.True(t, s.Has(peer.ID())) + } + + assert.EqualValues(t, numPeers, s.NumInbound()+s.NumOutbound()) +} + +func TestSet_Remove(t *testing.T) { + t.Parallel() + + var ( + numPeers = 100 + peers = mock.GeneratePeers(t, numPeers) + + s = newSet() + ) + + // Add the initial peers + for _, peer := range peers { + // Add the peer + s.Add(peer) + + // Make sure the peer is present + require.True(t, s.Has(peer.ID())) + } + + require.EqualValues(t, numPeers, s.NumInbound()+s.NumOutbound()) + + // Remove the peers + // Add the initial peers + for _, peer := range peers { + // Add the peer + s.Remove(peer.ID()) + + // Make sure the peer is present + assert.False(t, s.Has(peer.ID())) + } +} + +func TestSet_Get(t *testing.T) { + t.Parallel() + + t.Run("existing peer", func(t *testing.T) { + t.Parallel() + + var ( + peers = mock.GeneratePeers(t, 100) + s = newSet() + ) + + for _, peer := range peers { + id := peer.ID() + s.Add(peer) + + assert.True(t, s.Get(id).ID() == id) + } + }) + + t.Run("missing peer", func(t *testing.T) { + t.Parallel() + + var ( + peers = mock.GeneratePeers(t, 100) + s = newSet() + ) + + for _, peer := range peers { + s.Add(peer) + } + + p := s.Get("random ID") + assert.Nil(t, p) + }) +} + +func TestSet_List(t *testing.T) { + t.Parallel() + + t.Run("empty peer set", func(t *testing.T) { + t.Parallel() + + // Empty set + s := newSet() + + // Linearize the set + assert.Len(t, s.List(), 0) + }) + + t.Run("existing peer set", func(t *testing.T) { + t.Parallel() + + var ( + peers = mock.GeneratePeers(t, 100) + s = newSet() + ) + + for _, peer := range peers { + s.Add(peer) + } + + // Linearize the set + listedPeers := s.List() + + require.Len(t, listedPeers, len(peers)) + + // Make sure the lists are sorted + // for easier comparison + sort.Slice(listedPeers, func(i, j int) bool { + return listedPeers[i].ID() < listedPeers[j].ID() + }) + + sort.Slice(peers, func(i, j int) bool { + return peers[i].ID() < peers[j].ID() + }) + + // Compare the lists + for index, listedPeer := range listedPeers { + assert.Equal(t, listedPeer.ID(), peers[index].ID()) + } + }) +} diff --git a/tm2/pkg/p2p/switch.go b/tm2/pkg/p2p/switch.go index 317f34e496b..5c1c37f7729 100644 --- a/tm2/pkg/p2p/switch.go +++ b/tm2/pkg/p2p/switch.go @@ -2,226 +2,164 @@ package p2p import ( "context" + "crypto/rand" "fmt" "math" + "math/big" "sync" "time" - "github.com/gnolang/gno/tm2/pkg/cmap" - "github.com/gnolang/gno/tm2/pkg/errors" "github.com/gnolang/gno/tm2/pkg/p2p/config" "github.com/gnolang/gno/tm2/pkg/p2p/conn" - "github.com/gnolang/gno/tm2/pkg/random" + "github.com/gnolang/gno/tm2/pkg/p2p/dial" + "github.com/gnolang/gno/tm2/pkg/p2p/events" + "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/gnolang/gno/tm2/pkg/service" "github.com/gnolang/gno/tm2/pkg/telemetry" "github.com/gnolang/gno/tm2/pkg/telemetry/metrics" ) -const ( - // wait a random amount of time from this interval - // before dialing peers or reconnecting to help prevent DoS - dialRandomizerIntervalMilliseconds = 3000 +// defaultDialTimeout is the default wait time for a dial to succeed +var defaultDialTimeout = 3 * time.Second - // repeatedly try to reconnect for a few minutes - // ie. 5 * 20 = 100s - reconnectAttempts = 20 - reconnectInterval = 5 * time.Second +type reactorPeerBehavior struct { + chDescs []*conn.ChannelDescriptor + reactorsByCh map[byte]Reactor - // then move into exponential backoff mode for ~1day - // ie. 3**10 = 16hrs - reconnectBackOffAttempts = 10 - reconnectBackOffBaseSeconds = 3 -) + handlePeerErrFn func(PeerConn, error) + isPersistentPeerFn func(types.ID) bool + isPrivatePeerFn func(types.ID) bool +} + +func (r *reactorPeerBehavior) ReactorChDescriptors() []*conn.ChannelDescriptor { + return r.chDescs +} + +func (r *reactorPeerBehavior) Reactors() map[byte]Reactor { + return r.reactorsByCh +} -// MConnConfig returns an MConnConfig with fields updated -// from the P2PConfig. -func MConnConfig(cfg *config.P2PConfig) conn.MConnConfig { - mConfig := conn.DefaultMConnConfig() - mConfig.FlushThrottle = cfg.FlushThrottleTimeout - mConfig.SendRate = cfg.SendRate - mConfig.RecvRate = cfg.RecvRate - mConfig.MaxPacketMsgPayloadSize = cfg.MaxPacketMsgPayloadSize - return mConfig +func (r *reactorPeerBehavior) HandlePeerError(p PeerConn, err error) { + r.handlePeerErrFn(p, err) } -// PeerFilterFunc to be implemented by filter hooks after a new Peer has been -// fully setup. -type PeerFilterFunc func(IPeerSet, Peer) error +func (r *reactorPeerBehavior) IsPersistentPeer(id types.ID) bool { + return r.isPersistentPeerFn(id) +} -// ----------------------------------------------------------------------------- +func (r *reactorPeerBehavior) IsPrivatePeer(id types.ID) bool { + return r.isPrivatePeerFn(id) +} -// Switch handles peer connections and exposes an API to receive incoming messages +// MultiplexSwitch handles peer connections and exposes an API to receive incoming messages // on `Reactors`. Each `Reactor` is responsible for handling incoming messages of one // or more `Channels`. So while sending outgoing messages is typically performed on the peer, // incoming messages are received on the reactor. -type Switch struct { +type MultiplexSwitch struct { service.BaseService - config *config.P2PConfig - reactors map[string]Reactor - chDescs []*conn.ChannelDescriptor - reactorsByCh map[byte]Reactor - peers *PeerSet - dialing *cmap.CMap - reconnecting *cmap.CMap - nodeInfo NodeInfo // our node info - nodeKey *NodeKey // our node privkey - // peers addresses with whom we'll maintain constant connection - persistentPeersAddrs []*NetAddress + ctx context.Context + cancelFn context.CancelFunc - transport Transport + maxInboundPeers uint64 + maxOutboundPeers uint64 - filterTimeout time.Duration - peerFilters []PeerFilterFunc + reactors map[string]Reactor + peerBehavior *reactorPeerBehavior - rng *random.Rand // seed for randomizing dial times and orders -} + peers PeerSet // currently active peer set (live connections) + persistentPeers sync.Map // ID -> *NetAddress; peers whose connections are constant + privatePeers sync.Map // ID -> nothing; lookup table of peers who are not shared + transport Transport -// NetAddress returns the address the switch is listening on. -func (sw *Switch) NetAddress() *NetAddress { - addr := sw.transport.NetAddress() - return &addr + dialQueue *dial.Queue + events *events.Events } -// SwitchOption sets an optional parameter on the Switch. -type SwitchOption func(*Switch) - -// NewSwitch creates a new Switch with the given config. -func NewSwitch( - cfg *config.P2PConfig, +// NewMultiplexSwitch creates a new MultiplexSwitch with the given config. +func NewMultiplexSwitch( transport Transport, - options ...SwitchOption, -) *Switch { - sw := &Switch{ - config: cfg, - reactors: make(map[string]Reactor), - chDescs: make([]*conn.ChannelDescriptor, 0), - reactorsByCh: make(map[byte]Reactor), - peers: NewPeerSet(), - dialing: cmap.NewCMap(), - reconnecting: cmap.NewCMap(), - transport: transport, - filterTimeout: defaultFilterTimeout, - persistentPeersAddrs: make([]*NetAddress, 0), - } - - // Ensure we have a completely undeterministic PRNG. - sw.rng = random.NewRand() - - sw.BaseService = *service.NewBaseService(nil, "P2P Switch", sw) - - for _, option := range options { - option(sw) - } - - return sw -} + opts ...SwitchOption, +) *MultiplexSwitch { + defaultCfg := config.DefaultP2PConfig() -// SwitchFilterTimeout sets the timeout used for peer filters. -func SwitchFilterTimeout(timeout time.Duration) SwitchOption { - return func(sw *Switch) { sw.filterTimeout = timeout } -} - -// SwitchPeerFilters sets the filters for rejection of new peers. -func SwitchPeerFilters(filters ...PeerFilterFunc) SwitchOption { - return func(sw *Switch) { sw.peerFilters = filters } -} - -// --------------------------------------------------------------------- -// Switch setup - -// AddReactor adds the given reactor to the switch. -// NOTE: Not goroutine safe. -func (sw *Switch) AddReactor(name string, reactor Reactor) Reactor { - for _, chDesc := range reactor.GetChannels() { - chID := chDesc.ID - // No two reactors can share the same channel. - if sw.reactorsByCh[chID] != nil { - panic(fmt.Sprintf("Channel %X has multiple reactors %v & %v", chID, sw.reactorsByCh[chID], reactor)) - } - sw.chDescs = append(sw.chDescs, chDesc) - sw.reactorsByCh[chID] = reactor + sw := &MultiplexSwitch{ + reactors: make(map[string]Reactor), + peers: newSet(), + transport: transport, + dialQueue: dial.NewQueue(), + events: events.New(), + maxInboundPeers: defaultCfg.MaxNumInboundPeers, + maxOutboundPeers: defaultCfg.MaxNumOutboundPeers, } - sw.reactors[name] = reactor - reactor.SetSwitch(sw) - return reactor -} -// RemoveReactor removes the given Reactor from the Switch. -// NOTE: Not goroutine safe. -func (sw *Switch) RemoveReactor(name string, reactor Reactor) { - for _, chDesc := range reactor.GetChannels() { - // remove channel description - for i := 0; i < len(sw.chDescs); i++ { - if chDesc.ID == sw.chDescs[i].ID { - sw.chDescs = append(sw.chDescs[:i], sw.chDescs[i+1:]...) - break - } - } - delete(sw.reactorsByCh, chDesc.ID) + // Set up the peer dial behavior + sw.peerBehavior = &reactorPeerBehavior{ + chDescs: make([]*conn.ChannelDescriptor, 0), + reactorsByCh: make(map[byte]Reactor), + handlePeerErrFn: sw.StopPeerForError, + isPersistentPeerFn: func(id types.ID) bool { + return sw.isPersistentPeer(id) + }, + isPrivatePeerFn: func(id types.ID) bool { + return sw.isPrivatePeer(id) + }, } - delete(sw.reactors, name) - reactor.SetSwitch(nil) -} -// Reactors returns a map of reactors registered on the switch. -// NOTE: Not goroutine safe. -func (sw *Switch) Reactors() map[string]Reactor { - return sw.reactors -} + sw.BaseService = *service.NewBaseService(nil, "P2P MultiplexSwitch", sw) -// Reactor returns the reactor with the given name. -// NOTE: Not goroutine safe. -func (sw *Switch) Reactor(name string) Reactor { - return sw.reactors[name] -} + // Set up the context + sw.ctx, sw.cancelFn = context.WithCancel(context.Background()) -// SetNodeInfo sets the switch's NodeInfo for checking compatibility and handshaking with other nodes. -// NOTE: Not goroutine safe. -func (sw *Switch) SetNodeInfo(nodeInfo NodeInfo) { - sw.nodeInfo = nodeInfo -} + // Apply the options + for _, opt := range opts { + opt(sw) + } -// NodeInfo returns the switch's NodeInfo. -// NOTE: Not goroutine safe. -func (sw *Switch) NodeInfo() NodeInfo { - return sw.nodeInfo + return sw } -// SetNodeKey sets the switch's private key for authenticated encryption. -// NOTE: Not goroutine safe. -func (sw *Switch) SetNodeKey(nodeKey *NodeKey) { - sw.nodeKey = nodeKey +// Subscribe registers to live events happening on the p2p Switch. +// Returns the notification channel, along with an unsubscribe method +func (sw *MultiplexSwitch) Subscribe(filterFn events.EventFilter) (<-chan events.Event, func()) { + return sw.events.Subscribe(filterFn) } // --------------------------------------------------------------------- // Service start/stop // OnStart implements BaseService. It starts all the reactors and peers. -func (sw *Switch) OnStart() error { +func (sw *MultiplexSwitch) OnStart() error { // Start reactors for _, reactor := range sw.reactors { - err := reactor.Start() - if err != nil { - return errors.Wrapf(err, "failed to start %v", reactor) + if err := reactor.Start(); err != nil { + return fmt.Errorf("unable to start reactor %w", err) } } - // Start accepting Peers. - go sw.acceptRoutine() + // Run the peer accept routine. + // The accept routine asynchronously accepts + // and processes incoming peer connections + go sw.runAcceptLoop(sw.ctx) + + // Run the dial routine. + // The dial routine parses items in the dial queue + // and initiates outbound peer connections + go sw.runDialLoop(sw.ctx) + + // Run the redial routine. + // The redial routine monitors for important + // peer disconnects, and attempts to reconnect + // to them + go sw.runRedialLoop(sw.ctx) return nil } // OnStop implements BaseService. It stops all peers and reactors. -func (sw *Switch) OnStop() { - // Stop transport - if t, ok := sw.transport.(TransportLifecycle); ok { - err := t.Close() - if err != nil { - sw.Logger.Error("Error stopping transport on stop: ", "error", err) - } - } +func (sw *MultiplexSwitch) OnStop() { + // Close all hanging threads + sw.cancelFn() // Stop peers for _, p := range sw.peers.List() { @@ -229,465 +167,504 @@ func (sw *Switch) OnStop() { } // Stop reactors - sw.Logger.Debug("Switch: Stopping reactors") for _, reactor := range sw.reactors { - reactor.Stop() - } -} - -// --------------------------------------------------------------------- -// Peers - -// Broadcast runs a go routine for each attempted send, which will block trying -// to send for defaultSendTimeoutSeconds. Returns a channel which receives -// success values for each attempted send (false if times out). Channel will be -// closed once msg bytes are sent to all peers (or time out). -// -// NOTE: Broadcast uses goroutines, so order of broadcast may not be preserved. -func (sw *Switch) Broadcast(chID byte, msgBytes []byte) chan bool { - startTime := time.Now() - - sw.Logger.Debug( - "Broadcast", - "channel", chID, - "value", fmt.Sprintf("%X", msgBytes), - ) - - peers := sw.peers.List() - var wg sync.WaitGroup - wg.Add(len(peers)) - successChan := make(chan bool, len(peers)) - - for _, peer := range peers { - go func(p Peer) { - defer wg.Done() - success := p.Send(chID, msgBytes) - successChan <- success - }(peer) - } - - go func() { - wg.Wait() - close(successChan) - if telemetry.MetricsEnabled() { - metrics.BroadcastTxTimer.Record(context.Background(), time.Since(startTime).Milliseconds()) - } - }() - - return successChan -} - -// NumPeers returns the count of outbound/inbound and outbound-dialing peers. -func (sw *Switch) NumPeers() (outbound, inbound, dialing int) { - peers := sw.peers.List() - for _, peer := range peers { - if peer.IsOutbound() { - outbound++ - } else { - inbound++ + if err := reactor.Stop(); err != nil { + sw.Logger.Error("unable to gracefully stop reactor", "err", err) } } - dialing = sw.dialing.Size() - return } -// MaxNumOutboundPeers returns a maximum number of outbound peers. -func (sw *Switch) MaxNumOutboundPeers() int { - return sw.config.MaxNumOutboundPeers +// Broadcast broadcasts the given data to the given channel, +// across the entire switch peer set, without blocking +func (sw *MultiplexSwitch) Broadcast(chID byte, data []byte) { + for _, p := range sw.peers.List() { + go func() { + // This send context is managed internally + // by the Peer's underlying connection implementation + if !p.Send(chID, data) { + sw.Logger.Error( + "unable to perform broadcast", + "chID", chID, + "peerID", p.ID(), + ) + } + }() + } } // Peers returns the set of peers that are connected to the switch. -func (sw *Switch) Peers() IPeerSet { +func (sw *MultiplexSwitch) Peers() PeerSet { return sw.peers } // StopPeerForError disconnects from a peer due to external error. -// If the peer is persistent, it will attempt to reconnect. -// TODO: make record depending on reason. -func (sw *Switch) StopPeerForError(peer Peer, reason interface{}) { - sw.Logger.Error("Stopping peer for error", "peer", peer, "err", reason) - sw.stopAndRemovePeer(peer, reason) - - if peer.IsPersistent() { - var addr *NetAddress - if peer.IsOutbound() { // socket address for outbound peers - addr = peer.SocketAddr() - } else { // self-reported address for inbound peers - addr = peer.NodeInfo().NetAddress - } - go sw.reconnectToPeer(addr) +// If the peer is persistent, it will attempt to reconnect +func (sw *MultiplexSwitch) StopPeerForError(peer PeerConn, err error) { + sw.Logger.Error("Stopping peer for error", "peer", peer, "err", err) + + sw.stopAndRemovePeer(peer, err) + + if !peer.IsPersistent() { + // Peer is not a persistent peer, + // no need to initiate a redial + return } -} -// StopPeerGracefully disconnects from a peer gracefully. -// TODO: handle graceful disconnects. -func (sw *Switch) StopPeerGracefully(peer Peer) { - sw.Logger.Info("Stopping peer gracefully") - sw.stopAndRemovePeer(peer, nil) + // Add the peer to the dial queue + sw.DialPeers(peer.SocketAddr()) } -func (sw *Switch) stopAndRemovePeer(peer Peer, reason interface{}) { - sw.transport.Cleanup(peer) - peer.Stop() +func (sw *MultiplexSwitch) stopAndRemovePeer(peer PeerConn, err error) { + // Remove the peer from the transport + sw.transport.Remove(peer) + + // Close the (original) peer connection + if closeErr := peer.CloseConn(); closeErr != nil { + sw.Logger.Error( + "unable to gracefully close peer connection", + "peer", peer, + "err", closeErr, + ) + } + + // Stop the peer connection multiplexing + if stopErr := peer.Stop(); stopErr != nil { + sw.Logger.Error( + "unable to gracefully stop peer", + "peer", peer, + "err", stopErr, + ) + } + // Alert the reactors of a peer removal for _, reactor := range sw.reactors { - reactor.RemovePeer(peer, reason) + reactor.RemovePeer(peer, err) } // Removing a peer should go last to avoid a situation where a peer // reconnect to our node and the switch calls InitPeer before // RemovePeer is finished. // https://github.com/tendermint/classic/issues/3338 - sw.peers.Remove(peer) + sw.peers.Remove(peer.ID()) + + sw.events.Notify(events.PeerDisconnectedEvent{ + Address: peer.RemoteAddr(), + PeerID: peer.ID(), + Reason: err, + }) } -// reconnectToPeer tries to reconnect to the addr, first repeatedly -// with a fixed interval, then with exponential backoff. -// If no success after all that, it stops trying. -// NOTE: this will keep trying even if the handshake or auth fails. -// TODO: be more explicit with error types so we only retry on certain failures -// - ie. if we're getting ErrDuplicatePeer we can stop -func (sw *Switch) reconnectToPeer(addr *NetAddress) { - if sw.reconnecting.Has(addr.ID.String()) { - return - } - sw.reconnecting.Set(addr.ID.String(), addr) - defer sw.reconnecting.Delete(addr.ID.String()) +// --------------------------------------------------------------------- +// Dialing - start := time.Now() - sw.Logger.Info("Reconnecting to peer", "addr", addr) - for i := 0; i < reconnectAttempts; i++ { - if !sw.IsRunning() { - return - } +func (sw *MultiplexSwitch) runDialLoop(ctx context.Context) { + for { + select { + case <-ctx.Done(): + sw.Logger.Debug("dial context canceled") - err := sw.DialPeerWithAddress(addr) - if err == nil { - return // success - } else if _, ok := err.(CurrentlyDialingOrExistingAddressError); ok { return - } + default: + // Grab a dial item + item := sw.dialQueue.Peek() + if item == nil { + // Nothing to dial + continue + } - sw.Logger.Info("Error reconnecting to peer. Trying again", "tries", i, "err", err, "addr", addr) - // sleep a set amount - sw.randomSleep(reconnectInterval) - continue - } + // Check if the dial time is right + // for the item + if time.Now().Before(item.Time) { + // Nothing to dial + continue + } - sw.Logger.Error("Failed to reconnect to peer. Beginning exponential backoff", - "addr", addr, "elapsed", time.Since(start)) - for i := 0; i < reconnectBackOffAttempts; i++ { - if !sw.IsRunning() { - return - } + // Pop the item from the dial queue + item = sw.dialQueue.Pop() - // sleep an exponentially increasing amount - sleepIntervalSeconds := math.Pow(reconnectBackOffBaseSeconds, float64(i)) - sw.randomSleep(time.Duration(sleepIntervalSeconds) * time.Second) + // Dial the peer + sw.Logger.Info( + "dialing peer", + "address", item.Address.String(), + ) - err := sw.DialPeerWithAddress(addr) - if err == nil { - return // success - } else if _, ok := err.(CurrentlyDialingOrExistingAddressError); ok { - return - } - sw.Logger.Info("Error reconnecting to peer. Trying again", "tries", i, "err", err, "addr", addr) - } - sw.Logger.Error("Failed to reconnect to peer. Giving up", "addr", addr, "elapsed", time.Since(start)) -} + peerAddr := item.Address -// --------------------------------------------------------------------- -// Dialing + // Check if the peer is already connected + ps := sw.Peers() + if ps.Has(peerAddr.ID) { + sw.Logger.Warn( + "ignoring dial request for existing peer", + "id", peerAddr.ID, + ) -// DialPeersAsync dials a list of peers asynchronously in random order. -// Used to dial peers from config on startup or from unsafe-RPC (trusted sources). -// It ignores NetAddressLookupError. However, if there are other errors, first -// encounter is returned. -// Nop if there are no peers. -func (sw *Switch) DialPeersAsync(peers []string) error { - netAddrs, errs := NewNetAddressFromStrings(peers) - // report all the errors - for _, err := range errs { - sw.Logger.Error("Error in peer's address", "err", err) - } - // return first non-NetAddressLookupError error - for _, err := range errs { - if _, ok := err.(NetAddressLookupError); ok { - continue - } - return err - } - sw.dialPeersAsync(netAddrs) - return nil -} + continue + } -func (sw *Switch) dialPeersAsync(netAddrs []*NetAddress) { - ourAddr := sw.NetAddress() + // Create a dial context + dialCtx, cancelFn := context.WithTimeout(ctx, defaultDialTimeout) + defer cancelFn() - // permute the list, dial them in random order. - perm := sw.rng.Perm(len(netAddrs)) - for i := 0; i < len(perm); i++ { - go func(i int) { - j := perm[i] - addr := netAddrs[j] + p, err := sw.transport.Dial(dialCtx, *peerAddr, sw.peerBehavior) + if err != nil { + sw.Logger.Error( + "unable to dial peer", + "peer", peerAddr, + "err", err, + ) - if addr.Same(ourAddr) { - sw.Logger.Debug("Ignore attempt to connect to ourselves", "addr", addr, "ourAddr", ourAddr) - return + continue } - sw.randomSleep(0) + // Register the peer with the switch + if err = sw.addPeer(p); err != nil { + sw.Logger.Error( + "unable to add peer", + "peer", p, + "err", err, + ) - err := sw.DialPeerWithAddress(addr) - if err != nil { - switch err.(type) { - case SwitchConnectToSelfError, SwitchDuplicatePeerIDError, CurrentlyDialingOrExistingAddressError: - sw.Logger.Debug("Error dialing peer", "err", err) - default: - sw.Logger.Error("Error dialing peer", "err", err) + sw.transport.Remove(p) + + if !p.IsRunning() { + continue + } + + if stopErr := p.Stop(); stopErr != nil { + sw.Logger.Error( + "unable to gracefully stop peer", + "peer", p, + "err", stopErr, + ) } } - }(i) + + // Log the telemetry + sw.logTelemetry() + } } } -// DialPeerWithAddress dials the given peer and runs sw.addPeer if it connects -// and authenticates successfully. -// If we're currently dialing this address or it belongs to an existing peer, -// CurrentlyDialingOrExistingAddressError is returned. -func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error { - if sw.IsDialingOrExistingAddress(addr) { - return CurrentlyDialingOrExistingAddressError{addr.String()} +// runRedialLoop starts the persistent peer redial loop +func (sw *MultiplexSwitch) runRedialLoop(ctx context.Context) { + ticker := time.NewTicker(time.Second * 5) + defer ticker.Stop() + + type backoffItem struct { + lastDialTime time.Time + attempts int } - sw.dialing.Set(addr.ID.String(), addr) - defer sw.dialing.Delete(addr.ID.String()) + var ( + backoffMap = make(map[types.ID]*backoffItem) - return sw.addOutboundPeerWithConfig(addr, sw.config) -} + mux sync.RWMutex + ) -// sleep for interval plus some random amount of ms on [0, dialRandomizerIntervalMilliseconds] -func (sw *Switch) randomSleep(interval time.Duration) { - r := time.Duration(sw.rng.Int63n(dialRandomizerIntervalMilliseconds)) * time.Millisecond - time.Sleep(r + interval) -} + setBackoffItem := func(id types.ID, item *backoffItem) { + mux.Lock() + defer mux.Unlock() -// IsDialingOrExistingAddress returns true if switch has a peer with the given -// address or dialing it at the moment. -func (sw *Switch) IsDialingOrExistingAddress(addr *NetAddress) bool { - return sw.dialing.Has(addr.ID.String()) || - sw.peers.Has(addr.ID) || - (!sw.config.AllowDuplicateIP && sw.peers.HasIP(addr.IP)) -} + backoffMap[id] = item + } -// AddPersistentPeers allows you to set persistent peers. It ignores -// NetAddressLookupError. However, if there are other errors, first encounter is -// returned. -func (sw *Switch) AddPersistentPeers(addrs []string) error { - sw.Logger.Info("Adding persistent peers", "addrs", addrs) - netAddrs, errs := NewNetAddressFromStrings(addrs) - // report all the errors - for _, err := range errs { - sw.Logger.Error("Error in peer's address", "err", err) - } - // return first non-NetAddressLookupError error - for _, err := range errs { - if _, ok := err.(NetAddressLookupError); ok { - continue - } - return err + getBackoffItem := func(id types.ID) *backoffItem { + mux.RLock() + defer mux.RUnlock() + + return backoffMap[id] } - sw.persistentPeersAddrs = netAddrs - return nil -} -func (sw *Switch) isPeerPersistentFn() func(*NetAddress) bool { - return func(na *NetAddress) bool { - for _, pa := range sw.persistentPeersAddrs { - if pa.Equals(na) { + clearBackoffItem := func(id types.ID) { + mux.Lock() + defer mux.Unlock() + + delete(backoffMap, id) + } + + subCh, unsubFn := sw.Subscribe(func(event events.Event) bool { + if event.Type() != events.PeerConnected { + return false + } + + ev := event.(events.PeerConnectedEvent) + + return sw.isPersistentPeer(ev.PeerID) + }) + defer unsubFn() + + // redialFn goes through the persistent peer list + // and dials missing peers + redialFn := func() { + var ( + peers = sw.Peers() + peersToDial = make([]*types.NetAddress, 0) + ) + + sw.persistentPeers.Range(func(key, value any) bool { + var ( + id = key.(types.ID) + addr = value.(*types.NetAddress) + ) + + // Check if the peer is part of the peer set + // or is scheduled for dialing + if peers.Has(id) || sw.dialQueue.Has(addr) { return true } - } - return false - } -} -func (sw *Switch) acceptRoutine() { - for { - p, err := sw.transport.Accept(peerConfig{ - chDescs: sw.chDescs, - onPeerError: sw.StopPeerForError, - reactorsByCh: sw.reactorsByCh, - isPersistent: sw.isPeerPersistentFn(), + peersToDial = append(peersToDial, addr) + + return true }) - if err != nil { - switch err := err.(type) { - case RejectedError: - if err.IsSelf() { - // TODO: warn? - } - sw.Logger.Info( - "Inbound Peer rejected", - "err", err, - "numPeers", sw.peers.Size(), - ) + if len(peersToDial) == 0 { + // No persistent peers are missing + return + } - continue - case FilterTimeoutError: - sw.Logger.Error( - "Peer filter timed out", - "err", err, - ) + // Calculate the dial items + dialItems := make([]dial.Item, 0, len(peersToDial)) + for _, p := range peersToDial { + item := getBackoffItem(p.ID) + if item == nil { + dialItem := dial.Item{ + Time: time.Now(), + Address: p, + } + + dialItems = append(dialItems, dialItem) + setBackoffItem(p.ID, &backoffItem{dialItem.Time, 0}) continue - case TransportClosedError: - sw.Logger.Error( - "Stopped accept routine, as transport is closed", - "numPeers", sw.peers.Size(), - ) - default: - sw.Logger.Error( - "Accept on transport errored", - "err", err, - "numPeers", sw.peers.Size(), - ) - // We could instead have a retry loop around the acceptRoutine, - // but that would need to stop and let the node shutdown eventually. - // So might as well panic and let process managers restart the node. - // There's no point in letting the node run without the acceptRoutine, - // since it won't be able to accept new connections. - panic(fmt.Errorf("accept routine exited: %w", err)) } - break + setBackoffItem(p.ID, &backoffItem{ + lastDialTime: time.Now().Add( + calculateBackoff( + item.attempts, + time.Second, + 10*time.Minute, + ), + ), + attempts: item.attempts + 1, + }) } - // Ignore connection if we already have enough peers. - _, in, _ := sw.NumPeers() - if in >= sw.config.MaxNumInboundPeers { - sw.Logger.Info( - "Ignoring inbound connection: already have enough inbound peers", - "address", p.SocketAddr(), - "have", in, - "max", sw.config.MaxNumInboundPeers, - ) + // Add the peers to the dial queue + sw.dialItems(dialItems...) + } + + // Run the initial redial loop on start, + // in case persistent peer connections are not + // active + redialFn() + + for { + select { + case <-ctx.Done(): + sw.Logger.Debug("redial crawl context canceled") + + return + case <-ticker.C: + redialFn() + case event := <-subCh: + // A persistent peer reconnected, + // clear their redial queue + ev := event.(events.PeerConnectedEvent) + + clearBackoffItem(ev.PeerID) + } + } +} + +// calculateBackoff calculates a backoff time, +// based on the number of attempts and range limits +func calculateBackoff( + attempts int, + minTimeout time.Duration, + maxTimeout time.Duration, +) time.Duration { + var ( + minTime = time.Second * 1 + maxTime = time.Second * 60 + multiplier = float64(2) // exponential + ) + + // Check the min limit + if minTimeout > 0 { + minTime = minTimeout + } + + // Check the max limit + if maxTimeout > 0 { + maxTime = maxTimeout + } + + // Sanity check the range + if minTime >= maxTime { + return maxTime + } + + // Calculate the backoff duration + var ( + base = float64(minTime) + calculated = base * math.Pow(multiplier, float64(attempts)) + ) + + // Attempt to calculate the jitter factor + n, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) + if err == nil { + jitterFactor := float64(n.Int64()) / float64(math.MaxInt64) // range [0, 1] + + calculated = jitterFactor*(calculated-base) + base + } - sw.transport.Cleanup(p) + // Prevent overflow for int64 (duration) cast + if calculated > float64(math.MaxInt64) { + return maxTime + } + + duration := time.Duration(calculated) + + // Clamp the duration within bounds + if duration < minTime { + return minTime + } + if duration > maxTime { + return maxTime + } + + return duration +} + +// DialPeers adds the peers to the dial queue for async dialing. +// To monitor dial progress, subscribe to adequate p2p MultiplexSwitch events +func (sw *MultiplexSwitch) DialPeers(peerAddrs ...*types.NetAddress) { + for _, peerAddr := range peerAddrs { + // Check if this is our address + if peerAddr.Same(sw.transport.NetAddress()) { continue } - if err := sw.addPeer(p); err != nil { - sw.transport.Cleanup(p) - if p.IsRunning() { - _ = p.Stop() - } - sw.Logger.Info( - "Ignoring inbound connection: error while adding peer", - "err", err, - "id", p.ID(), + // Ignore dial if the limit is reached + if out := sw.Peers().NumOutbound(); out >= sw.maxOutboundPeers { + sw.Logger.Warn( + "ignoring dial request: already have max outbound peers", + "have", out, + "max", sw.maxOutboundPeers, ) + + continue } + + item := dial.Item{ + Time: time.Now(), + Address: peerAddr, + } + + sw.dialQueue.Push(item) } } -// dial the peer; make secret connection; authenticate against the dialed ID; -// add the peer. -// if dialing fails, start the reconnect loop. If handshake fails, it's over. -// If peer is started successfully, reconnectLoop will start when -// StopPeerForError is called. -func (sw *Switch) addOutboundPeerWithConfig( - addr *NetAddress, - cfg *config.P2PConfig, -) error { - sw.Logger.Info("Dialing peer", "address", addr) - - // XXX(xla): Remove the leakage of test concerns in implementation. - if cfg.TestDialFail { - go sw.reconnectToPeer(addr) - return fmt.Errorf("dial err (peerConfig.DialFail == true)") - } - - p, err := sw.transport.Dial(*addr, peerConfig{ - chDescs: sw.chDescs, - onPeerError: sw.StopPeerForError, - isPersistent: sw.isPeerPersistentFn(), - reactorsByCh: sw.reactorsByCh, - }) - if err != nil { - if e, ok := err.(RejectedError); ok { - if e.IsSelf() { - // TODO: warn? - return err - } +// dialItems adds custom dial items for the multiplex switch +func (sw *MultiplexSwitch) dialItems(dialItems ...dial.Item) { + for _, dialItem := range dialItems { + // Check if this is our address + if dialItem.Address.Same(sw.transport.NetAddress()) { + continue } - // retry persistent peers after - // any dial error besides IsSelf() - if sw.isPeerPersistentFn()(addr) { - go sw.reconnectToPeer(addr) + // Ignore dial if the limit is reached + if out := sw.Peers().NumOutbound(); out >= sw.maxOutboundPeers { + sw.Logger.Warn( + "ignoring dial request: already have max outbound peers", + "have", out, + "max", sw.maxOutboundPeers, + ) + + continue } - return err + sw.dialQueue.Push(dialItem) } +} - if err := sw.addPeer(p); err != nil { - sw.transport.Cleanup(p) - if p.IsRunning() { - _ = p.Stop() - } - return err - } +// isPersistentPeer returns a flag indicating if a peer +// is present in the persistent peer set +func (sw *MultiplexSwitch) isPersistentPeer(id types.ID) bool { + _, persistent := sw.persistentPeers.Load(id) - return nil + return persistent } -func (sw *Switch) filterPeer(p Peer) error { - // Avoid duplicate - if sw.peers.Has(p.ID()) { - return RejectedError{id: p.ID(), isDuplicate: true} - } - - errc := make(chan error, len(sw.peerFilters)) +// isPrivatePeer returns a flag indicating if a peer +// is present in the private peer set +func (sw *MultiplexSwitch) isPrivatePeer(id types.ID) bool { + _, persistent := sw.privatePeers.Load(id) - for _, f := range sw.peerFilters { - go func(f PeerFilterFunc, p Peer, errc chan<- error) { - errc <- f(sw.peers, p) - }(f, p, errc) - } + return persistent +} - for i := 0; i < cap(errc); i++ { +// runAcceptLoop is the main powerhouse method +// for accepting incoming peer connections, filtering them, +// and persisting them +func (sw *MultiplexSwitch) runAcceptLoop(ctx context.Context) { + for { select { - case err := <-errc: + case <-ctx.Done(): + sw.Logger.Debug("switch context close received") + + return + default: + p, err := sw.transport.Accept(ctx, sw.peerBehavior) if err != nil { - return RejectedError{id: p.ID(), err: err, isFiltered: true} + sw.Logger.Error( + "error encountered during peer connection accept", + "err", err, + ) + + continue + } + + // Ignore connection if we already have enough peers. + if in := sw.Peers().NumInbound(); in >= sw.maxInboundPeers { + sw.Logger.Info( + "Ignoring inbound connection: already have enough inbound peers", + "address", p.SocketAddr(), + "have", in, + "max", sw.maxInboundPeers, + ) + + sw.transport.Remove(p) + + continue + } + + // There are open peer slots, add peers + if err := sw.addPeer(p); err != nil { + sw.transport.Remove(p) + + if p.IsRunning() { + _ = p.Stop() + } + + sw.Logger.Info( + "Ignoring inbound connection: error while adding peer", + "err", err, + "id", p.ID(), + ) } - case <-time.After(sw.filterTimeout): - return FilterTimeoutError{} } } - - return nil } -// addPeer starts up the Peer and adds it to the Switch. Error is returned if +// addPeer starts up the Peer and adds it to the MultiplexSwitch. Error is returned if // the peer is filtered out or failed to start or can't be added. -func (sw *Switch) addPeer(p Peer) error { - if err := sw.filterPeer(p); err != nil { - return err - } - +func (sw *MultiplexSwitch) addPeer(p PeerConn) error { p.SetLogger(sw.Logger.With("peer", p.SocketAddr())) - // Handle the shut down case where the switch has stopped but we're - // concurrently trying to add a peer. - if !sw.IsRunning() { - // XXX should this return an error or just log and terminate? - sw.Logger.Error("Won't start a peer - switch is not running", "peer", p) - return nil - } - // Add some data to the peer, which is required by reactors. for _, reactor := range sw.reactors { p = reactor.InitPeer(p) @@ -696,19 +673,15 @@ func (sw *Switch) addPeer(p Peer) error { // Start the peer's send/recv routines. // Must start it before adding it to the peer set // to prevent Start and Stop from being called concurrently. - err := p.Start() - if err != nil { - // Should never happen + if err := p.Start(); err != nil { sw.Logger.Error("Error starting peer", "err", err, "peer", p) + return err } - // Add the peer to PeerSet. Do this before starting the reactors + // Add the peer to the peer set. Do this before starting the reactors // so that if Receive errors, we will find the peer and remove it. - // Add should not err since we already checked peers.Has(). - if err := sw.peers.Add(p); err != nil { - return err - } + sw.peers.Add(p) // Start all the reactor protocols on the peer. for _, reactor := range sw.reactors { @@ -717,29 +690,28 @@ func (sw *Switch) addPeer(p Peer) error { sw.Logger.Info("Added peer", "peer", p) - // Update the telemetry data - sw.logTelemetry() + sw.events.Notify(events.PeerConnectedEvent{ + Address: p.RemoteAddr(), + PeerID: p.ID(), + }) return nil } // logTelemetry logs the switch telemetry data // to global metrics funnels -func (sw *Switch) logTelemetry() { +func (sw *MultiplexSwitch) logTelemetry() { // Update the telemetry data if !telemetry.MetricsEnabled() { return } // Fetch the number of peers - outbound, inbound, dialing := sw.NumPeers() + outbound, inbound := sw.peers.NumOutbound(), sw.peers.NumInbound() // Log the outbound peer count metrics.OutboundPeers.Record(context.Background(), int64(outbound)) // Log the inbound peer count metrics.InboundPeers.Record(context.Background(), int64(inbound)) - - // Log the dialing peer count - metrics.DialingPeers.Record(context.Background(), int64(dialing)) } diff --git a/tm2/pkg/p2p/switch_option.go b/tm2/pkg/p2p/switch_option.go new file mode 100644 index 00000000000..83a6920f2cd --- /dev/null +++ b/tm2/pkg/p2p/switch_option.go @@ -0,0 +1,61 @@ +package p2p + +import ( + "github.com/gnolang/gno/tm2/pkg/p2p/types" +) + +// SwitchOption is a callback used for configuring the p2p MultiplexSwitch +type SwitchOption func(*MultiplexSwitch) + +// WithReactor sets the p2p switch reactors +func WithReactor(name string, reactor Reactor) SwitchOption { + return func(sw *MultiplexSwitch) { + for _, chDesc := range reactor.GetChannels() { + chID := chDesc.ID + + // No two reactors can share the same channel + if sw.peerBehavior.reactorsByCh[chID] != nil { + continue + } + + sw.peerBehavior.chDescs = append(sw.peerBehavior.chDescs, chDesc) + sw.peerBehavior.reactorsByCh[chID] = reactor + } + + sw.reactors[name] = reactor + + reactor.SetSwitch(sw) + } +} + +// WithPersistentPeers sets the p2p switch's persistent peer set +func WithPersistentPeers(peerAddrs []*types.NetAddress) SwitchOption { + return func(sw *MultiplexSwitch) { + for _, addr := range peerAddrs { + sw.persistentPeers.Store(addr.ID, addr) + } + } +} + +// WithPrivatePeers sets the p2p switch's private peer set +func WithPrivatePeers(peerIDs []types.ID) SwitchOption { + return func(sw *MultiplexSwitch) { + for _, id := range peerIDs { + sw.privatePeers.Store(id, struct{}{}) + } + } +} + +// WithMaxInboundPeers sets the p2p switch's maximum inbound peer limit +func WithMaxInboundPeers(maxInbound uint64) SwitchOption { + return func(sw *MultiplexSwitch) { + sw.maxInboundPeers = maxInbound + } +} + +// WithMaxOutboundPeers sets the p2p switch's maximum outbound peer limit +func WithMaxOutboundPeers(maxOutbound uint64) SwitchOption { + return func(sw *MultiplexSwitch) { + sw.maxOutboundPeers = maxOutbound + } +} diff --git a/tm2/pkg/p2p/switch_test.go b/tm2/pkg/p2p/switch_test.go index a7033b466fe..19a5db2efa5 100644 --- a/tm2/pkg/p2p/switch_test.go +++ b/tm2/pkg/p2p/switch_test.go @@ -1,704 +1,825 @@ package p2p import ( - "bytes" - "errors" - "fmt" - "io" + "context" "net" "sync" - "sync/atomic" "testing" "time" + "github.com/gnolang/gno/tm2/pkg/errors" + "github.com/gnolang/gno/tm2/pkg/p2p/dial" + "github.com/gnolang/gno/tm2/pkg/p2p/mock" + "github.com/gnolang/gno/tm2/pkg/p2p/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" - "github.com/gnolang/gno/tm2/pkg/log" - "github.com/gnolang/gno/tm2/pkg/p2p/config" - "github.com/gnolang/gno/tm2/pkg/p2p/conn" - "github.com/gnolang/gno/tm2/pkg/testutils" ) -var cfg *config.P2PConfig +func TestMultiplexSwitch_Options(t *testing.T) { + t.Parallel() -func init() { - cfg = config.DefaultP2PConfig() - cfg.PexReactor = true - cfg.AllowDuplicateIP = true -} + t.Run("custom reactors", func(t *testing.T) { + t.Parallel() -type PeerMessage struct { - PeerID ID - Bytes []byte - Counter int -} + var ( + name = "custom reactor" + mockReactor = &mockReactor{ + setSwitchFn: func(s Switch) { + require.NotNil(t, s) + }, + } + ) -type TestReactor struct { - BaseReactor + sw := NewMultiplexSwitch(nil, WithReactor(name, mockReactor)) - mtx sync.Mutex - channels []*conn.ChannelDescriptor - logMessages bool - msgsCounter int - msgsReceived map[byte][]PeerMessage -} + assert.Equal(t, mockReactor, sw.reactors[name]) + }) -func NewTestReactor(channels []*conn.ChannelDescriptor, logMessages bool) *TestReactor { - tr := &TestReactor{ - channels: channels, - logMessages: logMessages, - msgsReceived: make(map[byte][]PeerMessage), - } - tr.BaseReactor = *NewBaseReactor("TestReactor", tr) - tr.SetLogger(log.NewNoopLogger()) - return tr -} + t.Run("persistent peers", func(t *testing.T) { + t.Parallel() -func (tr *TestReactor) GetChannels() []*conn.ChannelDescriptor { - return tr.channels -} + peers := generateNetAddr(t, 10) -func (tr *TestReactor) AddPeer(peer Peer) {} + sw := NewMultiplexSwitch(nil, WithPersistentPeers(peers)) -func (tr *TestReactor) RemovePeer(peer Peer, reason interface{}) {} + for _, p := range peers { + assert.True(t, sw.isPersistentPeer(p.ID)) + } + }) -func (tr *TestReactor) Receive(chID byte, peer Peer, msgBytes []byte) { - if tr.logMessages { - tr.mtx.Lock() - defer tr.mtx.Unlock() - // fmt.Printf("Received: %X, %X\n", chID, msgBytes) - tr.msgsReceived[chID] = append(tr.msgsReceived[chID], PeerMessage{peer.ID(), msgBytes, tr.msgsCounter}) - tr.msgsCounter++ - } -} + t.Run("private peers", func(t *testing.T) { + t.Parallel() -func (tr *TestReactor) getMsgs(chID byte) []PeerMessage { - tr.mtx.Lock() - defer tr.mtx.Unlock() - return tr.msgsReceived[chID] -} + var ( + peers = generateNetAddr(t, 10) + ids = make([]types.ID, 0, len(peers)) + ) -// ----------------------------------------------------------------------------- + for _, p := range peers { + ids = append(ids, p.ID) + } -// convenience method for creating two switches connected to each other. -// XXX: note this uses net.Pipe and not a proper TCP conn -func MakeSwitchPair(_ testing.TB, initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) { - // Create two switches that will be interconnected. - switches := MakeConnectedSwitches(cfg, 2, initSwitch, Connect2Switches) - return switches[0], switches[1] -} + sw := NewMultiplexSwitch(nil, WithPrivatePeers(ids)) -func initSwitchFunc(i int, sw *Switch) *Switch { - // Make two reactors of two channels each - sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{ - {ID: byte(0x00), Priority: 10}, - {ID: byte(0x01), Priority: 10}, - }, true)) - sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{ - {ID: byte(0x02), Priority: 10}, - {ID: byte(0x03), Priority: 10}, - }, true)) - - return sw -} + for _, p := range peers { + assert.True(t, sw.isPrivatePeer(p.ID)) + } + }) -func TestSwitches(t *testing.T) { - t.Parallel() + t.Run("max inbound peers", func(t *testing.T) { + t.Parallel() - s1, s2 := MakeSwitchPair(t, initSwitchFunc) - defer s1.Stop() - defer s2.Stop() + maxInbound := uint64(500) - if s1.Peers().Size() != 1 { - t.Errorf("Expected exactly 1 peer in s1, got %v", s1.Peers().Size()) - } - if s2.Peers().Size() != 1 { - t.Errorf("Expected exactly 1 peer in s2, got %v", s2.Peers().Size()) - } + sw := NewMultiplexSwitch(nil, WithMaxInboundPeers(maxInbound)) - // Lets send some messages - ch0Msg := []byte("channel zero") - ch1Msg := []byte("channel foo") - ch2Msg := []byte("channel bar") + assert.Equal(t, maxInbound, sw.maxInboundPeers) + }) + + t.Run("max outbound peers", func(t *testing.T) { + t.Parallel() - s1.Broadcast(byte(0x00), ch0Msg) - s1.Broadcast(byte(0x01), ch1Msg) - s1.Broadcast(byte(0x02), ch2Msg) + maxOutbound := uint64(500) - assertMsgReceivedWithTimeout(t, ch0Msg, byte(0x00), s2.Reactor("foo").(*TestReactor), 10*time.Millisecond, 5*time.Second) - assertMsgReceivedWithTimeout(t, ch1Msg, byte(0x01), s2.Reactor("foo").(*TestReactor), 10*time.Millisecond, 5*time.Second) - assertMsgReceivedWithTimeout(t, ch2Msg, byte(0x02), s2.Reactor("bar").(*TestReactor), 10*time.Millisecond, 5*time.Second) + sw := NewMultiplexSwitch(nil, WithMaxOutboundPeers(maxOutbound)) + + assert.Equal(t, maxOutbound, sw.maxOutboundPeers) + }) } -func assertMsgReceivedWithTimeout(t *testing.T, msgBytes []byte, channel byte, reactor *TestReactor, checkPeriod, timeout time.Duration) { - t.Helper() +func TestMultiplexSwitch_Broadcast(t *testing.T) { + t.Parallel() - ticker := time.NewTicker(checkPeriod) - for { - select { - case <-ticker.C: - msgs := reactor.getMsgs(channel) - if len(msgs) > 0 { - if !bytes.Equal(msgs[0].Bytes, msgBytes) { - t.Fatalf("Unexpected message bytes. Wanted: %X, Got: %X", msgBytes, msgs[0].Bytes) - } - return - } + var ( + wg sync.WaitGroup + + expectedChID = byte(10) + expectedData = []byte("broadcast data") - case <-time.After(timeout): - t.Fatalf("Expected to have received 1 message in channel #%v, got zero", channel) + mockTransport = &mockTransport{ + acceptFn: func(_ context.Context, _ PeerBehavior) (PeerConn, error) { + return nil, errors.New("constant error") + }, } - } -} -func TestSwitchFiltersOutItself(t *testing.T) { - t.Parallel() + peers = mock.GeneratePeers(t, 10) + sw = NewMultiplexSwitch(mockTransport) + ) - s1 := MakeSwitch(cfg, 1, "127.0.0.1", "123.123.123", initSwitchFunc) + require.NoError(t, sw.OnStart()) + t.Cleanup(sw.OnStop) - // simulate s1 having a public IP by creating a remote peer with the same ID - rp := &remotePeer{PrivKey: s1.nodeKey.PrivKey, Config: cfg} - rp.Start() + // Create a new peer set + sw.peers = newSet() - // addr should be rejected in addPeer based on the same ID - err := s1.DialPeerWithAddress(rp.Addr()) - if assert.Error(t, err) { - if err, ok := err.(RejectedError); ok { - if !err.IsSelf() { - t.Errorf("expected self to be rejected") - } - } else { - t.Errorf("expected RejectedError") + for _, p := range peers { + wg.Add(1) + + p.SendFn = func(chID byte, data []byte) bool { + wg.Done() + + require.Equal(t, expectedChID, chID) + assert.Equal(t, expectedData, data) + + return false } + + // Load it up with peers + sw.peers.Add(p) } - rp.Stop() + // Broadcast the data + sw.Broadcast(expectedChID, expectedData) - assertNoPeersAfterTimeout(t, s1, 100*time.Millisecond) + wg.Wait() } -func TestSwitchPeerFilter(t *testing.T) { +func TestMultiplexSwitch_Peers(t *testing.T) { t.Parallel() var ( - filters = []PeerFilterFunc{ - func(_ IPeerSet, _ Peer) error { return nil }, - func(_ IPeerSet, _ Peer) error { return fmt.Errorf("denied!") }, - func(_ IPeerSet, _ Peer) error { return nil }, - } - sw = MakeSwitch( - cfg, - 1, - "testing", - "123.123.123", - initSwitchFunc, - SwitchPeerFilters(filters...), - ) + peers = mock.GeneratePeers(t, 10) + sw = NewMultiplexSwitch(nil) ) - defer sw.Stop() - - // simulate remote peer - rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} - rp.Start() - defer rp.Stop() - - p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ - chDescs: sw.chDescs, - onPeerError: sw.StopPeerForError, - isPersistent: sw.isPeerPersistentFn(), - reactorsByCh: sw.reactorsByCh, - }) - if err != nil { - t.Fatal(err) - } - err = sw.addPeer(p) - if err, ok := err.(RejectedError); ok { - if !err.IsFiltered() { - t.Errorf("expected peer to be filtered") - } - } else { - t.Errorf("expected RejectedError") + // Create a new peer set + sw.peers = newSet() + + for _, p := range peers { + // Load it up with peers + sw.peers.Add(p) } -} -func TestSwitchPeerFilterTimeout(t *testing.T) { - t.Parallel() + // Broadcast the data + ps := sw.Peers() - var ( - filters = []PeerFilterFunc{ - func(_ IPeerSet, _ Peer) error { - time.Sleep(10 * time.Millisecond) - return nil - }, - } - sw = MakeSwitch( - cfg, - 1, - "testing", - "123.123.123", - initSwitchFunc, - SwitchFilterTimeout(5*time.Millisecond), - SwitchPeerFilters(filters...), - ) + require.EqualValues( + t, + len(peers), + ps.NumInbound()+ps.NumOutbound(), ) - defer sw.Stop() - - // simulate remote peer - rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} - rp.Start() - defer rp.Stop() - - p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ - chDescs: sw.chDescs, - onPeerError: sw.StopPeerForError, - isPersistent: sw.isPeerPersistentFn(), - reactorsByCh: sw.reactorsByCh, - }) - if err != nil { - t.Fatal(err) - } - err = sw.addPeer(p) - if _, ok := err.(FilterTimeoutError); !ok { - t.Errorf("expected FilterTimeoutError") + for _, p := range peers { + assert.True(t, ps.Has(p.ID())) } } -func TestSwitchPeerFilterDuplicate(t *testing.T) { +func TestMultiplexSwitch_StopPeer(t *testing.T) { t.Parallel() - sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) - sw.Start() - defer sw.Stop() + t.Run("peer not persistent", func(t *testing.T) { + t.Parallel() + + var ( + p = mock.GeneratePeers(t, 1)[0] + mockTransport = &mockTransport{ + removeFn: func(removedPeer PeerConn) { + assert.Equal(t, p.ID(), removedPeer.ID()) + }, + } + + sw = NewMultiplexSwitch(mockTransport) + ) + + // Create a new peer set + sw.peers = newSet() - // simulate remote peer - rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} - rp.Start() - defer rp.Stop() + // Save the single peer + sw.peers.Add(p) - p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ - chDescs: sw.chDescs, - onPeerError: sw.StopPeerForError, - isPersistent: sw.isPeerPersistentFn(), - reactorsByCh: sw.reactorsByCh, + // Stop and remove the peer + sw.StopPeerForError(p, nil) + + // Make sure the peer is removed + assert.False(t, sw.peers.Has(p.ID())) }) - if err != nil { - t.Fatal(err) - } - if err := sw.addPeer(p); err != nil { - t.Fatal(err) - } + t.Run("persistent peer", func(t *testing.T) { + t.Parallel() + + var ( + p = mock.GeneratePeers(t, 1)[0] + mockTransport = &mockTransport{ + removeFn: func(removedPeer PeerConn) { + assert.Equal(t, p.ID(), removedPeer.ID()) + }, + netAddressFn: func() types.NetAddress { + return types.NetAddress{} + }, + } - err = sw.addPeer(p) - if errRej, ok := err.(RejectedError); ok { - if !errRej.IsDuplicate() { - t.Errorf("expected peer to be duplicate. got %v", errRej) + sw = NewMultiplexSwitch(mockTransport) + ) + + // Make sure the peer is persistent + p.IsPersistentFn = func() bool { + return true + } + + p.IsOutboundFn = func() bool { + return false } - } else { - t.Errorf("expected RejectedError, got %v", err) - } -} -func assertNoPeersAfterTimeout(t *testing.T, sw *Switch, timeout time.Duration) { - t.Helper() + // Create a new peer set + sw.peers = newSet() - time.Sleep(timeout) - if sw.Peers().Size() != 0 { - t.Fatalf("Expected %v to not connect to some peers, got %d", sw, sw.Peers().Size()) - } + // Save the single peer + sw.peers.Add(p) + + // Stop and remove the peer + sw.StopPeerForError(p, nil) + + // Make sure the peer is removed + assert.False(t, sw.peers.Has(p.ID())) + + // Make sure the peer is in the dial queue + sw.dialQueue.Has(p.SocketAddr()) + }) } -func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) { +func TestMultiplexSwitch_DialLoop(t *testing.T) { t.Parallel() - assert, require := assert.New(t), require.New(t) + t.Run("peer already connected", func(t *testing.T) { + t.Parallel() - sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) - err := sw.Start() - if err != nil { - t.Error(err) - } - defer sw.Stop() - - // simulate remote peer - rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} - rp.Start() - defer rp.Stop() - - p, err := sw.transport.Dial(*rp.Addr(), peerConfig{ - chDescs: sw.chDescs, - onPeerError: sw.StopPeerForError, - isPersistent: sw.isPeerPersistentFn(), - reactorsByCh: sw.reactorsByCh, + ctx, cancelFn := context.WithTimeout( + context.Background(), + 5*time.Second, + ) + defer cancelFn() + + var ( + ch = make(chan struct{}, 1) + + peerDialed bool + + p = mock.GeneratePeers(t, 1)[0] + dialTime = time.Now().Add(-5 * time.Second) // in the past + + mockSet = &mockSet{ + hasFn: func(id types.ID) bool { + require.Equal(t, p.ID(), id) + + cancelFn() + + ch <- struct{}{} + + return true + }, + } + + mockTransport = &mockTransport{ + dialFn: func( + _ context.Context, + _ types.NetAddress, + _ PeerBehavior, + ) (PeerConn, error) { + peerDialed = true + + return nil, nil + }, + } + + sw = NewMultiplexSwitch(mockTransport) + ) + + sw.peers = mockSet + + // Prepare the dial queue + sw.dialQueue.Push(dial.Item{ + Time: dialTime, + Address: p.SocketAddr(), + }) + + // Run the dial loop + go sw.runDialLoop(ctx) + + select { + case <-ch: + case <-time.After(5 * time.Second): + } + + assert.False(t, peerDialed) }) - require.Nil(err) - err = sw.addPeer(p) - require.Nil(err) + t.Run("peer undialable", func(t *testing.T) { + t.Parallel() - require.NotNil(sw.Peers().Get(rp.ID())) + ctx, cancelFn := context.WithTimeout( + context.Background(), + 5*time.Second, + ) + defer cancelFn() - // simulate failure by closing connection - p.(*peer).CloseConn() + var ( + ch = make(chan struct{}, 1) - assertNoPeersAfterTimeout(t, sw, 100*time.Millisecond) - assert.False(p.IsRunning()) -} + peerDialed bool -func TestSwitchStopPeerForError(t *testing.T) { - t.Parallel() + p = mock.GeneratePeers(t, 1)[0] + dialTime = time.Now().Add(-5 * time.Second) // in the past + + mockSet = &mockSet{ + hasFn: func(id types.ID) bool { + require.Equal(t, p.ID(), id) + + return false + }, + } + + mockTransport = &mockTransport{ + dialFn: func( + _ context.Context, + _ types.NetAddress, + _ PeerBehavior, + ) (PeerConn, error) { + peerDialed = true + + cancelFn() + + ch <- struct{}{} + + return nil, errors.New("invalid dial") + }, + } + + sw = NewMultiplexSwitch(mockTransport) + ) + + sw.peers = mockSet - // make two connected switches - sw1, sw2 := MakeSwitchPair(t, func(i int, sw *Switch) *Switch { - return initSwitchFunc(i, sw) + // Prepare the dial queue + sw.dialQueue.Push(dial.Item{ + Time: dialTime, + Address: p.SocketAddr(), + }) + + // Run the dial loop + go sw.runDialLoop(ctx) + + select { + case <-ch: + case <-time.After(5 * time.Second): + } + + assert.True(t, peerDialed) }) - assert.Equal(t, len(sw1.Peers().List()), 1) + t.Run("peer dialed and added", func(t *testing.T) { + t.Parallel() - // send messages to the peer from sw1 - p := sw1.Peers().List()[0] - p.Send(0x1, []byte("here's a message to send")) + ctx, cancelFn := context.WithTimeout( + context.Background(), + 5*time.Second, + ) + defer cancelFn() - // stop sw2. this should cause the p to fail, - // which results in calling StopPeerForError internally - sw2.Stop() + var ( + ch = make(chan struct{}, 1) - // now call StopPeerForError explicitly, eg. from a reactor - sw1.StopPeerForError(p, fmt.Errorf("some err")) + peerDialed bool - assert.Equal(t, len(sw1.Peers().List()), 0) -} + p = mock.GeneratePeers(t, 1)[0] + dialTime = time.Now().Add(-5 * time.Second) // in the past -func TestSwitchReconnectsToOutboundPersistentPeer(t *testing.T) { - t.Parallel() + mockTransport = &mockTransport{ + dialFn: func( + _ context.Context, + _ types.NetAddress, + _ PeerBehavior, + ) (PeerConn, error) { + peerDialed = true - sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) - err := sw.Start() - require.NoError(t, err) - defer sw.Stop() - - // 1. simulate failure by closing connection - rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} - rp.Start() - defer rp.Stop() - - err = sw.AddPersistentPeers([]string{rp.Addr().String()}) - require.NoError(t, err) - - err = sw.DialPeerWithAddress(rp.Addr()) - require.Nil(t, err) - require.NotNil(t, sw.Peers().Get(rp.ID())) - - p := sw.Peers().List()[0] - p.(*peer).CloseConn() - - waitUntilSwitchHasAtLeastNPeers(sw, 1) - assert.False(t, p.IsRunning()) // old peer instance - assert.Equal(t, 1, sw.Peers().Size()) // new peer instance - - // 2. simulate first time dial failure - rp = &remotePeer{ - PrivKey: ed25519.GenPrivKey(), - Config: cfg, - // Use different interface to prevent duplicate IP filter, this will break - // beyond two peers. - listenAddr: "127.0.0.1:0", - } - rp.Start() - defer rp.Stop() - - conf := config.DefaultP2PConfig() - conf.TestDialFail = true // will trigger a reconnect - err = sw.addOutboundPeerWithConfig(rp.Addr(), conf) - require.NotNil(t, err) - // DialPeerWithAddres - sw.peerConfig resets the dialer - waitUntilSwitchHasAtLeastNPeers(sw, 2) - assert.Equal(t, 2, sw.Peers().Size()) -} + cancelFn() -func TestSwitchReconnectsToInboundPersistentPeer(t *testing.T) { - t.Parallel() + ch <- struct{}{} + + return p, nil + }, + } - sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) - err := sw.Start() - require.NoError(t, err) - defer sw.Stop() + sw = NewMultiplexSwitch(mockTransport) + ) - // 1. simulate failure by closing the connection - rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} - rp.Start() - defer rp.Stop() + // Prepare the dial queue + sw.dialQueue.Push(dial.Item{ + Time: dialTime, + Address: p.SocketAddr(), + }) - err = sw.AddPersistentPeers([]string{rp.Addr().String()}) - require.NoError(t, err) + // Run the dial loop + go sw.runDialLoop(ctx) - conn, err := rp.Dial(sw.NetAddress()) - require.NoError(t, err) - time.Sleep(100 * time.Millisecond) - require.NotNil(t, sw.Peers().Get(rp.ID())) + select { + case <-ch: + case <-time.After(5 * time.Second): + } - conn.Close() + require.True(t, sw.Peers().Has(p.ID())) - waitUntilSwitchHasAtLeastNPeers(sw, 1) - assert.Equal(t, 1, sw.Peers().Size()) + assert.True(t, peerDialed) + }) } -func TestSwitchDialPeersAsync(t *testing.T) { +func TestMultiplexSwitch_AcceptLoop(t *testing.T) { t.Parallel() - if testing.Short() { - return - } + t.Run("inbound limit reached", func(t *testing.T) { + t.Parallel() - sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) - err := sw.Start() - require.NoError(t, err) - defer sw.Stop() + ctx, cancelFn := context.WithTimeout( + context.Background(), + 5*time.Second, + ) + defer cancelFn() - rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} - rp.Start() - defer rp.Stop() + var ( + ch = make(chan struct{}, 1) + maxInbound = uint64(10) - err = sw.DialPeersAsync([]string{rp.Addr().String()}) - require.NoError(t, err) - time.Sleep(dialRandomizerIntervalMilliseconds * time.Millisecond) - require.NotNil(t, sw.Peers().Get(rp.ID())) -} + peerRemoved bool + + p = mock.GeneratePeers(t, 1)[0] + + mockTransport = &mockTransport{ + acceptFn: func(_ context.Context, _ PeerBehavior) (PeerConn, error) { + return p, nil + }, + removeFn: func(removedPeer PeerConn) { + require.Equal(t, p.ID(), removedPeer.ID()) + + peerRemoved = true + + ch <- struct{}{} + }, + } -func waitUntilSwitchHasAtLeastNPeers(sw *Switch, n int) { - for i := 0; i < 20; i++ { - time.Sleep(250 * time.Millisecond) - has := sw.Peers().Size() - if has >= n { - break + ps = &mockSet{ + numInboundFn: func() uint64 { + return maxInbound + }, + } + + sw = NewMultiplexSwitch( + mockTransport, + WithMaxInboundPeers(maxInbound), + ) + ) + + // Set the peer set + sw.peers = ps + + // Run the accept loop + go sw.runAcceptLoop(ctx) + + select { + case <-ch: + case <-time.After(5 * time.Second): } - } + + assert.True(t, peerRemoved) + }) + + t.Run("peer accepted", func(t *testing.T) { + t.Parallel() + + ctx, cancelFn := context.WithTimeout( + context.Background(), + 5*time.Second, + ) + defer cancelFn() + + var ( + ch = make(chan struct{}, 1) + maxInbound = uint64(10) + + peerAdded bool + + p = mock.GeneratePeers(t, 1)[0] + + mockTransport = &mockTransport{ + acceptFn: func(_ context.Context, _ PeerBehavior) (PeerConn, error) { + return p, nil + }, + } + + ps = &mockSet{ + numInboundFn: func() uint64 { + return maxInbound - 1 // available slot + }, + addFn: func(peer PeerConn) { + require.Equal(t, p.ID(), peer.ID()) + + peerAdded = true + + ch <- struct{}{} + }, + } + + sw = NewMultiplexSwitch( + mockTransport, + WithMaxInboundPeers(maxInbound), + ) + ) + + // Set the peer set + sw.peers = ps + + // Run the accept loop + go sw.runAcceptLoop(ctx) + + select { + case <-ch: + case <-time.After(5 * time.Second): + } + + assert.True(t, peerAdded) + }) } -func TestSwitchFullConnectivity(t *testing.T) { +func TestMultiplexSwitch_RedialLoop(t *testing.T) { t.Parallel() - switches := MakeConnectedSwitches(cfg, 3, initSwitchFunc, Connect2Switches) - defer func() { - for _, sw := range switches { - sw.Stop() + t.Run("no peers to dial", func(t *testing.T) { + t.Parallel() + + var ( + ch = make(chan struct{}, 1) + + peersChecked = 0 + peers = mock.GeneratePeers(t, 10) + + ps = &mockSet{ + hasFn: func(id types.ID) bool { + exists := false + for _, p := range peers { + if p.ID() == id { + exists = true + + break + } + } + + require.True(t, exists) + + peersChecked++ + + if peersChecked == len(peers) { + ch <- struct{}{} + } + + return true + }, + } + ) + + // Make sure the peers are the + // switch persistent peers + addrs := make([]*types.NetAddress, 0, len(peers)) + + for _, p := range peers { + addrs = append(addrs, p.SocketAddr()) } - }() - for i, sw := range switches { - if sw.Peers().Size() != 2 { - t.Fatalf("Expected each switch to be connected to 2 other, but %d switch only connected to %d", sw.Peers().Size(), i) + // Create the switch + sw := NewMultiplexSwitch( + nil, + WithPersistentPeers(addrs), + ) + + // Set the peer set + sw.peers = ps + + // Run the redial loop + ctx, cancelFn := context.WithTimeout( + context.Background(), + 5*time.Second, + ) + defer cancelFn() + + go sw.runRedialLoop(ctx) + + select { + case <-ch: + case <-time.After(5 * time.Second): } - } -} -func TestSwitchAcceptRoutine(t *testing.T) { - t.Parallel() + assert.Equal(t, len(peers), peersChecked) + }) + + t.Run("missing peers dialed", func(t *testing.T) { + t.Parallel() + + var ( + peers = mock.GeneratePeers(t, 10) + missingPeer = peers[0] + missingAddr = missingPeer.SocketAddr() + + peersDialed []types.NetAddress + + mockTransport = &mockTransport{ + dialFn: func( + _ context.Context, + address types.NetAddress, + _ PeerBehavior, + ) (PeerConn, error) { + peersDialed = append(peersDialed, address) + + if address.Equals(*missingPeer.SocketAddr()) { + return missingPeer, nil + } + + return nil, errors.New("invalid dial") + }, + } + ps = &mockSet{ + hasFn: func(id types.ID) bool { + return id != missingPeer.ID() + }, + } + ) + + // Make sure the peers are the + // switch persistent peers + addrs := make([]*types.NetAddress, 0, len(peers)) + + for _, p := range peers { + addrs = append(addrs, p.SocketAddr()) + } + + // Create the switch + sw := NewMultiplexSwitch( + mockTransport, + WithPersistentPeers(addrs), + ) + + // Set the peer set + sw.peers = ps + + // Run the redial loop + ctx, cancelFn := context.WithTimeout( + context.Background(), + 5*time.Second, + ) + defer cancelFn() + + var wg sync.WaitGroup + + wg.Add(2) + + go func() { + defer wg.Done() + + sw.runRedialLoop(ctx) + }() + + go func() { + defer wg.Done() + + deadline := time.After(5 * time.Second) - cfg.MaxNumInboundPeers = 5 - - // make switch - sw := MakeSwitch(cfg, 1, "testing", "123.123.123", initSwitchFunc) - err := sw.Start() - require.NoError(t, err) - defer sw.Stop() - - remotePeers := make([]*remotePeer, 0) - assert.Equal(t, 0, sw.Peers().Size()) - - // 1. check we connect up to MaxNumInboundPeers - for i := 0; i < cfg.MaxNumInboundPeers; i++ { - rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} - remotePeers = append(remotePeers, rp) - rp.Start() - c, err := rp.Dial(sw.NetAddress()) - require.NoError(t, err) - // spawn a reading routine to prevent connection from closing - go func(c net.Conn) { for { - one := make([]byte, 1) - _, err := c.Read(one) - if err != nil { + select { + case <-deadline: + return + default: + if !sw.dialQueue.Has(missingAddr) { + continue + } + + cancelFn() + return } } - }(c) - } - time.Sleep(100 * time.Millisecond) - assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size()) - - // 2. check we close new connections if we already have MaxNumInboundPeers peers - rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} - rp.Start() - conn, err := rp.Dial(sw.NetAddress()) - require.NoError(t, err) - // check conn is closed - one := make([]byte, 1) - conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - _, err = conn.Read(one) - assert.Equal(t, io.EOF, err) - assert.Equal(t, cfg.MaxNumInboundPeers, sw.Peers().Size()) - rp.Stop() - - // stop remote peers - for _, rp := range remotePeers { - rp.Stop() - } -} + }() -type errorTransport struct { - acceptErr error -} + wg.Wait() -func (et errorTransport) NetAddress() NetAddress { - panic("not implemented") + require.True(t, sw.dialQueue.Has(missingAddr)) + assert.Equal(t, missingAddr, sw.dialQueue.Peek().Address) + }) } -func (et errorTransport) Accept(c peerConfig) (Peer, error) { - return nil, et.acceptErr -} +func TestMultiplexSwitch_DialPeers(t *testing.T) { + t.Parallel() -func (errorTransport) Dial(NetAddress, peerConfig) (Peer, error) { - panic("not implemented") -} + t.Run("self dial request", func(t *testing.T) { + t.Parallel() -func (errorTransport) Cleanup(Peer) { - panic("not implemented") -} + var ( + p = mock.GeneratePeers(t, 1)[0] + addr = types.NetAddress{ + ID: "id", + IP: p.SocketAddr().IP, + Port: p.SocketAddr().Port, + } -func TestSwitchAcceptRoutineErrorCases(t *testing.T) { - t.Parallel() + mockTransport = &mockTransport{ + netAddressFn: func() types.NetAddress { + return addr + }, + } + ) - sw := NewSwitch(cfg, errorTransport{FilterTimeoutError{}}) - assert.NotPanics(t, func() { - err := sw.Start() - assert.NoError(t, err) - sw.Stop() - }) + // Make sure the "peer" has the same address + // as the transport (node) + p.NodeInfoFn = func() types.NodeInfo { + return types.NodeInfo{ + PeerID: addr.ID, + } + } - sw = NewSwitch(cfg, errorTransport{RejectedError{conn: nil, err: errors.New("filtered"), isFiltered: true}}) - assert.NotPanics(t, func() { - err := sw.Start() - assert.NoError(t, err) - sw.Stop() - }) + sw := NewMultiplexSwitch(mockTransport) + + // Dial the peers + sw.DialPeers(p.SocketAddr()) - sw = NewSwitch(cfg, errorTransport{TransportClosedError{}}) - assert.NotPanics(t, func() { - err := sw.Start() - assert.NoError(t, err) - sw.Stop() + // Make sure the peer wasn't actually dialed + assert.False(t, sw.dialQueue.Has(p.SocketAddr())) }) -} -// mockReactor checks that InitPeer never called before RemovePeer. If that's -// not true, InitCalledBeforeRemoveFinished will return true. -type mockReactor struct { - *BaseReactor + t.Run("outbound peer limit reached", func(t *testing.T) { + t.Parallel() - // atomic - removePeerInProgress uint32 - initCalledBeforeRemoveFinished uint32 -} + var ( + maxOutbound = uint64(10) + peers = mock.GeneratePeers(t, 10) -func (r *mockReactor) RemovePeer(peer Peer, reason interface{}) { - atomic.StoreUint32(&r.removePeerInProgress, 1) - defer atomic.StoreUint32(&r.removePeerInProgress, 0) - time.Sleep(100 * time.Millisecond) -} + mockTransport = &mockTransport{ + netAddressFn: func() types.NetAddress { + return types.NetAddress{ + ID: "id", + IP: net.IP{}, + } + }, + } -func (r *mockReactor) InitPeer(peer Peer) Peer { - if atomic.LoadUint32(&r.removePeerInProgress) == 1 { - atomic.StoreUint32(&r.initCalledBeforeRemoveFinished, 1) - } + ps = &mockSet{ + numOutboundFn: func() uint64 { + return maxOutbound + }, + } + ) - return peer -} + sw := NewMultiplexSwitch( + mockTransport, + WithMaxOutboundPeers(maxOutbound), + ) -func (r *mockReactor) InitCalledBeforeRemoveFinished() bool { - return atomic.LoadUint32(&r.initCalledBeforeRemoveFinished) == 1 -} + // Set the peer set + sw.peers = ps -// see stopAndRemovePeer -func TestFlappySwitchInitPeerIsNotCalledBeforeRemovePeer(t *testing.T) { - t.Parallel() + // Dial the peers + addrs := make([]*types.NetAddress, 0, len(peers)) - testutils.FilterStability(t, testutils.Flappy) + for _, p := range peers { + addrs = append(addrs, p.SocketAddr()) + } - // make reactor - reactor := &mockReactor{} - reactor.BaseReactor = NewBaseReactor("mockReactor", reactor) + sw.DialPeers(addrs...) - // make switch - sw := MakeSwitch(cfg, 1, "testing", "123.123.123", func(i int, sw *Switch) *Switch { - sw.AddReactor("mock", reactor) - return sw + // Make sure no peers were dialed + for _, p := range peers { + assert.False(t, sw.dialQueue.Has(p.SocketAddr())) + } }) - err := sw.Start() - require.NoError(t, err) - defer sw.Stop() - - // add peer - rp := &remotePeer{PrivKey: ed25519.GenPrivKey(), Config: cfg} - rp.Start() - defer rp.Stop() - _, err = rp.Dial(sw.NetAddress()) - require.NoError(t, err) - // wait till the switch adds rp to the peer set - time.Sleep(100 * time.Millisecond) - - // stop peer asynchronously - go sw.StopPeerForError(sw.Peers().Get(rp.ID()), "test") - - // simulate peer reconnecting to us - _, err = rp.Dial(sw.NetAddress()) - require.NoError(t, err) - // wait till the switch adds rp to the peer set - time.Sleep(100 * time.Millisecond) - - // make sure reactor.RemovePeer is finished before InitPeer is called - assert.False(t, reactor.InitCalledBeforeRemoveFinished()) -} -func BenchmarkSwitchBroadcast(b *testing.B) { - s1, s2 := MakeSwitchPair(b, func(i int, sw *Switch) *Switch { - // Make bar reactors of bar channels each - sw.AddReactor("foo", NewTestReactor([]*conn.ChannelDescriptor{ - {ID: byte(0x00), Priority: 10}, - {ID: byte(0x01), Priority: 10}, - }, false)) - sw.AddReactor("bar", NewTestReactor([]*conn.ChannelDescriptor{ - {ID: byte(0x02), Priority: 10}, - {ID: byte(0x03), Priority: 10}, - }, false)) - return sw - }) - defer s1.Stop() - defer s2.Stop() + t.Run("peers dialed", func(t *testing.T) { + t.Parallel() + + var ( + maxOutbound = uint64(1000) + peers = mock.GeneratePeers(t, int(maxOutbound/2)) - // Allow time for goroutines to boot up - time.Sleep(1 * time.Second) + mockTransport = &mockTransport{ + netAddressFn: func() types.NetAddress { + return types.NetAddress{ + ID: "id", + IP: net.IP{}, + } + }, + } + ) - b.ResetTimer() + sw := NewMultiplexSwitch( + mockTransport, + WithMaxOutboundPeers(10), + ) - numSuccess, numFailure := 0, 0 + // Dial the peers + addrs := make([]*types.NetAddress, 0, len(peers)) - // Send random message from foo channel to another - for i := 0; i < b.N; i++ { - chID := byte(i % 4) - successChan := s1.Broadcast(chID, []byte("test data")) - for s := range successChan { - if s { - numSuccess++ - } else { - numFailure++ - } + for _, p := range peers { + addrs = append(addrs, p.SocketAddr()) } - } - b.Logf("success: %v, failure: %v", numSuccess, numFailure) + sw.DialPeers(addrs...) + + // Make sure peers were dialed + for _, p := range peers { + assert.True(t, sw.dialQueue.Has(p.SocketAddr())) + } + }) } diff --git a/tm2/pkg/p2p/test_util.go b/tm2/pkg/p2p/test_util.go deleted file mode 100644 index dd0d9cd6bc7..00000000000 --- a/tm2/pkg/p2p/test_util.go +++ /dev/null @@ -1,238 +0,0 @@ -package p2p - -import ( - "fmt" - "net" - "time" - - "github.com/gnolang/gno/tm2/pkg/crypto" - "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" - "github.com/gnolang/gno/tm2/pkg/errors" - "github.com/gnolang/gno/tm2/pkg/log" - "github.com/gnolang/gno/tm2/pkg/p2p/config" - "github.com/gnolang/gno/tm2/pkg/p2p/conn" - "github.com/gnolang/gno/tm2/pkg/random" - "github.com/gnolang/gno/tm2/pkg/versionset" -) - -const testCh = 0x01 - -// ------------------------------------------------ - -func CreateRoutableAddr() (addr string, netAddr *NetAddress) { - for { - id := ed25519.GenPrivKey().PubKey().Address().ID() - var err error - addr = fmt.Sprintf("%s@%v.%v.%v.%v:26656", id, random.RandInt()%256, random.RandInt()%256, random.RandInt()%256, random.RandInt()%256) - netAddr, err = NewNetAddressFromString(addr) - if err != nil { - panic(err) - } - if netAddr.Routable() { - break - } - } - return -} - -// ------------------------------------------------------------------ -// Connects switches via arbitrary net.Conn. Used for testing. - -const TEST_HOST = "localhost" - -// MakeConnectedSwitches returns n switches, connected according to the connect func. -// If connect==Connect2Switches, the switches will be fully connected. -// initSwitch defines how the i'th switch should be initialized (ie. with what reactors). -// NOTE: panics if any switch fails to start. -func MakeConnectedSwitches(cfg *config.P2PConfig, n int, initSwitch func(int, *Switch) *Switch, connect func([]*Switch, int, int)) []*Switch { - switches := make([]*Switch, n) - for i := 0; i < n; i++ { - switches[i] = MakeSwitch(cfg, i, TEST_HOST, "123.123.123", initSwitch) - } - - if err := StartSwitches(switches); err != nil { - panic(err) - } - - for i := 0; i < n; i++ { - for j := i + 1; j < n; j++ { - connect(switches, i, j) - } - } - - return switches -} - -// Connect2Switches will connect switches i and j via net.Pipe(). -// Blocks until a connection is established. -// NOTE: caller ensures i and j are within bounds. -func Connect2Switches(switches []*Switch, i, j int) { - switchI := switches[i] - switchJ := switches[j] - - c1, c2 := conn.NetPipe() - - doneCh := make(chan struct{}) - go func() { - err := switchI.addPeerWithConnection(c1) - if err != nil { - panic(err) - } - doneCh <- struct{}{} - }() - go func() { - err := switchJ.addPeerWithConnection(c2) - if err != nil { - panic(err) - } - doneCh <- struct{}{} - }() - <-doneCh - <-doneCh -} - -func (sw *Switch) addPeerWithConnection(conn net.Conn) error { - pc, err := testInboundPeerConn(conn, sw.config, sw.nodeKey.PrivKey) - if err != nil { - if err := conn.Close(); err != nil { - sw.Logger.Error("Error closing connection", "err", err) - } - return err - } - - ni, err := handshake(conn, time.Second, sw.nodeInfo) - if err != nil { - if err := conn.Close(); err != nil { - sw.Logger.Error("Error closing connection", "err", err) - } - return err - } - - p := newPeer( - pc, - MConnConfig(sw.config), - ni, - sw.reactorsByCh, - sw.chDescs, - sw.StopPeerForError, - ) - - if err = sw.addPeer(p); err != nil { - pc.CloseConn() - return err - } - - return nil -} - -// StartSwitches calls sw.Start() for each given switch. -// It returns the first encountered error. -func StartSwitches(switches []*Switch) error { - for _, s := range switches { - err := s.Start() // start switch and reactors - if err != nil { - return err - } - } - return nil -} - -func MakeSwitch( - cfg *config.P2PConfig, - i int, - network, version string, - initSwitch func(int, *Switch) *Switch, - opts ...SwitchOption, -) *Switch { - nodeKey := NodeKey{ - PrivKey: ed25519.GenPrivKey(), - } - nodeInfo := testNodeInfo(nodeKey.ID(), fmt.Sprintf("node%d", i)) - - t := NewMultiplexTransport(nodeInfo, nodeKey, MConnConfig(cfg)) - - if err := t.Listen(*nodeInfo.NetAddress); err != nil { - panic(err) - } - - // TODO: let the config be passed in? - sw := initSwitch(i, NewSwitch(cfg, t, opts...)) - sw.SetLogger(log.NewNoopLogger().With("switch", i)) - sw.SetNodeKey(&nodeKey) - - for ch := range sw.reactorsByCh { - nodeInfo.Channels = append(nodeInfo.Channels, ch) - } - - // TODO: We need to setup reactors ahead of time so the NodeInfo is properly - // populated and we don't have to do those awkward overrides and setters. - t.nodeInfo = nodeInfo - sw.SetNodeInfo(nodeInfo) - - return sw -} - -func testInboundPeerConn( - conn net.Conn, - config *config.P2PConfig, - ourNodePrivKey crypto.PrivKey, -) (peerConn, error) { - return testPeerConn(conn, config, false, false, ourNodePrivKey, nil) -} - -func testPeerConn( - rawConn net.Conn, - cfg *config.P2PConfig, - outbound, persistent bool, - ourNodePrivKey crypto.PrivKey, - socketAddr *NetAddress, -) (pc peerConn, err error) { - conn := rawConn - - // Fuzz connection - if cfg.TestFuzz { - // so we have time to do peer handshakes and get set up - conn = FuzzConnAfterFromConfig(conn, 10*time.Second, cfg.TestFuzzConfig) - } - - // Encrypt connection - conn, err = upgradeSecretConn(conn, cfg.HandshakeTimeout, ourNodePrivKey) - if err != nil { - return pc, errors.Wrap(err, "Error creating peer") - } - - // Only the information we already have - return newPeerConn(outbound, persistent, conn, socketAddr), nil -} - -// ---------------------------------------------------------------- -// rand node info - -func testNodeInfo(id ID, name string) NodeInfo { - return testNodeInfoWithNetwork(id, name, "testing") -} - -func testVersionSet() versionset.VersionSet { - return versionset.VersionSet{ - versionset.VersionInfo{ - Name: "p2p", - Version: "v0.0.0", // dontcare - }, - } -} - -func testNodeInfoWithNetwork(id ID, name, network string) NodeInfo { - return NodeInfo{ - VersionSet: testVersionSet(), - NetAddress: NewNetAddressFromIPPort(id, net.ParseIP("127.0.0.1"), 0), - Network: network, - Software: "p2ptest", - Version: "v1.2.3-rc.0-deadbeef", - Channels: []byte{testCh}, - Moniker: name, - Other: NodeInfoOther{ - TxIndex: "on", - RPCAddress: fmt.Sprintf("127.0.0.1:%d", 0), - }, - } -} diff --git a/tm2/pkg/p2p/transport.go b/tm2/pkg/p2p/transport.go index 5bfae9e52b8..9edef9a15e5 100644 --- a/tm2/pkg/p2p/transport.go +++ b/tm2/pkg/p2p/transport.go @@ -3,144 +3,64 @@ package p2p import ( "context" "fmt" + "io" + "log/slog" "net" - "strconv" + "sync" "time" "github.com/gnolang/gno/tm2/pkg/amino" "github.com/gnolang/gno/tm2/pkg/crypto" "github.com/gnolang/gno/tm2/pkg/errors" "github.com/gnolang/gno/tm2/pkg/p2p/conn" + "github.com/gnolang/gno/tm2/pkg/p2p/types" + "golang.org/x/sync/errgroup" ) -const ( - defaultDialTimeout = time.Second - defaultFilterTimeout = 5 * time.Second - defaultHandshakeTimeout = 3 * time.Second -) - -// IPResolver is a behaviour subset of net.Resolver. -type IPResolver interface { - LookupIPAddr(context.Context, string) ([]net.IPAddr, error) -} - -// accept is the container to carry the upgraded connection and NodeInfo from an -// asynchronously running routine to the Accept method. -type accept struct { - netAddr *NetAddress - conn net.Conn - nodeInfo NodeInfo - err error -} - -// peerConfig is used to bundle data we need to fully setup a Peer with an -// MConn, provided by the caller of Accept and Dial (currently the Switch). This -// a temporary measure until reactor setup is less dynamic and we introduce the -// concept of PeerBehaviour to communicate about significant Peer lifecycle -// events. -// TODO(xla): Refactor out with more static Reactor setup and PeerBehaviour. -type peerConfig struct { - chDescs []*conn.ChannelDescriptor - onPeerError func(Peer, interface{}) - outbound bool - // isPersistent allows you to set a function, which, given socket address - // (for outbound peers) OR self-reported address (for inbound peers), tells - // if the peer is persistent or not. - isPersistent func(*NetAddress) bool - reactorsByCh map[byte]Reactor -} - -// Transport emits and connects to Peers. The implementation of Peer is left to -// the transport. Each transport is also responsible to filter establishing -// peers specific to its domain. -type Transport interface { - // Listening address. - NetAddress() NetAddress - - // Accept returns a newly connected Peer. - Accept(peerConfig) (Peer, error) - - // Dial connects to the Peer for the address. - Dial(NetAddress, peerConfig) (Peer, error) - - // Cleanup any resources associated with Peer. - Cleanup(Peer) -} - -// TransportLifecycle bundles the methods for callers to control start and stop -// behaviour. -type TransportLifecycle interface { - Close() error - Listen(NetAddress) error -} - -// ConnFilterFunc to be implemented by filter hooks after a new connection has -// been established. The set of existing connections is passed along together -// with all resolved IPs for the new connection. -type ConnFilterFunc func(ConnSet, net.Conn, []net.IP) error - -// ConnDuplicateIPFilter resolves and keeps all ips for an incoming connection -// and refuses new ones if they come from a known ip. -func ConnDuplicateIPFilter() ConnFilterFunc { - return func(cs ConnSet, c net.Conn, ips []net.IP) error { - for _, ip := range ips { - if cs.HasIP(ip) { - return RejectedError{ - conn: c, - err: fmt.Errorf("IP<%v> already connected", ip), - isDuplicate: true, - } - } - } +// defaultHandshakeTimeout is the timeout for the STS handshaking protocol +const defaultHandshakeTimeout = 3 * time.Second - return nil - } -} +var ( + errTransportClosed = errors.New("transport is closed") + errTransportInactive = errors.New("transport is inactive") + errDuplicateConnection = errors.New("duplicate peer connection") + errPeerIDNodeInfoMismatch = errors.New("connection ID does not match node info ID") + errPeerIDDialMismatch = errors.New("connection ID does not match dialed ID") + errIncompatibleNodeInfo = errors.New("incompatible node info") +) -// MultiplexTransportOption sets an optional parameter on the -// MultiplexTransport. -type MultiplexTransportOption func(*MultiplexTransport) +type connUpgradeFn func(io.ReadWriteCloser, crypto.PrivKey) (*conn.SecretConnection, error) -// MultiplexTransportConnFilters sets the filters for rejection new connections. -func MultiplexTransportConnFilters( - filters ...ConnFilterFunc, -) MultiplexTransportOption { - return func(mt *MultiplexTransport) { mt.connFilters = filters } -} +type secretConn interface { + net.Conn -// MultiplexTransportFilterTimeout sets the timeout waited for filter calls to -// return. -func MultiplexTransportFilterTimeout( - timeout time.Duration, -) MultiplexTransportOption { - return func(mt *MultiplexTransport) { mt.filterTimeout = timeout } + RemotePubKey() crypto.PubKey } -// MultiplexTransportResolver sets the Resolver used for ip lookups, defaults to -// net.DefaultResolver. -func MultiplexTransportResolver(resolver IPResolver) MultiplexTransportOption { - return func(mt *MultiplexTransport) { mt.resolver = resolver } +// peerInfo is a wrapper for an unverified peer connection +type peerInfo struct { + addr *types.NetAddress // the dial address of the peer + conn net.Conn // the connection associated with the peer + nodeInfo types.NodeInfo // the relevant peer node info } // MultiplexTransport accepts and dials tcp connections and upgrades them to // multiplexed peers. type MultiplexTransport struct { - netAddr NetAddress - listener net.Listener + ctx context.Context + cancelFn context.CancelFunc - acceptc chan accept - closec chan struct{} + logger *slog.Logger - // Lookup table for duplicate ip and id checks. - conns ConnSet - connFilters []ConnFilterFunc + netAddr types.NetAddress // the node's P2P dial address, used for handshaking + nodeInfo types.NodeInfo // the node's P2P info, used for handshaking + nodeKey types.NodeKey // the node's private P2P key, used for handshaking - dialTimeout time.Duration - filterTimeout time.Duration - handshakeTimeout time.Duration - nodeInfo NodeInfo - nodeKey NodeKey - resolver IPResolver + listener net.Listener // listener for inbound peer connections + peerCh chan peerInfo // pipe for inbound peer connections + activeConns sync.Map // active peer connections (remote address -> nothing) + + connUpgradeFn connUpgradeFn // Upgrades the connection to a secret connection // TODO(xla): This config is still needed as we parameterize peerConn and // peer currently. All relevant configuration should be refactored into options @@ -148,439 +68,376 @@ type MultiplexTransport struct { mConfig conn.MConnConfig } -// Test multiplexTransport for interface completeness. -var ( - _ Transport = (*MultiplexTransport)(nil) - _ TransportLifecycle = (*MultiplexTransport)(nil) -) - // NewMultiplexTransport returns a tcp connected multiplexed peer. func NewMultiplexTransport( - nodeInfo NodeInfo, - nodeKey NodeKey, + nodeInfo types.NodeInfo, + nodeKey types.NodeKey, mConfig conn.MConnConfig, + logger *slog.Logger, ) *MultiplexTransport { return &MultiplexTransport{ - acceptc: make(chan accept), - closec: make(chan struct{}), - dialTimeout: defaultDialTimeout, - filterTimeout: defaultFilterTimeout, - handshakeTimeout: defaultHandshakeTimeout, - mConfig: mConfig, - nodeInfo: nodeInfo, - nodeKey: nodeKey, - conns: NewConnSet(), - resolver: net.DefaultResolver, + peerCh: make(chan peerInfo, 1), + mConfig: mConfig, + nodeInfo: nodeInfo, + nodeKey: nodeKey, + logger: logger, + connUpgradeFn: conn.MakeSecretConnection, } } -// NetAddress implements Transport. -func (mt *MultiplexTransport) NetAddress() NetAddress { +// NetAddress returns the transport's listen address (for p2p connections) +func (mt *MultiplexTransport) NetAddress() types.NetAddress { return mt.netAddr } -// Accept implements Transport. -func (mt *MultiplexTransport) Accept(cfg peerConfig) (Peer, error) { +// Accept waits for a verified inbound Peer to connect, and returns it [BLOCKING] +func (mt *MultiplexTransport) Accept(ctx context.Context, behavior PeerBehavior) (PeerConn, error) { + // Sanity check, no need to wait + // on an inactive transport + if mt.listener == nil { + return nil, errTransportInactive + } + select { - // This case should never have any side-effectful/blocking operations to - // ensure that quality peers are ready to be used. - case a := <-mt.acceptc: - if a.err != nil { - return nil, a.err + case <-ctx.Done(): + return nil, ctx.Err() + case info, ok := <-mt.peerCh: + if !ok { + return nil, errTransportClosed } - cfg.outbound = false - - return mt.wrapPeer(a.conn, a.nodeInfo, cfg, a.netAddr), nil - case <-mt.closec: - return nil, TransportClosedError{} + return mt.newMultiplexPeer(info, behavior, false) } } -// Dial implements Transport. +// Dial creates an outbound Peer connection, and +// verifies it (performs handshaking) [BLOCKING] func (mt *MultiplexTransport) Dial( - addr NetAddress, - cfg peerConfig, -) (Peer, error) { - c, err := addr.DialTimeout(mt.dialTimeout) + ctx context.Context, + addr types.NetAddress, + behavior PeerBehavior, +) (PeerConn, error) { + // Set a dial timeout for the connection + c, err := addr.DialContext(ctx) if err != nil { return nil, err } - // TODO(xla): Evaluate if we should apply filters if we explicitly dial. - if err := mt.filterConn(c); err != nil { - return nil, err - } - - secretConn, nodeInfo, err := mt.upgrade(c, &addr) + // Process the connection with expected ID + info, err := mt.processConn(c, addr.ID) if err != nil { - return nil, err - } - - cfg.outbound = true + // Close the net peer connection + _ = c.Close() - p := mt.wrapPeer(secretConn, nodeInfo, cfg, &addr) + return nil, fmt.Errorf("unable to process connection, %w", err) + } - return p, nil + return mt.newMultiplexPeer(info, behavior, true) } -// Close implements TransportLifecycle. +// Close stops the multiplex transport func (mt *MultiplexTransport) Close() error { - close(mt.closec) - - if mt.listener != nil { - return mt.listener.Close() + if mt.listener == nil { + return nil } - return nil + mt.cancelFn() + + return mt.listener.Close() } -// Listen implements TransportLifecycle. -func (mt *MultiplexTransport) Listen(addr NetAddress) error { +// Listen starts an active process of listening for incoming connections [NON-BLOCKING] +func (mt *MultiplexTransport) Listen(addr types.NetAddress) error { + // Reserve a port, and start listening ln, err := net.Listen("tcp", addr.DialString()) if err != nil { - return err + return fmt.Errorf("unable to listen on address, %w", err) } if addr.Port == 0 { // net.Listen on port 0 means the kernel will auto-allocate a port // - find out which one has been given to us. - _, p, err := net.SplitHostPort(ln.Addr().String()) - if err != nil { + tcpAddr, ok := ln.Addr().(*net.TCPAddr) + if !ok { return fmt.Errorf("error finding port (after listening on port 0): %w", err) } - pInt, _ := strconv.Atoi(p) - addr.Port = uint16(pInt) + + addr.Port = uint16(tcpAddr.Port) } + // Set up the context + mt.ctx, mt.cancelFn = context.WithCancel(context.Background()) + mt.netAddr = addr mt.listener = ln - go mt.acceptPeers() + // Run the routine for accepting + // incoming peer connections + go mt.runAcceptLoop() return nil } -func (mt *MultiplexTransport) acceptPeers() { +// runAcceptLoop runs the loop where incoming peers are: +// +// 1. accepted by the transport +// 2. filtered +// 3. upgraded (handshaked + verified) +func (mt *MultiplexTransport) runAcceptLoop() { + var wg sync.WaitGroup + + defer func() { + wg.Wait() // Wait for all process routines + + close(mt.peerCh) + }() + for { - c, err := mt.listener.Accept() - if err != nil { - // If Close() has been called, silently exit. - select { - case _, ok := <-mt.closec: - if !ok { - return - } - default: - // Transport is not closed - } + select { + case <-mt.ctx.Done(): + mt.logger.Debug("transport accept context closed") - mt.acceptc <- accept{err: err} return - } + default: + // Accept an incoming peer connection + c, err := mt.listener.Accept() + if err != nil { + mt.logger.Error( + "unable to accept p2p connection", + "err", err, + ) - // Connection upgrade and filtering should be asynchronous to avoid - // Head-of-line blocking[0]. - // Reference: https://github.com/tendermint/classic/issues/2047 - // - // [0] https://en.wikipedia.org/wiki/Head-of-line_blocking - go func(c net.Conn) { - defer func() { - if r := recover(); r != nil { - err := RejectedError{ - conn: c, - err: errors.New("recovered from panic: %v", r), - isAuthFailure: true, - } - select { - case mt.acceptc <- accept{err: err}: - case <-mt.closec: - // Give up if the transport was closed. - _ = c.Close() - return - } - } - }() - - var ( - nodeInfo NodeInfo - secretConn *conn.SecretConnection - netAddr *NetAddress - ) - - err := mt.filterConn(c) - if err == nil { - secretConn, nodeInfo, err = mt.upgrade(c, nil) - if err == nil { - addr := c.RemoteAddr() - id := secretConn.RemotePubKey().Address().ID() - netAddr = NewNetAddress(id, addr) - } + continue } - select { - case mt.acceptc <- accept{netAddr, secretConn, nodeInfo, err}: - // Make the upgraded peer available. - case <-mt.closec: - // Give up if the transport was closed. - _ = c.Close() - return - } - }(c) - } -} + // Process the new connection asynchronously + wg.Add(1) -// Cleanup removes the given address from the connections set and -// closes the connection. -func (mt *MultiplexTransport) Cleanup(p Peer) { - mt.conns.RemoveAddr(p.RemoteAddr()) - _ = p.CloseConn() -} + go func(c net.Conn) { + defer wg.Done() -func (mt *MultiplexTransport) cleanup(c net.Conn) error { - mt.conns.Remove(c) + info, err := mt.processConn(c, "") + if err != nil { + mt.logger.Error( + "unable to process p2p connection", + "err", err, + ) - return c.Close() -} + // Close the connection + _ = c.Close() -func (mt *MultiplexTransport) filterConn(c net.Conn) (err error) { - defer func() { - if err != nil { - _ = c.Close() - } - }() + return + } - // Reject if connection is already present. - if mt.conns.Has(c) { - return RejectedError{conn: c, isDuplicate: true} + select { + case mt.peerCh <- info: + case <-mt.ctx.Done(): + // Give up if the transport was closed. + _ = c.Close() + } + }(c) + } } +} - // Resolve ips for incoming conn. - ips, err := resolveIPs(mt.resolver, c) - if err != nil { - return err +// processConn handles the raw connection by upgrading it and verifying it +func (mt *MultiplexTransport) processConn(c net.Conn, expectedID types.ID) (peerInfo, error) { + dialAddr := c.RemoteAddr().String() + + // Check if the connection is a duplicate one + if _, exists := mt.activeConns.LoadOrStore(dialAddr, struct{}{}); exists { + return peerInfo{}, errDuplicateConnection } - errc := make(chan error, len(mt.connFilters)) + // Handshake with the peer, through STS + secretConn, nodeInfo, err := mt.upgradeAndVerifyConn(c) + if err != nil { + mt.activeConns.Delete(dialAddr) - for _, f := range mt.connFilters { - go func(f ConnFilterFunc, c net.Conn, ips []net.IP, errc chan<- error) { - errc <- f(mt.conns, c, ips) - }(f, c, ips, errc) + return peerInfo{}, fmt.Errorf("unable to upgrade connection, %w", err) } - for i := 0; i < cap(errc); i++ { - select { - case err := <-errc: - if err != nil { - return RejectedError{conn: c, err: err, isFiltered: true} - } - case <-time.After(mt.filterTimeout): - return FilterTimeoutError{} - } + // Grab the connection ID. + // At this point, the connection and information shared + // with the peer is considered valid, since full handshaking + // and verification took place + id := secretConn.RemotePubKey().Address().ID() + + // The reason the dial ID needs to be verified is because + // for outbound peers (peers the node dials), there is an expected peer ID + // when initializing the outbound connection, that can differ from the exchanged one. + // For inbound peers, the ID is whatever the peer exchanges during the + // handshaking process, and is verified separately + if !expectedID.IsZero() && id.String() != expectedID.String() { + mt.activeConns.Delete(dialAddr) + + return peerInfo{}, fmt.Errorf( + "%w (expected %q got %q)", + errPeerIDDialMismatch, + expectedID, + id, + ) } - mt.conns.Set(c, ips) + netAddr, _ := types.NewNetAddress(id, c.RemoteAddr()) - return nil + return peerInfo{ + addr: netAddr, + conn: secretConn, + nodeInfo: nodeInfo, + }, nil } -func (mt *MultiplexTransport) upgrade( - c net.Conn, - dialedAddr *NetAddress, -) (secretConn *conn.SecretConnection, nodeInfo NodeInfo, err error) { - defer func() { - if err != nil { - _ = mt.cleanup(c) - } - }() +// Remove removes the peer resources from the transport +func (mt *MultiplexTransport) Remove(p PeerConn) { + mt.activeConns.Delete(p.RemoteAddr().String()) +} - secretConn, err = upgradeSecretConn(c, mt.handshakeTimeout, mt.nodeKey.PrivKey) +// upgradeAndVerifyConn upgrades the connections (performs the handshaking process) +// and verifies that the connecting peer is valid +func (mt *MultiplexTransport) upgradeAndVerifyConn(c net.Conn) (secretConn, types.NodeInfo, error) { + // Upgrade to a secret connection. + // A secret connection is a connection that has passed + // an initial handshaking process, as defined by the STS + // protocol, and is considered to be secure and authentic + sc, err := mt.upgradeToSecretConn( + c, + defaultHandshakeTimeout, + mt.nodeKey.PrivKey, + ) if err != nil { - return nil, NodeInfo{}, RejectedError{ - conn: c, - err: fmt.Errorf("secret conn failed: %w", err), - isAuthFailure: true, - } + return nil, types.NodeInfo{}, fmt.Errorf("unable to upgrade p2p connection, %w", err) } - // For outgoing conns, ensure connection key matches dialed key. - connID := secretConn.RemotePubKey().Address().ID() - if dialedAddr != nil { - if dialedID := dialedAddr.ID; connID.String() != dialedID.String() { - return nil, NodeInfo{}, RejectedError{ - conn: c, - id: connID, - err: fmt.Errorf( - "conn.ID (%v) dialed ID (%v) mismatch", - connID, - dialedID, - ), - isAuthFailure: true, - } - } - } - - nodeInfo, err = handshake(secretConn, mt.handshakeTimeout, mt.nodeInfo) + // Exchange node information + nodeInfo, err := exchangeNodeInfo(sc, defaultHandshakeTimeout, mt.nodeInfo) if err != nil { - return nil, NodeInfo{}, RejectedError{ - conn: c, - err: fmt.Errorf("handshake failed: %w", err), - isAuthFailure: true, - } + return nil, types.NodeInfo{}, fmt.Errorf("unable to exchange node information, %w", err) } - if err := nodeInfo.Validate(); err != nil { - return nil, NodeInfo{}, RejectedError{ - conn: c, - err: err, - isNodeInfoInvalid: true, - } - } + // Ensure the connection ID matches the node's reported ID + connID := sc.RemotePubKey().Address().ID() - // Ensure connection key matches self reported key. if connID != nodeInfo.ID() { - return nil, NodeInfo{}, RejectedError{ - conn: c, - id: connID, - err: fmt.Errorf( - "conn.ID (%v) NodeInfo.ID (%v) mismatch", - connID, - nodeInfo.ID(), - ), - isAuthFailure: true, - } - } - - // Reject self. - if mt.nodeInfo.ID() == nodeInfo.ID() { - return nil, NodeInfo{}, RejectedError{ - addr: *NewNetAddress(nodeInfo.ID(), c.RemoteAddr()), - conn: c, - id: nodeInfo.ID(), - isSelf: true, - } + return nil, types.NodeInfo{}, fmt.Errorf( + "%w (expected %q got %q)", + errPeerIDNodeInfoMismatch, + connID.String(), + nodeInfo.ID().String(), + ) } - if err := mt.nodeInfo.CompatibleWith(nodeInfo); err != nil { - return nil, NodeInfo{}, RejectedError{ - conn: c, - err: err, - id: nodeInfo.ID(), - isIncompatible: true, - } + // Check compatibility with the node + if err = mt.nodeInfo.CompatibleWith(nodeInfo); err != nil { + return nil, types.NodeInfo{}, fmt.Errorf("%w, %w", errIncompatibleNodeInfo, err) } - return secretConn, nodeInfo, nil + return sc, nodeInfo, nil } -func (mt *MultiplexTransport) wrapPeer( - c net.Conn, - ni NodeInfo, - cfg peerConfig, - socketAddr *NetAddress, -) Peer { - persistent := false - if cfg.isPersistent != nil { - if cfg.outbound { - persistent = cfg.isPersistent(socketAddr) - } else { - selfReportedAddr := ni.NetAddress - persistent = cfg.isPersistent(selfReportedAddr) - } +// newMultiplexPeer creates a new multiplex Peer, using +// the provided Peer behavior and info +func (mt *MultiplexTransport) newMultiplexPeer( + info peerInfo, + behavior PeerBehavior, + isOutbound bool, +) (PeerConn, error) { + // Extract the host + host, _, err := net.SplitHostPort(info.conn.RemoteAddr().String()) + if err != nil { + return nil, fmt.Errorf("unable to extract peer host, %w", err) } - peerConn := newPeerConn( - cfg.outbound, - persistent, - c, - socketAddr, - ) + // Look up the IPs + ips, err := net.LookupIP(host) + if err != nil { + return nil, fmt.Errorf("unable to lookup peer IPs, %w", err) + } - p := newPeer( - peerConn, - mt.mConfig, - ni, - cfg.reactorsByCh, - cfg.chDescs, - cfg.onPeerError, - ) + // Wrap the info related to the connection + peerConn := &ConnInfo{ + Outbound: isOutbound, + Persistent: behavior.IsPersistentPeer(info.addr.ID), + Private: behavior.IsPrivatePeer(info.nodeInfo.ID()), + Conn: info.conn, + RemoteIP: ips[0], // IPv4 + SocketAddr: info.addr, + } + + // Create the info related to the multiplex connection + mConfig := &ConnConfig{ + MConfig: mt.mConfig, + ReactorsByCh: behavior.Reactors(), + ChDescs: behavior.ReactorChDescriptors(), + OnPeerError: behavior.HandlePeerError, + } - return p + return newPeer(peerConn, info.nodeInfo, mConfig), nil } -func handshake( - c net.Conn, +// exchangeNodeInfo performs a data swap, where node +// info is exchanged between the current node and a peer async +func exchangeNodeInfo( + c secretConn, timeout time.Duration, - nodeInfo NodeInfo, -) (NodeInfo, error) { + nodeInfo types.NodeInfo, +) (types.NodeInfo, error) { if err := c.SetDeadline(time.Now().Add(timeout)); err != nil { - return NodeInfo{}, err + return types.NodeInfo{}, err } var ( - errc = make(chan error, 2) - - peerNodeInfo NodeInfo + peerNodeInfo types.NodeInfo ourNodeInfo = nodeInfo ) - go func(errc chan<- error, c net.Conn) { + g, _ := errgroup.WithContext(context.Background()) + + g.Go(func() error { _, err := amino.MarshalSizedWriter(c, ourNodeInfo) - errc <- err - }(errc, c) - go func(errc chan<- error, c net.Conn) { + + return err + }) + + g.Go(func() error { _, err := amino.UnmarshalSizedReader( c, &peerNodeInfo, - int64(MaxNodeInfoSize()), + types.MaxNodeInfoSize, ) - errc <- err - }(errc, c) - for i := 0; i < cap(errc); i++ { - err := <-errc - if err != nil { - return NodeInfo{}, err - } + return err + }) + + if err := g.Wait(); err != nil { + return types.NodeInfo{}, err + } + + // Validate the received node information + if err := nodeInfo.Validate(); err != nil { + return types.NodeInfo{}, fmt.Errorf("unable to validate node info, %w", err) } return peerNodeInfo, c.SetDeadline(time.Time{}) } -func upgradeSecretConn( +// upgradeToSecretConn takes an active TCP connection, +// and upgrades it to a verified, handshaked connection through +// the STS protocol +func (mt *MultiplexTransport) upgradeToSecretConn( c net.Conn, timeout time.Duration, privKey crypto.PrivKey, -) (*conn.SecretConnection, error) { +) (secretConn, error) { if err := c.SetDeadline(time.Now().Add(timeout)); err != nil { return nil, err } - sc, err := conn.MakeSecretConnection(c, privKey) + // Handshake (STS) + sc, err := mt.connUpgradeFn(c, privKey) if err != nil { return nil, err } return sc, sc.SetDeadline(time.Time{}) } - -func resolveIPs(resolver IPResolver, c net.Conn) ([]net.IP, error) { - host, _, err := net.SplitHostPort(c.RemoteAddr().String()) - if err != nil { - return nil, err - } - - addrs, err := resolver.LookupIPAddr(context.Background(), host) - if err != nil { - return nil, err - } - - ips := []net.IP{} - - for _, addr := range addrs { - ips = append(ips, addr.IP) - } - - return ips, nil -} diff --git a/tm2/pkg/p2p/transport_test.go b/tm2/pkg/p2p/transport_test.go index 63b1c26e666..3eb3264ec2b 100644 --- a/tm2/pkg/p2p/transport_test.go +++ b/tm2/pkg/p2p/transport_test.go @@ -1,650 +1,519 @@ package p2p import ( + "context" "fmt" - "math/rand" "net" - "reflect" "testing" "time" - "github.com/gnolang/gno/tm2/pkg/amino" - "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" + "github.com/gnolang/gno/tm2/pkg/log" "github.com/gnolang/gno/tm2/pkg/p2p/conn" - "github.com/gnolang/gno/tm2/pkg/testutils" + "github.com/gnolang/gno/tm2/pkg/p2p/types" + "github.com/gnolang/gno/tm2/pkg/versionset" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -var defaultNodeName = "host_peer" - -func emptyNodeInfo() NodeInfo { - return NodeInfo{} -} - -// newMultiplexTransport returns a tcp connected multiplexed peer -// using the default MConnConfig. It's a convenience function used -// for testing. -func newMultiplexTransport( - nodeInfo NodeInfo, - nodeKey NodeKey, -) *MultiplexTransport { - return NewMultiplexTransport( - nodeInfo, nodeKey, conn.DefaultMConnConfig(), - ) -} - -func TestTransportMultiplexConnFilter(t *testing.T) { - t.Parallel() - - mt := newMultiplexTransport( - emptyNodeInfo(), - NodeKey{ - PrivKey: ed25519.GenPrivKey(), - }, - ) - id := mt.nodeKey.ID() - - MultiplexTransportConnFilters( - func(_ ConnSet, _ net.Conn, _ []net.IP) error { return nil }, - func(_ ConnSet, _ net.Conn, _ []net.IP) error { return nil }, - func(_ ConnSet, _ net.Conn, _ []net.IP) error { - return fmt.Errorf("rejected") - }, - )(mt) - - addr, err := NewNetAddressFromString(NetAddressString(id, "127.0.0.1:0")) - if err != nil { - t.Fatal(err) - } - - if err := mt.Listen(*addr); err != nil { - t.Fatal(err) - } +// generateNetAddr generates dummy net addresses +func generateNetAddr(t *testing.T, count int) []*types.NetAddress { + t.Helper() - errc := make(chan error) + addrs := make([]*types.NetAddress, 0, count) - go func() { - addr := NewNetAddress(id, mt.listener.Addr()) + for i := 0; i < count; i++ { + key := types.GenerateNodeKey() - _, err := addr.Dial() - if err != nil { - errc <- err - return - } + // Grab a random port + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) - close(errc) - }() + addr, err := types.NewNetAddress(key.ID(), ln.Addr()) + require.NoError(t, err) - if err := <-errc; err != nil { - t.Errorf("connection failed: %v", err) + addrs = append(addrs, addr) } - _, err = mt.Accept(peerConfig{}) - if err, ok := err.(RejectedError); ok { - if !err.IsFiltered() { - t.Errorf("expected peer to be filtered") - } - } else { - t.Errorf("expected RejectedError") - } + return addrs } -func TestTransportMultiplexConnFilterTimeout(t *testing.T) { +func TestMultiplexTransport_NetAddress(t *testing.T) { t.Parallel() - mt := newMultiplexTransport( - emptyNodeInfo(), - NodeKey{ - PrivKey: ed25519.GenPrivKey(), - }, - ) - id := mt.nodeKey.ID() - - MultiplexTransportFilterTimeout(5 * time.Millisecond)(mt) - MultiplexTransportConnFilters( - func(_ ConnSet, _ net.Conn, _ []net.IP) error { - time.Sleep(100 * time.Millisecond) - return nil - }, - )(mt) - - addr, err := NewNetAddressFromString(NetAddressString(id, "127.0.0.1:0")) - if err != nil { - t.Fatal(err) - } - - if err := mt.Listen(*addr); err != nil { - t.Fatal(err) - } - - errc := make(chan error) + t.Run("transport not active", func(t *testing.T) { + t.Parallel() - go func() { - addr := NewNetAddress(id, mt.listener.Addr()) - - _, err := addr.Dial() - if err != nil { - errc <- err - return - } - - close(errc) - }() - - if err := <-errc; err != nil { - t.Errorf("connection failed: %v", err) - } - - _, err = mt.Accept(peerConfig{}) - if _, ok := err.(FilterTimeoutError); !ok { - t.Errorf("expected FilterTimeoutError") - } -} - -func TestTransportMultiplexAcceptMultiple(t *testing.T) { - t.Parallel() + var ( + ni = types.NodeInfo{} + nk = types.NodeKey{} + mCfg = conn.DefaultMConnConfig() + logger = log.NewNoopLogger() + ) - mt := testSetupMultiplexTransport(t) - laddr := NewNetAddress(mt.nodeKey.ID(), mt.listener.Addr()) + transport := NewMultiplexTransport(ni, nk, mCfg, logger) + addr := transport.NetAddress() - var ( - seed = rand.New(rand.NewSource(time.Now().UnixNano())) - nDialers = seed.Intn(64) + 64 - errc = make(chan error, nDialers) - ) + assert.Error(t, addr.Validate()) + }) - // Setup dialers. - for i := 0; i < nDialers; i++ { - go testDialer(*laddr, errc) - } + t.Run("active transport on random port", func(t *testing.T) { + t.Parallel() - // Catch connection errors. - for i := 0; i < nDialers; i++ { - if err := <-errc; err != nil { - t.Fatal(err) - } - } + var ( + ni = types.NodeInfo{} + nk = types.NodeKey{} + mCfg = conn.DefaultMConnConfig() + logger = log.NewNoopLogger() + addr = generateNetAddr(t, 1)[0] + ) - ps := []Peer{} + addr.Port = 0 // random port - // Accept all peers. - for i := 0; i < cap(errc); i++ { - p, err := mt.Accept(peerConfig{}) - if err != nil { - t.Fatal(err) - } + transport := NewMultiplexTransport(ni, nk, mCfg, logger) - if err := p.Start(); err != nil { - t.Fatal(err) - } + require.NoError(t, transport.Listen(*addr)) + defer func() { + require.NoError(t, transport.Close()) + }() - ps = append(ps, p) - } + netAddr := transport.NetAddress() + assert.False(t, netAddr.Equals(*addr)) + assert.NoError(t, netAddr.Validate()) + }) - if have, want := len(ps), cap(errc); have != want { - t.Errorf("have %v, want %v", have, want) - } + t.Run("active transport on specific port", func(t *testing.T) { + t.Parallel() - // Stop all peers. - for _, p := range ps { - if err := p.Stop(); err != nil { - t.Fatal(err) - } - } + var ( + ni = types.NodeInfo{} + nk = types.NodeKey{} + mCfg = conn.DefaultMConnConfig() + logger = log.NewNoopLogger() + addr = generateNetAddr(t, 1)[0] + ) - if err := mt.Close(); err != nil { - t.Errorf("close errored: %v", err) - } -} + addr.Port = 4123 // specific port -func testDialer(dialAddr NetAddress, errc chan error) { - var ( - pv = ed25519.GenPrivKey() - dialer = newMultiplexTransport( - testNodeInfo(pv.PubKey().Address().ID(), defaultNodeName), - NodeKey{ - PrivKey: pv, - }, - ) - ) + transport := NewMultiplexTransport(ni, nk, mCfg, logger) - _, err := dialer.Dial(dialAddr, peerConfig{}) - if err != nil { - errc <- err - return - } + require.NoError(t, transport.Listen(*addr)) + defer func() { + require.NoError(t, transport.Close()) + }() - // Signal that the connection was established. - errc <- nil + netAddr := transport.NetAddress() + assert.True(t, netAddr.Equals(*addr)) + assert.NoError(t, netAddr.Validate()) + }) } -func TestFlappyTransportMultiplexAcceptNonBlocking(t *testing.T) { +func TestMultiplexTransport_Accept(t *testing.T) { t.Parallel() - testutils.FilterStability(t, testutils.Flappy) + t.Run("inactive transport", func(t *testing.T) { + t.Parallel() - mt := testSetupMultiplexTransport(t) - - var ( - fastNodePV = ed25519.GenPrivKey() - fastNodeInfo = testNodeInfo(fastNodePV.PubKey().Address().ID(), "fastnode") - errc = make(chan error) - fastc = make(chan struct{}) - slowc = make(chan struct{}) - ) - - // Simulate slow Peer. - go func() { - addr := NewNetAddress(mt.nodeKey.ID(), mt.listener.Addr()) - - c, err := addr.Dial() - if err != nil { - errc <- err - return - } - - close(slowc) + var ( + ni = types.NodeInfo{} + nk = types.NodeKey{} + mCfg = conn.DefaultMConnConfig() + logger = log.NewNoopLogger() + ) - select { - case <-fastc: - // Fast peer connected. - case <-time.After(100 * time.Millisecond): - // We error if the fast peer didn't succeed. - errc <- fmt.Errorf("Fast peer timed out") - } + transport := NewMultiplexTransport(ni, nk, mCfg, logger) - sc, err := upgradeSecretConn(c, 100*time.Millisecond, ed25519.GenPrivKey()) - if err != nil { - errc <- err - return - } + p, err := transport.Accept(context.Background(), nil) - _, err = handshake(sc, 100*time.Millisecond, - testNodeInfo( - ed25519.GenPrivKey().PubKey().Address().ID(), - "slow_peer", - )) - if err != nil { - errc <- err - return - } - }() + assert.Nil(t, p) + assert.ErrorIs( + t, + err, + errTransportInactive, + ) + }) - // Simulate fast Peer. - go func() { - <-slowc + t.Run("transport closed", func(t *testing.T) { + t.Parallel() - dialer := newMultiplexTransport( - fastNodeInfo, - NodeKey{ - PrivKey: fastNodePV, - }, + var ( + ni = types.NodeInfo{} + nk = types.NodeKey{} + mCfg = conn.DefaultMConnConfig() + logger = log.NewNoopLogger() + addr = generateNetAddr(t, 1)[0] ) - addr := NewNetAddress(mt.nodeKey.ID(), mt.listener.Addr()) - _, err := dialer.Dial(*addr, peerConfig{}) - if err != nil { - errc <- err - return - } + addr.Port = 0 - close(errc) - close(fastc) - }() + transport := NewMultiplexTransport(ni, nk, mCfg, logger) - if err := <-errc; err != nil { - t.Errorf("connection failed: %v", err) - } + // Start the transport + require.NoError(t, transport.Listen(*addr)) - p, err := mt.Accept(peerConfig{}) - if err != nil { - t.Fatal(err) - } + // Stop the transport + require.NoError(t, transport.Close()) - if have, want := p.NodeInfo(), fastNodeInfo; !reflect.DeepEqual(have, want) { - t.Errorf("have %v, want %v", have, want) - } -} - -func TestTransportMultiplexValidateNodeInfo(t *testing.T) { - t.Parallel() + p, err := transport.Accept(context.Background(), nil) - mt := testSetupMultiplexTransport(t) + assert.Nil(t, p) + assert.ErrorIs( + t, + err, + errTransportClosed, + ) + }) - errc := make(chan error) + t.Run("context canceled", func(t *testing.T) { + t.Parallel() - go func() { var ( - pv = ed25519.GenPrivKey() - dialer = newMultiplexTransport( - testNodeInfo(pv.PubKey().Address().ID(), ""), // Should not be empty - NodeKey{ - PrivKey: pv, - }, - ) + ni = types.NodeInfo{} + nk = types.NodeKey{} + mCfg = conn.DefaultMConnConfig() + logger = log.NewNoopLogger() + addr = generateNetAddr(t, 1)[0] ) - addr := NewNetAddress(mt.nodeKey.ID(), mt.listener.Addr()) + addr.Port = 0 - _, err := dialer.Dial(*addr, peerConfig{}) - if err != nil { - errc <- err - return - } + transport := NewMultiplexTransport(ni, nk, mCfg, logger) - close(errc) - }() + // Start the transport + require.NoError(t, transport.Listen(*addr)) - if err := <-errc; err != nil { - t.Errorf("connection failed: %v", err) - } + ctx, cancelFn := context.WithCancel(context.Background()) + cancelFn() - _, err := mt.Accept(peerConfig{}) - if err, ok := err.(RejectedError); ok { - if !err.IsNodeInfoInvalid() { - t.Errorf("expected NodeInfo to be invalid") - } - } else { - t.Errorf("expected RejectedError") - } -} + p, err := transport.Accept(ctx, nil) -func TestTransportMultiplexRejectMismatchID(t *testing.T) { - t.Parallel() + assert.Nil(t, p) + assert.ErrorIs( + t, + err, + context.Canceled, + ) + }) - mt := testSetupMultiplexTransport(t) + t.Run("peer ID mismatch", func(t *testing.T) { + t.Parallel() - errc := make(chan error) + var ( + network = "dev" + mCfg = conn.DefaultMConnConfig() + logger = log.NewNoopLogger() + keys = []*types.NodeKey{ + types.GenerateNodeKey(), + types.GenerateNodeKey(), + } - go func() { - dialer := newMultiplexTransport( - testNodeInfo( - ed25519.GenPrivKey().PubKey().Address().ID(), "dialer", - ), - NodeKey{ - PrivKey: ed25519.GenPrivKey(), - }, + peerBehavior = &reactorPeerBehavior{ + chDescs: make([]*conn.ChannelDescriptor, 0), + reactorsByCh: make(map[byte]Reactor), + handlePeerErrFn: func(_ PeerConn, err error) { + require.NoError(t, err) + }, + isPersistentPeerFn: func(_ types.ID) bool { + return false + }, + isPrivatePeerFn: func(_ types.ID) bool { + return false + }, + } ) - addr := NewNetAddress(mt.nodeKey.ID(), mt.listener.Addr()) - _, err := dialer.Dial(*addr, peerConfig{}) - if err != nil { - errc <- err - return - } + peers := make([]*MultiplexTransport, 0, len(keys)) - close(errc) - }() + for index, key := range keys { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + require.NoError(t, err) - if err := <-errc; err != nil { - t.Errorf("connection failed: %v", err) - } + id := key.ID() - _, err := mt.Accept(peerConfig{}) - if err, ok := err.(RejectedError); ok { - if !err.IsAuthFailure() { - t.Errorf("expected auth failure") - } - } else { - t.Errorf("expected RejectedError") - } -} + if index%1 == 0 { + // Hijack the key value + id = types.GenerateNodeKey().ID() + } -func TestTransportMultiplexDialRejectWrongID(t *testing.T) { - t.Parallel() + na, err := types.NewNetAddress(id, addr) + require.NoError(t, err) - mt := testSetupMultiplexTransport(t) + ni := types.NodeInfo{ + Network: network, // common network + PeerID: id, + Version: "v1.0.0-rc.0", + Moniker: fmt.Sprintf("node-%d", index), + VersionSet: make(versionset.VersionSet, 0), // compatible version set + Channels: []byte{42}, // common channel + } - var ( - pv = ed25519.GenPrivKey() - dialer = newMultiplexTransport( - testNodeInfo(pv.PubKey().Address().ID(), ""), // Should not be empty - NodeKey{ - PrivKey: pv, - }, - ) - ) + // Create a fresh transport + tr := NewMultiplexTransport(ni, *key, mCfg, logger) - wrongID := ed25519.GenPrivKey().PubKey().Address().ID() - addr := NewNetAddress(wrongID, mt.listener.Addr()) + // Start the transport + require.NoError(t, tr.Listen(*na)) - _, err := dialer.Dial(*addr, peerConfig{}) - if err != nil { - t.Logf("connection failed: %v", err) - if err, ok := err.(RejectedError); ok { - if !err.IsAuthFailure() { - t.Errorf("expected auth failure") - } - } else { - t.Errorf("expected RejectedError") + t.Cleanup(func() { + assert.NoError(t, tr.Close()) + }) + + peers = append( + peers, + tr, + ) } - } -} -func TestTransportMultiplexRejectIncompatible(t *testing.T) { - t.Parallel() + // Make peer 1 --dial--> peer 2, and handshake. + // This "upgrade" should fail because the peer shared a different + // peer ID than what they actually used for the secret connection + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() - mt := testSetupMultiplexTransport(t) + p, err := peers[0].Dial(ctx, peers[1].netAddr, peerBehavior) + assert.ErrorIs(t, err, errPeerIDNodeInfoMismatch) + require.Nil(t, p) + }) - errc := make(chan error) + t.Run("incompatible peers", func(t *testing.T) { + t.Parallel() - go func() { var ( - pv = ed25519.GenPrivKey() - dialer = newMultiplexTransport( - testNodeInfoWithNetwork(pv.PubKey().Address().ID(), "dialer", "incompatible-network"), - NodeKey{ - PrivKey: pv, + network = "dev" + mCfg = conn.DefaultMConnConfig() + logger = log.NewNoopLogger() + keys = []*types.NodeKey{ + types.GenerateNodeKey(), + types.GenerateNodeKey(), + } + + peerBehavior = &reactorPeerBehavior{ + chDescs: make([]*conn.ChannelDescriptor, 0), + reactorsByCh: make(map[byte]Reactor), + handlePeerErrFn: func(_ PeerConn, err error) { + require.NoError(t, err) }, - ) + isPersistentPeerFn: func(_ types.ID) bool { + return false + }, + isPrivatePeerFn: func(_ types.ID) bool { + return false + }, + } ) - addr := NewNetAddress(mt.nodeKey.ID(), mt.listener.Addr()) - _, err := dialer.Dial(*addr, peerConfig{}) - if err != nil { - errc <- err - return - } + peers := make([]*MultiplexTransport, 0, len(keys)) - close(errc) - }() + for index, key := range keys { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + require.NoError(t, err) - _, err := mt.Accept(peerConfig{}) - if err, ok := err.(RejectedError); ok { - if !err.IsIncompatible() { - t.Errorf("expected to reject incompatible") - } - } else { - t.Errorf("expected RejectedError") - } -} + id := key.ID() -func TestTransportMultiplexRejectSelf(t *testing.T) { - t.Parallel() + na, err := types.NewNetAddress(id, addr) + require.NoError(t, err) - mt := testSetupMultiplexTransport(t) + chainID := network - errc := make(chan error) + if index%2 == 0 { + chainID = "totally-random-network" + } - go func() { - addr := NewNetAddress(mt.nodeKey.ID(), mt.listener.Addr()) + ni := types.NodeInfo{ + Network: chainID, + PeerID: id, + Version: "v1.0.0-rc.0", + Moniker: fmt.Sprintf("node-%d", index), + VersionSet: make(versionset.VersionSet, 0), // compatible version set + Channels: []byte{42}, // common channel + } - _, err := mt.Dial(*addr, peerConfig{}) - if err != nil { - errc <- err - return - } + // Create a fresh transport + tr := NewMultiplexTransport(ni, *key, mCfg, logger) - close(errc) - }() + // Start the transport + require.NoError(t, tr.Listen(*na)) - if err := <-errc; err != nil { - if err, ok := err.(RejectedError); ok { - if !err.IsSelf() { - t.Errorf("expected to reject self, got: %v", err) - } - } else { - t.Errorf("expected RejectedError") - } - } else { - t.Errorf("expected connection failure") - } + t.Cleanup(func() { + assert.NoError(t, tr.Close()) + }) - _, err := mt.Accept(peerConfig{}) - if err, ok := err.(RejectedError); ok { - if !err.IsSelf() { - t.Errorf("expected to reject self, got: %v", err) + peers = append( + peers, + tr, + ) } - } else { - t.Errorf("expected RejectedError") - } -} -func TestTransportConnDuplicateIPFilter(t *testing.T) { - t.Parallel() + // Make peer 1 --dial--> peer 2, and handshake. + // This "upgrade" should fail because the peer shared a different + // peer ID than what they actually used for the secret connection + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() - filter := ConnDuplicateIPFilter() + p, err := peers[0].Dial(ctx, peers[1].netAddr, peerBehavior) + assert.ErrorIs(t, err, errIncompatibleNodeInfo) + require.Nil(t, p) + }) - if err := filter(nil, &testTransportConn{}, nil); err != nil { - t.Fatal(err) - } + t.Run("dialed peer ID mismatch", func(t *testing.T) { + t.Parallel() - var ( - c = &testTransportConn{} - cs = NewConnSet() - ) + var ( + network = "dev" + mCfg = conn.DefaultMConnConfig() + logger = log.NewNoopLogger() + keys = []*types.NodeKey{ + types.GenerateNodeKey(), + types.GenerateNodeKey(), + } - cs.Set(c, []net.IP{ - {10, 0, 10, 1}, - {10, 0, 10, 2}, - {10, 0, 10, 3}, - }) + peerBehavior = &reactorPeerBehavior{ + chDescs: make([]*conn.ChannelDescriptor, 0), + reactorsByCh: make(map[byte]Reactor), + handlePeerErrFn: func(_ PeerConn, err error) { + require.NoError(t, err) + }, + isPersistentPeerFn: func(_ types.ID) bool { + return false + }, + isPrivatePeerFn: func(_ types.ID) bool { + return false + }, + } + ) - if err := filter(cs, c, []net.IP{ - {10, 0, 10, 2}, - }); err == nil { - t.Errorf("expected Peer to be rejected as duplicate") - } -} + peers := make([]*MultiplexTransport, 0, len(keys)) -func TestTransportHandshake(t *testing.T) { - t.Parallel() + for index, key := range keys { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + require.NoError(t, err) - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } + na, err := types.NewNetAddress(key.ID(), addr) + require.NoError(t, err) - var ( - peerPV = ed25519.GenPrivKey() - peerNodeInfo = testNodeInfo(peerPV.PubKey().Address().ID(), defaultNodeName) - ) + ni := types.NodeInfo{ + Network: network, // common network + PeerID: key.ID(), + Version: "v1.0.0-rc.0", + Moniker: fmt.Sprintf("node-%d", index), + VersionSet: make(versionset.VersionSet, 0), // compatible version set + Channels: []byte{42}, // common channel + } - go func() { - c, err := net.Dial(ln.Addr().Network(), ln.Addr().String()) - if err != nil { - t.Error(err) - return - } + // Create a fresh transport + tr := NewMultiplexTransport(ni, *key, mCfg, logger) - go func(c net.Conn) { - _, err := amino.MarshalSizedWriter(c, peerNodeInfo) - if err != nil { - t.Error(err) - } - }(c) - go func(c net.Conn) { - var ni NodeInfo - - _, err := amino.UnmarshalSizedReader( - c, - &ni, - int64(MaxNodeInfoSize()), + // Start the transport + require.NoError(t, tr.Listen(*na)) + + t.Cleanup(func() { + assert.NoError(t, tr.Close()) + }) + + peers = append( + peers, + tr, ) - if err != nil { - t.Error(err) - } - }(c) - }() + } - c, err := ln.Accept() - if err != nil { - t.Fatal(err) - } + // Make peer 1 --dial--> peer 2, and handshake + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() - ni, err := handshake(c, 100*time.Millisecond, emptyNodeInfo()) - if err != nil { - t.Fatal(err) - } + p, err := peers[0].Dial( + ctx, + types.NetAddress{ + ID: types.GenerateNodeKey().ID(), // mismatched ID + IP: peers[1].netAddr.IP, + Port: peers[1].netAddr.Port, + }, + peerBehavior, + ) + assert.ErrorIs(t, err, errPeerIDDialMismatch) + assert.Nil(t, p) + }) - if have, want := ni, peerNodeInfo; !reflect.DeepEqual(have, want) { - t.Errorf("have %v, want %v", have, want) - } -} + t.Run("valid peer accepted", func(t *testing.T) { + t.Parallel() -// create listener -func testSetupMultiplexTransport(t *testing.T) *MultiplexTransport { - t.Helper() + var ( + network = "dev" + mCfg = conn.DefaultMConnConfig() + logger = log.NewNoopLogger() + keys = []*types.NodeKey{ + types.GenerateNodeKey(), + types.GenerateNodeKey(), + } - var ( - pv = ed25519.GenPrivKey() - id = pv.PubKey().Address().ID() - mt = newMultiplexTransport( - testNodeInfo( - id, "transport", - ), - NodeKey{ - PrivKey: pv, - }, + peerBehavior = &reactorPeerBehavior{ + chDescs: make([]*conn.ChannelDescriptor, 0), + reactorsByCh: make(map[byte]Reactor), + handlePeerErrFn: func(_ PeerConn, err error) { + require.NoError(t, err) + }, + isPersistentPeerFn: func(_ types.ID) bool { + return false + }, + isPrivatePeerFn: func(_ types.ID) bool { + return false + }, + } ) - ) - addr, err := NewNetAddressFromString(NetAddressString(id, "127.0.0.1:0")) - if err != nil { - t.Fatal(err) - } + peers := make([]*MultiplexTransport, 0, len(keys)) - if err := mt.Listen(*addr); err != nil { - t.Fatal(err) - } + for index, key := range keys { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + require.NoError(t, err) - return mt -} + na, err := types.NewNetAddress(key.ID(), addr) + require.NoError(t, err) -type testTransportAddr struct{} + ni := types.NodeInfo{ + Network: network, // common network + PeerID: key.ID(), + Version: "v1.0.0-rc.0", + Moniker: fmt.Sprintf("node-%d", index), + VersionSet: make(versionset.VersionSet, 0), // compatible version set + Channels: []byte{42}, // common channel + } -func (a *testTransportAddr) Network() string { return "tcp" } -func (a *testTransportAddr) String() string { return "test.local:1234" } + // Create a fresh transport + tr := NewMultiplexTransport(ni, *key, mCfg, logger) -type testTransportConn struct{} + // Start the transport + require.NoError(t, tr.Listen(*na)) -func (c *testTransportConn) Close() error { - return fmt.Errorf("Close() not implemented") -} + t.Cleanup(func() { + assert.NoError(t, tr.Close()) + }) -func (c *testTransportConn) LocalAddr() net.Addr { - return &testTransportAddr{} -} + peers = append( + peers, + tr, + ) + } -func (c *testTransportConn) RemoteAddr() net.Addr { - return &testTransportAddr{} -} + // Make peer 1 --dial--> peer 2, and handshake + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() -func (c *testTransportConn) Read(_ []byte) (int, error) { - return -1, fmt.Errorf("Read() not implemented") -} + p, err := peers[0].Dial(ctx, peers[1].netAddr, peerBehavior) + require.NoError(t, err) + require.NotNil(t, p) -func (c *testTransportConn) SetDeadline(_ time.Time) error { - return fmt.Errorf("SetDeadline() not implemented") -} + // Make sure the new peer info is valid + assert.Equal(t, peers[1].netAddr.ID, p.ID()) -func (c *testTransportConn) SetReadDeadline(_ time.Time) error { - return fmt.Errorf("SetReadDeadline() not implemented") -} + assert.Equal(t, peers[1].nodeInfo.Channels, p.NodeInfo().Channels) + assert.Equal(t, peers[1].nodeInfo.Moniker, p.NodeInfo().Moniker) + assert.Equal(t, peers[1].nodeInfo.Network, p.NodeInfo().Network) -func (c *testTransportConn) SetWriteDeadline(_ time.Time) error { - return fmt.Errorf("SetWriteDeadline() not implemented") -} + // Attempt to dial again, expect the dial to fail + // because the connection is already active + dialedPeer, err := peers[0].Dial(ctx, peers[1].netAddr, peerBehavior) + require.ErrorIs(t, err, errDuplicateConnection) + assert.Nil(t, dialedPeer) -func (c *testTransportConn) Write(_ []byte) (int, error) { - return -1, fmt.Errorf("Write() not implemented") + // Remove the peer + peers[0].Remove(p) + }) } diff --git a/tm2/pkg/p2p/types.go b/tm2/pkg/p2p/types.go index 150325f52bb..d206a5af662 100644 --- a/tm2/pkg/p2p/types.go +++ b/tm2/pkg/p2p/types.go @@ -1,10 +1,115 @@ package p2p import ( + "context" + "net" + "github.com/gnolang/gno/tm2/pkg/p2p/conn" + "github.com/gnolang/gno/tm2/pkg/p2p/events" + "github.com/gnolang/gno/tm2/pkg/p2p/types" + "github.com/gnolang/gno/tm2/pkg/service" ) type ( ChannelDescriptor = conn.ChannelDescriptor ConnectionStatus = conn.ConnectionStatus ) + +// PeerConn is a wrapper for a connected peer +type PeerConn interface { + service.Service + + FlushStop() + + ID() types.ID // peer's cryptographic ID + RemoteIP() net.IP // remote IP of the connection + RemoteAddr() net.Addr // remote address of the connection + + IsOutbound() bool // did we dial the peer + IsPersistent() bool // do we redial this peer when we disconnect + IsPrivate() bool // do we share the peer + + CloseConn() error // close original connection + + NodeInfo() types.NodeInfo // peer's info + Status() ConnectionStatus + SocketAddr() *types.NetAddress // actual address of the socket + + Send(byte, []byte) bool + TrySend(byte, []byte) bool + + Set(string, any) + Get(string) any +} + +// PeerSet has a (immutable) subset of the methods of PeerSet. +type PeerSet interface { + Add(peer PeerConn) + Remove(key types.ID) bool + Has(key types.ID) bool + Get(key types.ID) PeerConn + List() []PeerConn + + NumInbound() uint64 // returns the number of connected inbound nodes + NumOutbound() uint64 // returns the number of connected outbound nodes +} + +// Transport handles peer dialing and connection acceptance. Additionally, +// it is also responsible for any custom connection mechanisms (like handshaking). +// Peers returned by the transport are considered to be verified and sound +type Transport interface { + // NetAddress returns the Transport's dial address + NetAddress() types.NetAddress + + // Accept returns a newly connected inbound peer + Accept(context.Context, PeerBehavior) (PeerConn, error) + + // Dial dials a peer, and returns it + Dial(context.Context, types.NetAddress, PeerBehavior) (PeerConn, error) + + // Remove drops any resources associated + // with the PeerConn in the transport + Remove(PeerConn) +} + +// Switch is the abstraction in the p2p module that handles +// and manages peer connections thorough a Transport +type Switch interface { + // Broadcast publishes data on the given channel, to all peers + Broadcast(chID byte, data []byte) + + // Peers returns the latest peer set + Peers() PeerSet + + // Subscribe subscribes to active switch events + Subscribe(filterFn events.EventFilter) (<-chan events.Event, func()) + + // StopPeerForError stops the peer with the given reason + StopPeerForError(peer PeerConn, err error) + + // DialPeers marks the given peers as ready for async dialing + DialPeers(peerAddrs ...*types.NetAddress) +} + +// PeerBehavior wraps the Reactor and MultiplexSwitch information a Transport would need when +// dialing or accepting new Peer connections. +// It is worth noting that the only reason why this information is required in the first place, +// is because Peers expose an API through which different TM modules can interact with them. +// In the future™, modules should not directly "Send" anything to Peers, but instead communicate through +// other mediums, such as the P2P module +type PeerBehavior interface { + // ReactorChDescriptors returns the Reactor channel descriptors + ReactorChDescriptors() []*conn.ChannelDescriptor + + // Reactors returns the node's active p2p Reactors (modules) + Reactors() map[byte]Reactor + + // HandlePeerError propagates a peer connection error for further processing + HandlePeerError(PeerConn, error) + + // IsPersistentPeer returns a flag indicating if the given peer is persistent + IsPersistentPeer(types.ID) bool + + // IsPrivatePeer returns a flag indicating if the given peer is private + IsPrivatePeer(types.ID) bool +} diff --git a/tm2/pkg/p2p/types/key.go b/tm2/pkg/p2p/types/key.go new file mode 100644 index 00000000000..bc45de709d8 --- /dev/null +++ b/tm2/pkg/p2p/types/key.go @@ -0,0 +1,113 @@ +package types + +import ( + "fmt" + "os" + + "github.com/gnolang/gno/tm2/pkg/amino" + "github.com/gnolang/gno/tm2/pkg/crypto" + "github.com/gnolang/gno/tm2/pkg/crypto/ed25519" + osm "github.com/gnolang/gno/tm2/pkg/os" +) + +// ID represents the cryptographically unique Peer ID +type ID = crypto.ID + +// NewIDFromStrings returns an array of ID's build using +// the provided strings +func NewIDFromStrings(idStrs []string) ([]ID, []error) { + var ( + ids = make([]ID, 0, len(idStrs)) + errs = make([]error, 0, len(idStrs)) + ) + + for _, idStr := range idStrs { + id := ID(idStr) + if err := id.Validate(); err != nil { + errs = append(errs, err) + + continue + } + + ids = append(ids, id) + } + + return ids, errs +} + +// NodeKey is the persistent peer key. +// It contains the nodes private key for authentication. +// NOTE: keep in sync with gno.land/cmd/gnoland/secrets.go +type NodeKey struct { + crypto.PrivKey `json:"priv_key"` // our priv key +} + +// ID returns the bech32 representation +// of the node's public p2p key, with +// the bech32 prefix +func (k NodeKey) ID() ID { + return k.PubKey().Address().ID() +} + +// LoadOrGenNodeKey attempts to load the NodeKey from the given filePath. +// If the file does not exist, it generates and saves a new NodeKey. +func LoadOrGenNodeKey(path string) (*NodeKey, error) { + // Check if the key exists + if osm.FileExists(path) { + // Load the node key + return LoadNodeKey(path) + } + + // Key is not present on path, + // generate a fresh one + nodeKey := GenerateNodeKey() + if err := saveNodeKey(path, nodeKey); err != nil { + return nil, fmt.Errorf("unable to save node key, %w", err) + } + + return nodeKey, nil +} + +// LoadNodeKey loads the node key from the given path +func LoadNodeKey(path string) (*NodeKey, error) { + // Load the key + jsonBytes, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("unable to read key, %w", err) + } + + var nodeKey NodeKey + + // Parse the key + if err = amino.UnmarshalJSON(jsonBytes, &nodeKey); err != nil { + return nil, fmt.Errorf("unable to JSON unmarshal node key, %w", err) + } + + return &nodeKey, nil +} + +// GenerateNodeKey generates a random +// node P2P key, based on ed25519 +func GenerateNodeKey() *NodeKey { + privKey := ed25519.GenPrivKey() + + return &NodeKey{ + PrivKey: privKey, + } +} + +// saveNodeKey saves the node key +func saveNodeKey(path string, nodeKey *NodeKey) error { + // Get Amino JSON + marshalledData, err := amino.MarshalJSONIndent(nodeKey, "", "\t") + if err != nil { + return fmt.Errorf("unable to marshal node key into JSON, %w", err) + } + + // Save the data to disk + if err := os.WriteFile(path, marshalledData, 0o644); err != nil { + return fmt.Errorf("unable to save node key to path, %w", err) + } + + return nil +} diff --git a/tm2/pkg/p2p/types/key_test.go b/tm2/pkg/p2p/types/key_test.go new file mode 100644 index 00000000000..5dc153b08c0 --- /dev/null +++ b/tm2/pkg/p2p/types/key_test.go @@ -0,0 +1,158 @@ +package types + +import ( + "encoding/json" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// generateKeys generates random node p2p keys +func generateKeys(t *testing.T, count int) []*NodeKey { + t.Helper() + + keys := make([]*NodeKey, count) + + for i := 0; i < count; i++ { + keys[i] = GenerateNodeKey() + } + + return keys +} + +func TestNodeKey_Generate(t *testing.T) { + t.Parallel() + + keys := generateKeys(t, 10) + + for _, key := range keys { + require.NotNil(t, key) + assert.NotNil(t, key.PrivKey) + + // Make sure all keys are unique + for _, keyInner := range keys { + if key.ID() == keyInner.ID() { + continue + } + + assert.False(t, key.Equals(keyInner)) + } + } +} + +func TestNodeKey_Load(t *testing.T) { + t.Parallel() + + t.Run("non-existing key", func(t *testing.T) { + t.Parallel() + + key, err := LoadNodeKey("definitely valid path") + + require.Nil(t, key) + assert.ErrorIs(t, err, os.ErrNotExist) + }) + + t.Run("invalid key format", func(t *testing.T) { + t.Parallel() + + // Generate a random path + path := fmt.Sprintf("%s/key.json", t.TempDir()) + + type random struct { + field string + } + + data, err := json.Marshal(&random{ + field: "random data", + }) + require.NoError(t, err) + + // Save the invalid data format + require.NoError(t, os.WriteFile(path, data, 0o644)) + + // Load the key, that's invalid + key, err := LoadNodeKey(path) + + require.NoError(t, err) + assert.Nil(t, key.PrivKey) + }) + + t.Run("valid key loaded", func(t *testing.T) { + t.Parallel() + + var ( + path = fmt.Sprintf("%s/key.json", t.TempDir()) + key = GenerateNodeKey() + ) + + // Save the key + require.NoError(t, saveNodeKey(path, key)) + + // Load the key, that's valid + loadedKey, err := LoadNodeKey(path) + require.NoError(t, err) + + assert.True(t, key.PrivKey.Equals(loadedKey.PrivKey)) + assert.Equal(t, key.ID(), loadedKey.ID()) + }) +} + +func TestNodeKey_ID(t *testing.T) { + t.Parallel() + + keys := generateKeys(t, 10) + + for _, key := range keys { + // Make sure the ID is valid + id := key.ID() + require.NotNil(t, id) + + assert.NoError(t, id.Validate()) + } +} + +func TestNodeKey_LoadOrGenNodeKey(t *testing.T) { + t.Parallel() + + t.Run("existing key loaded", func(t *testing.T) { + t.Parallel() + + var ( + path = fmt.Sprintf("%s/key.json", t.TempDir()) + key = GenerateNodeKey() + ) + + // Save the key + require.NoError(t, saveNodeKey(path, key)) + + loadedKey, err := LoadOrGenNodeKey(path) + require.NoError(t, err) + + // Make sure the key was not generated + assert.True(t, key.PrivKey.Equals(loadedKey.PrivKey)) + }) + + t.Run("fresh key generated", func(t *testing.T) { + t.Parallel() + + path := fmt.Sprintf("%s/key.json", t.TempDir()) + + // Make sure there is no key at the path + _, err := os.Stat(path) + require.ErrorIs(t, err, os.ErrNotExist) + + // Generate the fresh key + key, err := LoadOrGenNodeKey(path) + require.NoError(t, err) + + // Load the saved key + loadedKey, err := LoadOrGenNodeKey(path) + require.NoError(t, err) + + // Make sure the keys are the same + assert.True(t, key.PrivKey.Equals(loadedKey.PrivKey)) + }) +} diff --git a/tm2/pkg/p2p/netaddress.go b/tm2/pkg/p2p/types/netaddress.go similarity index 52% rename from tm2/pkg/p2p/netaddress.go rename to tm2/pkg/p2p/types/netaddress.go index 77f89b2a4b3..a43f90454ea 100644 --- a/tm2/pkg/p2p/netaddress.go +++ b/tm2/pkg/p2p/types/netaddress.go @@ -2,143 +2,156 @@ // Originally Copyright (c) 2013-2014 Conformal Systems LLC. // https://github.com/conformal/btcd/blob/master/LICENSE -package p2p +package types import ( - "flag" + "context" "fmt" "net" "strconv" "strings" - "time" "github.com/gnolang/gno/tm2/pkg/crypto" "github.com/gnolang/gno/tm2/pkg/errors" ) -type ID = crypto.ID +const ( + nilNetAddress = "" + badNetAddress = "" +) + +var ( + ErrInvalidTCPAddress = errors.New("invalid TCP address") + ErrUnsetIPAddress = errors.New("unset IP address") + ErrInvalidIP = errors.New("invalid IP address") + ErrUnspecifiedIP = errors.New("unspecified IP address") + ErrInvalidNetAddress = errors.New("invalid net address") + ErrEmptyHost = errors.New("empty host address") +) // NetAddress defines information about a peer on the network -// including its Address, IP address, and port. -// NOTE: NetAddress is not meant to be mutated due to memoization. -// @amino2: immutable XXX +// including its ID, IP address, and port type NetAddress struct { - ID ID `json:"id"` // authenticated identifier (TODO) - IP net.IP `json:"ip"` // part of "addr" - Port uint16 `json:"port"` // part of "addr" - - // TODO: - // Name string `json:"name"` // optional DNS name - - // memoize .String() - str string + ID ID `json:"id"` // unique peer identifier (public key address) + IP net.IP `json:"ip"` // the IP part of the dial address + Port uint16 `json:"port"` // the port part of the dial address } // NetAddressString returns id@addr. It strips the leading // protocol from protocolHostPort if it exists. func NetAddressString(id ID, protocolHostPort string) string { - addr := removeProtocolIfDefined(protocolHostPort) - return fmt.Sprintf("%s@%s", id, addr) + return fmt.Sprintf( + "%s@%s", + id, + removeProtocolIfDefined(protocolHostPort), + ) } // NewNetAddress returns a new NetAddress using the provided TCP -// address. When testing, other net.Addr (except TCP) will result in -// using 0.0.0.0:0. When normal run, other net.Addr (except TCP) will -// panic. Panics if ID is invalid. -// TODO: socks proxies? -func NewNetAddress(id ID, addr net.Addr) *NetAddress { +// address +func NewNetAddress(id ID, addr net.Addr) (*NetAddress, error) { + // Make sure the address is valid tcpAddr, ok := addr.(*net.TCPAddr) if !ok { - if flag.Lookup("test.v") == nil { // normal run - panic(fmt.Sprintf("Only TCPAddrs are supported. Got: %v", addr)) - } else { // in testing - netAddr := NewNetAddressFromIPPort("", net.IP("0.0.0.0"), 0) - netAddr.ID = id - return netAddr - } + return nil, ErrInvalidTCPAddress } + // Validate the ID if err := id.Validate(); err != nil { - panic(fmt.Sprintf("Invalid ID %v: %v (addr: %v)", id, err, addr)) + return nil, fmt.Errorf("unable to verify ID, %w", err) } - ip := tcpAddr.IP - port := uint16(tcpAddr.Port) - na := NewNetAddressFromIPPort("", ip, port) + na := NewNetAddressFromIPPort( + tcpAddr.IP, + uint16(tcpAddr.Port), + ) + + // Set the ID na.ID = id - return na + + return na, nil } // NewNetAddressFromString returns a new NetAddress using the provided address in // the form of "ID@IP:Port". // Also resolves the host if host is not an IP. -// Errors are of type ErrNetAddressXxx where Xxx is in (NoID, Invalid, Lookup) func NewNetAddressFromString(idaddr string) (*NetAddress, error) { - idaddr = removeProtocolIfDefined(idaddr) - spl := strings.Split(idaddr, "@") + var ( + prunedAddr = removeProtocolIfDefined(idaddr) + spl = strings.Split(prunedAddr, "@") + ) + if len(spl) != 2 { - return nil, NetAddressNoIDError{idaddr} + return nil, ErrInvalidNetAddress } - // get ID - id := crypto.ID(spl[0]) + var ( + id = crypto.ID(spl[0]) + addr = spl[1] + ) + + // Validate the ID if err := id.Validate(); err != nil { - return nil, NetAddressInvalidError{idaddr, err} + return nil, fmt.Errorf("unable to verify address ID, %w", err) } - addr := spl[1] - // get host and port + // Extract the host and port host, portStr, err := net.SplitHostPort(addr) if err != nil { - return nil, NetAddressInvalidError{addr, err} + return nil, fmt.Errorf("unable to split host and port, %w", err) } - if len(host) == 0 { - return nil, NetAddressInvalidError{ - addr, - errors.New("host is empty"), - } + + if host == "" { + return nil, ErrEmptyHost } ip := net.ParseIP(host) if ip == nil { ips, err := net.LookupIP(host) if err != nil { - return nil, NetAddressLookupError{host, err} + return nil, fmt.Errorf("unable to look up IP, %w", err) } + ip = ips[0] } port, err := strconv.ParseUint(portStr, 10, 16) if err != nil { - return nil, NetAddressInvalidError{portStr, err} + return nil, fmt.Errorf("unable to parse port %s, %w", portStr, err) } - na := NewNetAddressFromIPPort("", ip, uint16(port)) + na := NewNetAddressFromIPPort(ip, uint16(port)) na.ID = id + return na, nil } // NewNetAddressFromStrings returns an array of NetAddress'es build using // the provided strings. func NewNetAddressFromStrings(idaddrs []string) ([]*NetAddress, []error) { - netAddrs := make([]*NetAddress, 0) - errs := make([]error, 0) + var ( + netAddrs = make([]*NetAddress, 0, len(idaddrs)) + errs = make([]error, 0, len(idaddrs)) + ) + for _, addr := range idaddrs { netAddr, err := NewNetAddressFromString(addr) if err != nil { errs = append(errs, err) - } else { - netAddrs = append(netAddrs, netAddr) + + continue } + + netAddrs = append(netAddrs, netAddr) } + return netAddrs, errs } // NewNetAddressFromIPPort returns a new NetAddress using the provided IP // and port number. -func NewNetAddressFromIPPort(id ID, ip net.IP, port uint16) *NetAddress { +func NewNetAddressFromIPPort(ip net.IP, port uint16) *NetAddress { return &NetAddress{ - ID: id, IP: ip, Port: port, } @@ -146,88 +159,78 @@ func NewNetAddressFromIPPort(id ID, ip net.IP, port uint16) *NetAddress { // Equals reports whether na and other are the same addresses, // including their ID, IP, and Port. -func (na *NetAddress) Equals(other interface{}) bool { - if o, ok := other.(*NetAddress); ok { - return na.String() == o.String() - } - return false +func (na *NetAddress) Equals(other NetAddress) bool { + return na.String() == other.String() } // Same returns true is na has the same non-empty ID or DialString as other. -func (na *NetAddress) Same(other interface{}) bool { - if o, ok := other.(*NetAddress); ok { - if na.DialString() == o.DialString() { - return true - } - if na.ID != "" && na.ID == o.ID { - return true - } - } - return false +func (na *NetAddress) Same(other NetAddress) bool { + var ( + dialsSame = na.DialString() == other.DialString() + IDsSame = na.ID != "" && na.ID == other.ID + ) + + return dialsSame || IDsSame } // String representation: @: func (na *NetAddress) String() string { if na == nil { - return "" - } - if na.str != "" { - return na.str + return nilNetAddress } + str, err := na.MarshalAmino() if err != nil { - return "" + return badNetAddress } + return str } +// MarshalAmino stringifies a NetAddress. // Needed because (a) IP doesn't encode, and (b) the intend of this type is to // serialize to a string anyways. func (na NetAddress) MarshalAmino() (string, error) { - if na.str == "" { - addrStr := na.DialString() - if na.ID != "" { - addrStr = NetAddressString(na.ID, addrStr) - } - na.str = addrStr + addrStr := na.DialString() + + if na.ID != "" { + return NetAddressString(na.ID, addrStr), nil } - return na.str, nil + + return addrStr, nil } -func (na *NetAddress) UnmarshalAmino(str string) (err error) { - na2, err := NewNetAddressFromString(str) +func (na *NetAddress) UnmarshalAmino(raw string) (err error) { + netAddress, err := NewNetAddressFromString(raw) if err != nil { return err } - *na = *na2 + + *na = *netAddress + return nil } func (na *NetAddress) DialString() string { if na == nil { - return "" + return nilNetAddress } + return net.JoinHostPort( na.IP.String(), strconv.FormatUint(uint64(na.Port), 10), ) } -// Dial calls net.Dial on the address. -func (na *NetAddress) Dial() (net.Conn, error) { - conn, err := net.Dial("tcp", na.DialString()) - if err != nil { - return nil, err - } - return conn, nil -} +// DialContext dials the given NetAddress with a context +func (na *NetAddress) DialContext(ctx context.Context) (net.Conn, error) { + var d net.Dialer -// DialTimeout calls net.DialTimeout on the address. -func (na *NetAddress) DialTimeout(timeout time.Duration) (net.Conn, error) { - conn, err := net.DialTimeout("tcp", na.DialString(), timeout) + conn, err := d.DialContext(ctx, "tcp", na.DialString()) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to dial address, %w", err) } + return conn, nil } @@ -236,47 +239,45 @@ func (na *NetAddress) Routable() bool { if err := na.Validate(); err != nil { return false } + // TODO(oga) bitcoind doesn't include RFC3849 here, but should we? - return !(na.RFC1918() || na.RFC3927() || na.RFC4862() || - na.RFC4193() || na.RFC4843() || na.Local()) + return !(na.RFC1918() || + na.RFC3927() || + na.RFC4862() || + na.RFC4193() || + na.RFC4843() || + na.Local()) } -func (na *NetAddress) ValidateLocal() error { +// Validate validates the NetAddress params +func (na *NetAddress) Validate() error { + // Validate the ID if err := na.ID.Validate(); err != nil { - return err + return fmt.Errorf("unable to validate ID, %w", err) } + + // Make sure the IP is set if na.IP == nil { - return errors.New("no IP") - } - if len(na.IP) != 4 && len(na.IP) != 16 { - return fmt.Errorf("invalid IP bytes: %v", len(na.IP)) - } - if na.RFC3849() || na.IP.Equal(net.IPv4bcast) { - return errors.New("invalid IP", na.IP.IsUnspecified()) + return ErrUnsetIPAddress } - return nil -} -func (na *NetAddress) Validate() error { - if err := na.ID.Validate(); err != nil { - return err - } - if na.IP == nil { - return errors.New("no IP") + // Make sure the IP is valid + ipLen := len(na.IP) + if ipLen != 4 && ipLen != 16 { + return ErrInvalidIP } - if len(na.IP) != 4 && len(na.IP) != 16 { - return fmt.Errorf("invalid IP bytes: %v", len(na.IP)) + + // Check if the IP is unspecified + if na.IP.IsUnspecified() { + return ErrUnspecifiedIP } - if na.IP.IsUnspecified() || na.RFC3849() || na.IP.Equal(net.IPv4bcast) { - return errors.New("invalid IP", na.IP.IsUnspecified()) + + // Check if the IP conforms to standards, or is a broadcast + if na.RFC3849() || na.IP.Equal(net.IPv4bcast) { + return ErrInvalidIP } - return nil -} -// HasID returns true if the address has an ID. -// NOTE: It does not check whether the ID is valid or not. -func (na *NetAddress) HasID() bool { - return !na.ID.IsZero() + return nil } // Local returns true if it is a local address. @@ -284,56 +285,6 @@ func (na *NetAddress) Local() bool { return na.IP.IsLoopback() || zero4.Contains(na.IP) } -// ReachabilityTo checks whenever o can be reached from na. -func (na *NetAddress) ReachabilityTo(o *NetAddress) int { - const ( - Unreachable = 0 - Default = iota - Teredo - Ipv6_weak - Ipv4 - Ipv6_strong - ) - switch { - case !na.Routable(): - return Unreachable - case na.RFC4380(): - switch { - case !o.Routable(): - return Default - case o.RFC4380(): - return Teredo - case o.IP.To4() != nil: - return Ipv4 - default: // ipv6 - return Ipv6_weak - } - case na.IP.To4() != nil: - if o.Routable() && o.IP.To4() != nil { - return Ipv4 - } - return Default - default: /* ipv6 */ - var tunnelled bool - // Is our v6 is tunnelled? - if o.RFC3964() || o.RFC6052() || o.RFC6145() { - tunnelled = true - } - switch { - case !o.Routable(): - return Default - case o.RFC4380(): - return Teredo - case o.IP.To4() != nil: - return Ipv4 - case tunnelled: - // only prioritise ipv6 if we aren't tunnelling it. - return Ipv6_weak - } - return Ipv6_strong - } -} - // RFC1918: IPv4 Private networks (10.0.0.0/8, 192.168.0.0/16, 172.16.0.0/12) // RFC3849: IPv6 Documentation address (2001:0DB8::/32) // RFC3927: IPv4 Autoconfig (169.254.0.0/16) @@ -376,9 +327,12 @@ func (na *NetAddress) RFC4862() bool { return rfc4862.Contains(na.IP) } func (na *NetAddress) RFC6052() bool { return rfc6052.Contains(na.IP) } func (na *NetAddress) RFC6145() bool { return rfc6145.Contains(na.IP) } +// removeProtocolIfDefined removes the protocol part of the given address func removeProtocolIfDefined(addr string) string { - if strings.Contains(addr, "://") { - return strings.Split(addr, "://")[1] + if !strings.Contains(addr, "://") { + // No protocol part + return addr } - return addr + + return strings.Split(addr, "://")[1] } diff --git a/tm2/pkg/p2p/types/netaddress_test.go b/tm2/pkg/p2p/types/netaddress_test.go new file mode 100644 index 00000000000..1f8f0229b99 --- /dev/null +++ b/tm2/pkg/p2p/types/netaddress_test.go @@ -0,0 +1,323 @@ +package types + +import ( + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func BenchmarkNetAddress_String(b *testing.B) { + key := GenerateNodeKey() + + na, err := NewNetAddressFromString(NetAddressString(key.ID(), "127.0.0.1:0")) + require.NoError(b, err) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _ = na.String() + } +} + +func TestNewNetAddress(t *testing.T) { + t.Parallel() + + t.Run("invalid TCP address", func(t *testing.T) { + t.Parallel() + + var ( + key = GenerateNodeKey() + address = "127.0.0.1:8080" + ) + + udpAddr, err := net.ResolveUDPAddr("udp", address) + require.NoError(t, err) + + _, err = NewNetAddress(key.ID(), udpAddr) + require.Error(t, err) + }) + + t.Run("invalid ID", func(t *testing.T) { + t.Parallel() + + var ( + id = "" // zero ID + address = "127.0.0.1:8080" + ) + + tcpAddr, err := net.ResolveTCPAddr("tcp", address) + require.NoError(t, err) + + _, err = NewNetAddress(ID(id), tcpAddr) + require.Error(t, err) + }) + + t.Run("valid net address", func(t *testing.T) { + t.Parallel() + + var ( + key = GenerateNodeKey() + address = "127.0.0.1:8080" + ) + + tcpAddr, err := net.ResolveTCPAddr("tcp", address) + require.NoError(t, err) + + addr, err := NewNetAddress(key.ID(), tcpAddr) + require.NoError(t, err) + + assert.Equal(t, fmt.Sprintf("%s@%s", key.ID(), address), addr.String()) + }) +} + +func TestNewNetAddressFromString(t *testing.T) { + t.Parallel() + + t.Run("valid net address", func(t *testing.T) { + t.Parallel() + + testTable := []struct { + name string + addr string + expected string + }{ + {"no protocol", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080"}, + {"tcp input", "tcp://g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080"}, + {"udp input", "udp://g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080"}, + {"no protocol", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080"}, + {"tcp input", "tcp://g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080"}, + {"udp input", "udp://g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080"}, + {"correct nodeId w/tcp", "tcp://g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080", "g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080"}, + } + + for _, testCase := range testTable { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + addr, err := NewNetAddressFromString(testCase.addr) + require.NoError(t, err) + + assert.Equal(t, testCase.expected, addr.String()) + }) + } + }) + + t.Run("invalid net address", func(t *testing.T) { + t.Parallel() + + testTable := []struct { + name string + addr string + }{ + {"no node id and no protocol", "127.0.0.1:8080"}, + {"no node id w/ tcp input", "tcp://127.0.0.1:8080"}, + {"no node id w/ udp input", "udp://127.0.0.1:8080"}, + + {"malformed tcp input", "tcp//g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080"}, + {"malformed udp input", "udp//g1m6kmam774klwlh4dhmhaatd7al02m0h0jwnyc6@127.0.0.1:8080"}, + + {"invalid host", "notahost"}, + {"invalid port", "127.0.0.1:notapath"}, + {"invalid host w/ port", "notahost:8080"}, + {"just a port", "8082"}, + {"non-existent port", "127.0.0:8080000"}, + + {"too short nodeId", "deadbeef@127.0.0.1:8080"}, + {"too short, not hex nodeId", "this-isnot-hex@127.0.0.1:8080"}, + {"not bech32 nodeId", "xxxm6kmam774klwlh4dhmhaatd7al02m0h0hdap9l@127.0.0.1:8080"}, + + {"too short nodeId w/tcp", "tcp://deadbeef@127.0.0.1:8080"}, + {"too short notHex nodeId w/tcp", "tcp://this-isnot-hex@127.0.0.1:8080"}, + {"not bech32 nodeId w/tcp", "tcp://xxxxm6kmam774klwlh4dhmhaatd7al02m0h0hdap9l@127.0.0.1:8080"}, + + {"no node id", "tcp://@127.0.0.1:8080"}, + {"no node id or IP", "tcp://@"}, + {"tcp no host, w/ port", "tcp://:26656"}, + {"empty", ""}, + {"node id delimiter 1", "@"}, + {"node id delimiter 2", " @"}, + {"node id delimiter 3", " @ "}, + } + + for _, testCase := range testTable { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + addr, err := NewNetAddressFromString(testCase.addr) + + assert.Nil(t, addr) + assert.Error(t, err) + }) + } + }) +} + +func TestNewNetAddressFromStrings(t *testing.T) { + t.Parallel() + + t.Run("invalid addresses", func(t *testing.T) { + t.Parallel() + + var ( + keys = generateKeys(t, 10) + strs = make([]string, 0, len(keys)) + ) + + for index, key := range keys { + if index%2 != 0 { + strs = append( + strs, + fmt.Sprintf("%s@:8080", key.ID()), // missing host + ) + + continue + } + + strs = append( + strs, + fmt.Sprintf("%s@127.0.0.1:8080", key.ID()), + ) + } + + // Convert the strings + addrs, errs := NewNetAddressFromStrings(strs) + + assert.Len(t, errs, len(keys)/2) + assert.Equal(t, len(keys)/2, len(addrs)) + + for index, addr := range addrs { + assert.Contains(t, addr.String(), keys[index*2].ID()) + } + }) + + t.Run("valid addresses", func(t *testing.T) { + t.Parallel() + + var ( + keys = generateKeys(t, 10) + strs = make([]string, 0, len(keys)) + ) + + for _, key := range keys { + strs = append( + strs, + fmt.Sprintf("%s@127.0.0.1:8080", key.ID()), + ) + } + + // Convert the strings + addrs, errs := NewNetAddressFromStrings(strs) + + assert.Len(t, errs, 0) + assert.Equal(t, len(keys), len(addrs)) + + for index, addr := range addrs { + assert.Contains(t, addr.String(), keys[index].ID()) + } + }) +} + +func TestNewNetAddressFromIPPort(t *testing.T) { + t.Parallel() + + var ( + host = "127.0.0.1" + port = uint16(8080) + ) + + addr := NewNetAddressFromIPPort(net.ParseIP(host), port) + + assert.Equal( + t, + fmt.Sprintf("%s:%d", host, port), + addr.String(), + ) +} + +func TestNetAddress_Local(t *testing.T) { + t.Parallel() + + testTable := []struct { + name string + addr string + isLocal bool + }{ + { + "local loopback", + "127.0.0.1:8080", + true, + }, + { + "local loopback, zero", + "0.0.0.0:8080", + true, + }, + { + "non-local address", + "200.100.200.100:8080", + false, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + key := GenerateNodeKey() + + addr, err := NewNetAddressFromString( + fmt.Sprintf( + "%s@%s", + key.ID(), + testCase.addr, + ), + ) + require.NoError(t, err) + + assert.Equal(t, testCase.isLocal, addr.Local()) + }) + } +} + +func TestNetAddress_Routable(t *testing.T) { + t.Parallel() + + testTable := []struct { + name string + addr string + isRoutable bool + }{ + { + "local loopback", + "127.0.0.1:8080", + false, + }, + { + "routable address", + "gno.land:80", + true, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + key := GenerateNodeKey() + + addr, err := NewNetAddressFromString( + fmt.Sprintf( + "%s@%s", + key.ID(), + testCase.addr, + ), + ) + require.NoError(t, err) + + assert.Equal(t, testCase.isRoutable, addr.Routable()) + }) + } +} diff --git a/tm2/pkg/p2p/types/node_info.go b/tm2/pkg/p2p/types/node_info.go new file mode 100644 index 00000000000..8452cb43cb8 --- /dev/null +++ b/tm2/pkg/p2p/types/node_info.go @@ -0,0 +1,141 @@ +package types + +import ( + "errors" + "fmt" + + "github.com/gnolang/gno/tm2/pkg/strings" + "github.com/gnolang/gno/tm2/pkg/versionset" +) + +const ( + MaxNodeInfoSize = int64(10240) // 10KB + maxNumChannels = 16 // plenty of room for upgrades, for now +) + +var ( + ErrInvalidPeerID = errors.New("invalid peer ID") + ErrInvalidVersion = errors.New("invalid node version") + ErrInvalidMoniker = errors.New("invalid node moniker") + ErrInvalidRPCAddress = errors.New("invalid node RPC address") + ErrExcessiveChannels = errors.New("excessive node channels") + ErrDuplicateChannels = errors.New("duplicate node channels") + ErrIncompatibleNetworks = errors.New("incompatible networks") + ErrNoCommonChannels = errors.New("no common channels") +) + +// NodeInfo is the basic node information exchanged +// between two peers during the Tendermint P2P handshake. +type NodeInfo struct { + // Set of protocol versions + VersionSet versionset.VersionSet `json:"version_set"` + + // Unique peer identifier + PeerID ID `json:"id"` + + // Check compatibility. + // Channels are HexBytes so easier to read as JSON + Network string `json:"network"` // network/chain ID + Software string `json:"software"` // name of immediate software + Version string `json:"version"` // software major.minor.revision + Channels []byte `json:"channels"` // channels this node knows about + + // ASCIIText fields + Moniker string `json:"moniker"` // arbitrary moniker + Other NodeInfoOther `json:"other"` // other application specific data +} + +// NodeInfoOther is the misc. application specific data +type NodeInfoOther struct { + TxIndex string `json:"tx_index"` + RPCAddress string `json:"rpc_address"` +} + +// Validate checks the self-reported NodeInfo is safe. +// It returns an error if there +// are too many Channels, if there are any duplicate Channels, +// if the ListenAddr is malformed, or if the ListenAddr is a host name +// that can not be resolved to some IP +func (info NodeInfo) Validate() error { + // Validate the ID + if err := info.PeerID.Validate(); err != nil { + return fmt.Errorf("%w, %w", ErrInvalidPeerID, err) + } + + // Validate Version + if len(info.Version) > 0 && + (!strings.IsASCIIText(info.Version) || + strings.ASCIITrim(info.Version) == "") { + return ErrInvalidVersion + } + + // Validate Channels - ensure max and check for duplicates. + if len(info.Channels) > maxNumChannels { + return ErrExcessiveChannels + } + + channelMap := make(map[byte]struct{}, len(info.Channels)) + for _, ch := range info.Channels { + if _, ok := channelMap[ch]; ok { + return ErrDuplicateChannels + } + + // Mark the channel as present + channelMap[ch] = struct{}{} + } + + // Validate Moniker. + if !strings.IsASCIIText(info.Moniker) || strings.ASCIITrim(info.Moniker) == "" { + return ErrInvalidMoniker + } + + // XXX: Should we be more strict about address formats? + rpcAddr := info.Other.RPCAddress + if len(rpcAddr) > 0 && (!strings.IsASCIIText(rpcAddr) || strings.ASCIITrim(rpcAddr) == "") { + return ErrInvalidRPCAddress + } + + return nil +} + +// ID returns the local node ID +func (info NodeInfo) ID() ID { + return info.PeerID +} + +// CompatibleWith checks if two NodeInfo are compatible with each other. +// CONTRACT: two nodes are compatible if the Block version and networks match, +// and they have at least one channel in common +func (info NodeInfo) CompatibleWith(other NodeInfo) error { + // Validate the protocol versions + if _, err := info.VersionSet.CompatibleWith(other.VersionSet); err != nil { + return fmt.Errorf("incompatible version sets, %w", err) + } + + // Make sure nodes are on the same network + if info.Network != other.Network { + return ErrIncompatibleNetworks + } + + // Make sure there is at least 1 channel in common + commonFound := false + for _, ch1 := range info.Channels { + for _, ch2 := range other.Channels { + if ch1 == ch2 { + commonFound = true + + break + } + } + + if commonFound { + break + } + } + + if !commonFound { + return ErrNoCommonChannels + } + + return nil +} diff --git a/tm2/pkg/p2p/types/node_info_test.go b/tm2/pkg/p2p/types/node_info_test.go new file mode 100644 index 00000000000..d03d77e608f --- /dev/null +++ b/tm2/pkg/p2p/types/node_info_test.go @@ -0,0 +1,321 @@ +package types + +import ( + "fmt" + "testing" + + "github.com/gnolang/gno/tm2/pkg/versionset" + "github.com/stretchr/testify/assert" +) + +func TestNodeInfo_Validate(t *testing.T) { + t.Parallel() + + t.Run("invalid peer ID", func(t *testing.T) { + t.Parallel() + + info := &NodeInfo{ + PeerID: "", // zero + } + + assert.ErrorIs(t, info.Validate(), ErrInvalidPeerID) + }) + + t.Run("invalid version", func(t *testing.T) { + t.Parallel() + + testTable := []struct { + name string + version string + }{ + { + "non-ascii version", + "¢§µ", + }, + { + "empty tab version", + fmt.Sprintf("\t"), + }, + { + "empty space version", + fmt.Sprintf(" "), + }, + } + + for _, testCase := range testTable { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + info := &NodeInfo{ + PeerID: GenerateNodeKey().ID(), + Version: testCase.version, + } + + assert.ErrorIs(t, info.Validate(), ErrInvalidVersion) + }) + } + }) + + t.Run("invalid moniker", func(t *testing.T) { + t.Parallel() + + testTable := []struct { + name string + moniker string + }{ + { + "empty moniker", + "", + }, + { + "non-ascii moniker", + "¢§µ", + }, + { + "empty tab moniker", + fmt.Sprintf("\t"), + }, + { + "empty space moniker", + fmt.Sprintf(" "), + }, + } + + for _, testCase := range testTable { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + info := &NodeInfo{ + PeerID: GenerateNodeKey().ID(), + Moniker: testCase.moniker, + } + + assert.ErrorIs(t, info.Validate(), ErrInvalidMoniker) + }) + } + }) + + t.Run("invalid RPC Address", func(t *testing.T) { + t.Parallel() + + testTable := []struct { + name string + rpcAddress string + }{ + { + "non-ascii moniker", + "¢§µ", + }, + { + "empty tab RPC address", + fmt.Sprintf("\t"), + }, + { + "empty space RPC address", + fmt.Sprintf(" "), + }, + } + + for _, testCase := range testTable { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + info := &NodeInfo{ + PeerID: GenerateNodeKey().ID(), + Moniker: "valid moniker", + Other: NodeInfoOther{ + RPCAddress: testCase.rpcAddress, + }, + } + + assert.ErrorIs(t, info.Validate(), ErrInvalidRPCAddress) + }) + } + }) + + t.Run("invalid channels", func(t *testing.T) { + t.Parallel() + + testTable := []struct { + name string + channels []byte + expectedErr error + }{ + { + "too many channels", + make([]byte, maxNumChannels+1), + ErrExcessiveChannels, + }, + { + "duplicate channels", + []byte{ + byte(10), + byte(20), + byte(10), + }, + ErrDuplicateChannels, + }, + } + + for _, testCase := range testTable { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + info := &NodeInfo{ + PeerID: GenerateNodeKey().ID(), + Moniker: "valid moniker", + Channels: testCase.channels, + } + + assert.ErrorIs(t, info.Validate(), testCase.expectedErr) + }) + } + }) + + t.Run("valid node info", func(t *testing.T) { + t.Parallel() + + info := &NodeInfo{ + PeerID: GenerateNodeKey().ID(), + Moniker: "valid moniker", + Channels: []byte{10, 20, 30}, + Other: NodeInfoOther{ + RPCAddress: "0.0.0.0:26657", + }, + } + + assert.NoError(t, info.Validate()) + }) +} + +func TestNodeInfo_CompatibleWith(t *testing.T) { + t.Parallel() + + t.Run("incompatible version sets", func(t *testing.T) { + t.Parallel() + + var ( + name = "Block" + + infoOne = &NodeInfo{ + VersionSet: []versionset.VersionInfo{ + { + Name: name, + Version: "badversion", + }, + }, + } + + infoTwo = &NodeInfo{ + VersionSet: []versionset.VersionInfo{ + { + Name: name, + Version: "v0.0.0", + }, + }, + } + ) + + assert.Error(t, infoTwo.CompatibleWith(*infoOne)) + }) + + t.Run("incompatible networks", func(t *testing.T) { + t.Parallel() + + var ( + name = "Block" + version = "v0.0.0" + + infoOne = &NodeInfo{ + Network: "+wrong", + VersionSet: []versionset.VersionInfo{ + { + Name: name, + Version: version, + }, + }, + } + + infoTwo = &NodeInfo{ + Network: "gno", + VersionSet: []versionset.VersionInfo{ + { + Name: name, + Version: version, + }, + }, + } + ) + + assert.ErrorIs(t, infoTwo.CompatibleWith(*infoOne), ErrIncompatibleNetworks) + }) + + t.Run("no common channels", func(t *testing.T) { + t.Parallel() + + var ( + name = "Block" + version = "v0.0.0" + network = "gno" + + infoOne = &NodeInfo{ + Network: network, + VersionSet: []versionset.VersionInfo{ + { + Name: name, + Version: version, + }, + }, + Channels: []byte{10}, + } + + infoTwo = &NodeInfo{ + Network: network, + VersionSet: []versionset.VersionInfo{ + { + Name: name, + Version: version, + }, + }, + Channels: []byte{20}, + } + ) + + assert.ErrorIs(t, infoTwo.CompatibleWith(*infoOne), ErrNoCommonChannels) + }) + + t.Run("fully compatible node infos", func(t *testing.T) { + t.Parallel() + + var ( + name = "Block" + version = "v0.0.0" + network = "gno" + channels = []byte{10, 20, 30} + + infoOne = &NodeInfo{ + Network: network, + VersionSet: []versionset.VersionInfo{ + { + Name: name, + Version: version, + }, + }, + Channels: channels, + } + + infoTwo = &NodeInfo{ + Network: network, + VersionSet: []versionset.VersionInfo{ + { + Name: name, + Version: version, + }, + }, + Channels: channels[1:], + } + ) + + assert.NoError(t, infoTwo.CompatibleWith(*infoOne)) + }) +} diff --git a/tm2/pkg/p2p/upnp/probe.go b/tm2/pkg/p2p/upnp/probe.go deleted file mode 100644 index 29480e7cecc..00000000000 --- a/tm2/pkg/p2p/upnp/probe.go +++ /dev/null @@ -1,110 +0,0 @@ -package upnp - -import ( - "fmt" - "log/slog" - "net" - "time" -) - -type UPNPCapabilities struct { - PortMapping bool - Hairpin bool -} - -func makeUPNPListener(intPort int, extPort int, logger *slog.Logger) (NAT, net.Listener, net.IP, error) { - nat, err := Discover() - if err != nil { - return nil, nil, nil, fmt.Errorf("NAT upnp could not be discovered: %w", err) - } - logger.Info(fmt.Sprintf("ourIP: %v", nat.(*upnpNAT).ourIP)) - - ext, err := nat.GetExternalAddress() - if err != nil { - return nat, nil, nil, fmt.Errorf("external address error: %w", err) - } - logger.Info(fmt.Sprintf("External address: %v", ext)) - - port, err := nat.AddPortMapping("tcp", extPort, intPort, "Tendermint UPnP Probe", 0) - if err != nil { - return nat, nil, ext, fmt.Errorf("port mapping error: %w", err) - } - logger.Info(fmt.Sprintf("Port mapping mapped: %v", port)) - - // also run the listener, open for all remote addresses. - listener, err := net.Listen("tcp", fmt.Sprintf(":%v", intPort)) - if err != nil { - return nat, nil, ext, fmt.Errorf("error establishing listener: %w", err) - } - return nat, listener, ext, nil -} - -func testHairpin(listener net.Listener, extAddr string, logger *slog.Logger) (supportsHairpin bool) { - // Listener - go func() { - inConn, err := listener.Accept() - if err != nil { - logger.Info(fmt.Sprintf("Listener.Accept() error: %v", err)) - return - } - logger.Info(fmt.Sprintf("Accepted incoming connection: %v -> %v", inConn.LocalAddr(), inConn.RemoteAddr())) - buf := make([]byte, 1024) - n, err := inConn.Read(buf) - if err != nil { - logger.Info(fmt.Sprintf("Incoming connection read error: %v", err)) - return - } - logger.Info(fmt.Sprintf("Incoming connection read %v bytes: %X", n, buf)) - if string(buf) == "test data" { - supportsHairpin = true - return - } - }() - - // Establish outgoing - outConn, err := net.Dial("tcp", extAddr) - if err != nil { - logger.Info(fmt.Sprintf("Outgoing connection dial error: %v", err)) - return - } - - n, err := outConn.Write([]byte("test data")) - if err != nil { - logger.Info(fmt.Sprintf("Outgoing connection write error: %v", err)) - return - } - logger.Info(fmt.Sprintf("Outgoing connection wrote %v bytes", n)) - - // Wait for data receipt - time.Sleep(1 * time.Second) - return supportsHairpin -} - -func Probe(logger *slog.Logger) (caps UPNPCapabilities, err error) { - logger.Info("Probing for UPnP!") - - intPort, extPort := 8001, 8001 - - nat, listener, ext, err := makeUPNPListener(intPort, extPort, logger) - if err != nil { - return - } - caps.PortMapping = true - - // Deferred cleanup - defer func() { - if err := nat.DeletePortMapping("tcp", intPort, extPort); err != nil { - logger.Error(fmt.Sprintf("Port mapping delete error: %v", err)) - } - if err := listener.Close(); err != nil { - logger.Error(fmt.Sprintf("Listener closing error: %v", err)) - } - }() - - supportsHairpin := testHairpin(listener, fmt.Sprintf("%v:%v", ext, extPort), logger) - if supportsHairpin { - caps.Hairpin = true - } - - return -} diff --git a/tm2/pkg/p2p/upnp/upnp.go b/tm2/pkg/p2p/upnp/upnp.go deleted file mode 100644 index cd47ac35553..00000000000 --- a/tm2/pkg/p2p/upnp/upnp.go +++ /dev/null @@ -1,392 +0,0 @@ -// Taken from taipei-torrent. -// Just enough UPnP to be able to forward ports -// For more information, see: http://www.upnp-hacks.org/upnp.html -package upnp - -// TODO: use syscalls to get actual ourIP, see issue #712 - -import ( - "bytes" - "encoding/xml" - "errors" - "fmt" - "io" - "net" - "net/http" - "strconv" - "strings" - "time" -) - -type upnpNAT struct { - serviceURL string - ourIP string - urnDomain string -} - -// protocol is either "udp" or "tcp" -type NAT interface { - GetExternalAddress() (addr net.IP, err error) - AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) - DeletePortMapping(protocol string, externalPort, internalPort int) (err error) -} - -func Discover() (nat NAT, err error) { - ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") - if err != nil { - return - } - conn, err := net.ListenPacket("udp4", ":0") - if err != nil { - return - } - socket := conn.(*net.UDPConn) - defer socket.Close() //nolint: errcheck - - if err := socket.SetDeadline(time.Now().Add(3 * time.Second)); err != nil { - return nil, err - } - - st := "InternetGatewayDevice:1" - - buf := bytes.NewBufferString( - "M-SEARCH * HTTP/1.1\r\n" + - "HOST: 239.255.255.250:1900\r\n" + - "ST: ssdp:all\r\n" + - "MAN: \"ssdp:discover\"\r\n" + - "MX: 2\r\n\r\n") - message := buf.Bytes() - answerBytes := make([]byte, 1024) - for i := 0; i < 3; i++ { - _, err = socket.WriteToUDP(message, ssdp) - if err != nil { - return - } - var n int - _, _, err = socket.ReadFromUDP(answerBytes) - if err != nil { - return - } - for { - n, _, err = socket.ReadFromUDP(answerBytes) - if err != nil { - break - } - answer := string(answerBytes[0:n]) - if !strings.Contains(answer, st) { - continue - } - // HTTP header field names are case-insensitive. - // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 - locString := "\r\nlocation:" - answer = strings.ToLower(answer) - locIndex := strings.Index(answer, locString) - if locIndex < 0 { - continue - } - loc := answer[locIndex+len(locString):] - endIndex := strings.Index(loc, "\r\n") - if endIndex < 0 { - continue - } - locURL := strings.TrimSpace(loc[0:endIndex]) - var serviceURL, urnDomain string - serviceURL, urnDomain, err = getServiceURL(locURL) - if err != nil { - return - } - var ourIP net.IP - ourIP, err = localIPv4() - if err != nil { - return - } - nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain} - return - } - } - err = errors.New("UPnP port discovery failed") - return nat, err -} - -type Envelope struct { - XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"` - Soap *SoapBody -} - -type SoapBody struct { - XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"` - ExternalIP *ExternalIPAddressResponse -} - -type ExternalIPAddressResponse struct { - XMLName xml.Name `xml:"GetExternalIPAddressResponse"` - IPAddress string `xml:"NewExternalIPAddress"` -} - -type ExternalIPAddress struct { - XMLName xml.Name `xml:"NewExternalIPAddress"` - IP string -} - -type UPNPService struct { - ServiceType string `xml:"serviceType"` - ControlURL string `xml:"controlURL"` -} - -type DeviceList struct { - Device []Device `xml:"device"` -} - -type ServiceList struct { - Service []UPNPService `xml:"service"` -} - -type Device struct { - XMLName xml.Name `xml:"device"` - DeviceType string `xml:"deviceType"` - DeviceList DeviceList `xml:"deviceList"` - ServiceList ServiceList `xml:"serviceList"` -} - -type Root struct { - Device Device -} - -func getChildDevice(d *Device, deviceType string) *Device { - dl := d.DeviceList.Device - for i := 0; i < len(dl); i++ { - if strings.Contains(dl[i].DeviceType, deviceType) { - return &dl[i] - } - } - return nil -} - -func getChildService(d *Device, serviceType string) *UPNPService { - sl := d.ServiceList.Service - for i := 0; i < len(sl); i++ { - if strings.Contains(sl[i].ServiceType, serviceType) { - return &sl[i] - } - } - return nil -} - -func localIPv4() (net.IP, error) { - tt, err := net.Interfaces() - if err != nil { - return nil, err - } - for _, t := range tt { - aa, err := t.Addrs() - if err != nil { - return nil, err - } - for _, a := range aa { - ipnet, ok := a.(*net.IPNet) - if !ok { - continue - } - v4 := ipnet.IP.To4() - if v4 == nil || v4[0] == 127 { // loopback address - continue - } - return v4, nil - } - } - return nil, errors.New("cannot find local IP address") -} - -func getServiceURL(rootURL string) (url, urnDomain string, err error) { - r, err := http.Get(rootURL) //nolint: gosec - if err != nil { - return - } - defer r.Body.Close() //nolint: errcheck - - if r.StatusCode >= 400 { - err = errors.New(fmt.Sprint(r.StatusCode)) - return - } - var root Root - err = xml.NewDecoder(r.Body).Decode(&root) - if err != nil { - return - } - a := &root.Device - if !strings.Contains(a.DeviceType, "InternetGatewayDevice:1") { - err = errors.New("no InternetGatewayDevice") - return - } - b := getChildDevice(a, "WANDevice:1") - if b == nil { - err = errors.New("no WANDevice") - return - } - c := getChildDevice(b, "WANConnectionDevice:1") - if c == nil { - err = errors.New("no WANConnectionDevice") - return - } - d := getChildService(c, "WANIPConnection:1") - if d == nil { - // Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice, - // instead of under WanConnectionDevice - d = getChildService(b, "WANIPConnection:1") - - if d == nil { - err = errors.New("no WANIPConnection") - return - } - } - // Extract the domain name, which isn't always 'schemas-upnp-org' - urnDomain = strings.Split(d.ServiceType, ":")[1] - url = combineURL(rootURL, d.ControlURL) - return url, urnDomain, err -} - -func combineURL(rootURL, subURL string) string { - protocolEnd := "://" - protoEndIndex := strings.Index(rootURL, protocolEnd) - a := rootURL[protoEndIndex+len(protocolEnd):] - rootIndex := strings.Index(a, "/") - return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL -} - -func soapRequest(url, function, message, domain string) (r *http.Response, err error) { - fullMessage := "" + - "\r\n" + - "" + message + "" - - req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage)) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"") - req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3") - // req.Header.Set("Transfer-Encoding", "chunked") - req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"") - req.Header.Set("Connection", "Close") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Pragma", "no-cache") - - // log.Stderr("soapRequest ", req) - - r, err = http.DefaultClient.Do(req) - if err != nil { - return nil, err - } - /*if r.Body != nil { - defer r.Body.Close() - }*/ - - if r.StatusCode >= 400 { - // log.Stderr(function, r.StatusCode) - err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function) - r = nil - return - } - return r, err -} - -type statusInfo struct { - externalIPAddress string -} - -func (n *upnpNAT) getExternalIPAddress() (info statusInfo, err error) { - message := "\r\n" + - "" - - var response *http.Response - response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain) - if response != nil { - defer response.Body.Close() //nolint: errcheck - } - if err != nil { - return - } - var envelope Envelope - data, err := io.ReadAll(response.Body) - if err != nil { - return - } - reader := bytes.NewReader(data) - err = xml.NewDecoder(reader).Decode(&envelope) - if err != nil { - return - } - - info = statusInfo{envelope.Soap.ExternalIP.IPAddress} - - if err != nil { - return - } - - return info, err -} - -// GetExternalAddress returns an external IP. If GetExternalIPAddress action -// fails or IP returned is invalid, GetExternalAddress returns an error. -func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { - info, err := n.getExternalIPAddress() - if err != nil { - return - } - addr = net.ParseIP(info.externalIPAddress) - if addr == nil { - err = fmt.Errorf("failed to parse IP: %v", info.externalIPAddress) - } - return -} - -func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { - // A single concatenation would break ARM compilation. - message := "\r\n" + - "" + strconv.Itoa(externalPort) - message += "" + protocol + "" - message += "" + strconv.Itoa(internalPort) + "" + - "" + n.ourIP + "" + - "1" - message += description + - "" + strconv.Itoa(timeout) + - "" - - var response *http.Response - response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain) - if response != nil { - defer response.Body.Close() //nolint: errcheck - } - if err != nil { - return - } - - // TODO: check response to see if the port was forwarded - // log.Println(message, response) - // JAE: - // body, err := io.ReadAll(response.Body) - // fmt.Println(string(body), err) - mappedExternalPort = externalPort - _ = response - return -} - -func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { - message := "\r\n" + - "" + strconv.Itoa(externalPort) + - "" + protocol + "" + - "" - - var response *http.Response - response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain) - if response != nil { - defer response.Body.Close() //nolint: errcheck - } - if err != nil { - return - } - - // TODO: check response to see if the port was deleted - // log.Println(message, response) - _ = response - return -} diff --git a/tm2/pkg/sdk/auth/ante.go b/tm2/pkg/sdk/auth/ante.go index 997478fe4b5..f05a8eff0a7 100644 --- a/tm2/pkg/sdk/auth/ante.go +++ b/tm2/pkg/sdk/auth/ante.go @@ -387,9 +387,10 @@ func EnsureSufficientMempoolFees(ctx sdk.Context, fee std.Fee) sdk.Result { if prod1.Cmp(prod2) >= 0 { return sdk.Result{} } else { + fee := new(big.Int).Quo(prod2, gpg) return abciResult(std.ErrInsufficientFee( fmt.Sprintf( - "insufficient fees; got: {Gas-Wanted: %d, Gas-Fee %s}, fee required: %+v as minimum gas price set by the node", feeGasPrice.Gas, feeGasPrice.Price, gp, + "insufficient fees; got: {Gas-Wanted: %d, Gas-Fee %s}, fee required: %d with %+v as minimum gas price set by the node", feeGasPrice.Gas, feeGasPrice.Price, fee, gp, ), )) } @@ -418,16 +419,20 @@ func SetGasMeter(simulate bool, ctx sdk.Context, gasLimit int64) sdk.Context { // GetSignBytes returns a slice of bytes to sign over for a given transaction // and an account. func GetSignBytes(chainID string, tx std.Tx, acc std.Account, genesis bool) ([]byte, error) { - var accNum uint64 + var ( + accNum uint64 + accSequence uint64 + ) if !genesis { accNum = acc.GetAccountNumber() + accSequence = acc.GetSequence() } return std.GetSignaturePayload( std.SignDoc{ ChainID: chainID, AccountNumber: accNum, - Sequence: acc.GetSequence(), + Sequence: accSequence, Fee: tx.Fee, Msgs: tx.Msgs, Memo: tx.Memo, diff --git a/tm2/pkg/sdk/auth/ante_test.go b/tm2/pkg/sdk/auth/ante_test.go index 78018b415eb..7c6ace51e4e 100644 --- a/tm2/pkg/sdk/auth/ante_test.go +++ b/tm2/pkg/sdk/auth/ante_test.go @@ -209,8 +209,8 @@ func TestAnteHandlerAccountNumbersAtBlockHeightZero(t *testing.T) { tx = tu.NewTestTx(t, ctx.ChainID(), msgs, privs, []uint64{1}, seqs, fee) checkInvalidTx(t, anteHandler, ctx, tx, false, std.UnauthorizedError{}) - // from correct account number - seqs = []uint64{1} + // At genesis account number is zero + seqs = []uint64{0} tx = tu.NewTestTx(t, ctx.ChainID(), msgs, privs, []uint64{0}, seqs, fee) checkValidTx(t, anteHandler, ctx, tx, false) @@ -223,7 +223,7 @@ func TestAnteHandlerAccountNumbersAtBlockHeightZero(t *testing.T) { checkInvalidTx(t, anteHandler, ctx, tx, false, std.UnauthorizedError{}) // correct account numbers - privs, accnums, seqs = []crypto.PrivKey{priv1, priv2}, []uint64{0, 0}, []uint64{2, 0} + privs, accnums, seqs = []crypto.PrivKey{priv1, priv2}, []uint64{0, 0}, []uint64{0, 0} tx = tu.NewTestTx(t, ctx.ChainID(), msgs, privs, accnums, seqs, fee) checkValidTx(t, anteHandler, ctx, tx, false) } diff --git a/tm2/pkg/sdk/auth/params.go b/tm2/pkg/sdk/auth/params.go index 3fe08ed444d..fda85c7a3d6 100644 --- a/tm2/pkg/sdk/auth/params.go +++ b/tm2/pkg/sdk/auth/params.go @@ -69,15 +69,16 @@ func DefaultParams() Params { // String implements the stringer interface. func (p Params) String() string { - var sb strings.Builder + var builder strings.Builder + sb := &builder // Pointer for use with fmt.Fprintf sb.WriteString("Params: \n") - sb.WriteString(fmt.Sprintf("MaxMemoBytes: %d\n", p.MaxMemoBytes)) - sb.WriteString(fmt.Sprintf("TxSigLimit: %d\n", p.TxSigLimit)) - sb.WriteString(fmt.Sprintf("TxSizeCostPerByte: %d\n", p.TxSizeCostPerByte)) - sb.WriteString(fmt.Sprintf("SigVerifyCostED25519: %d\n", p.SigVerifyCostED25519)) - sb.WriteString(fmt.Sprintf("SigVerifyCostSecp256k1: %d\n", p.SigVerifyCostSecp256k1)) - sb.WriteString(fmt.Sprintf("GasPricesChangeCompressor: %d\n", p.GasPricesChangeCompressor)) - sb.WriteString(fmt.Sprintf("TargetGasRatio: %d\n", p.TargetGasRatio)) + fmt.Fprintf(sb, "MaxMemoBytes: %d\n", p.MaxMemoBytes) + fmt.Fprintf(sb, "TxSigLimit: %d\n", p.TxSigLimit) + fmt.Fprintf(sb, "TxSizeCostPerByte: %d\n", p.TxSizeCostPerByte) + fmt.Fprintf(sb, "SigVerifyCostED25519: %d\n", p.SigVerifyCostED25519) + fmt.Fprintf(sb, "SigVerifyCostSecp256k1: %d\n", p.SigVerifyCostSecp256k1) + fmt.Fprintf(sb, "GasPricesChangeCompressor: %d\n", p.GasPricesChangeCompressor) + fmt.Fprintf(sb, "TargetGasRatio: %d\n", p.TargetGasRatio) return sb.String() } diff --git a/tm2/pkg/sdk/auth/params_test.go b/tm2/pkg/sdk/auth/params_test.go index 4b5a6b15789..36a52ac9001 100644 --- a/tm2/pkg/sdk/auth/params_test.go +++ b/tm2/pkg/sdk/auth/params_test.go @@ -4,6 +4,7 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" ) @@ -105,3 +106,26 @@ func TestNewParams(t *testing.T) { t.Errorf("NewParams() = %+v, want %+v", params, expectedParams) } } + +func TestParamsString(t *testing.T) { + cases := []struct { + name string + params Params + want string + }{ + {"blank params", Params{}, "Params: \nMaxMemoBytes: 0\nTxSigLimit: 0\nTxSizeCostPerByte: 0\nSigVerifyCostED25519: 0\nSigVerifyCostSecp256k1: 0\nGasPricesChangeCompressor: 0\nTargetGasRatio: 0\n"}, + {"some values", Params{ + MaxMemoBytes: 1_000_000, + TxSizeCostPerByte: 8192, + }, "Params: \nMaxMemoBytes: 1000000\nTxSigLimit: 0\nTxSizeCostPerByte: 8192\nSigVerifyCostED25519: 0\nSigVerifyCostSecp256k1: 0\nGasPricesChangeCompressor: 0\nTargetGasRatio: 0\n"}, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got := tt.params.String() + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Fatalf("Mismatch: got - want +\n%s", diff) + } + }) + } +} diff --git a/tm2/pkg/sdk/config/config.go b/tm2/pkg/sdk/config/config.go new file mode 100644 index 00000000000..6e5ededf9a4 --- /dev/null +++ b/tm2/pkg/sdk/config/config.go @@ -0,0 +1,35 @@ +package config + +import ( + "github.com/gnolang/gno/tm2/pkg/errors" + "github.com/gnolang/gno/tm2/pkg/std" +) + +// ----------------------------------------------------------------------------- +// Application Config + +// AppConfig defines the configuration options for the Application +type AppConfig struct { + // Lowest gas prices accepted by a validator in the form of "100tokenA/3gas;10tokenB/5gas" separated by semicolons + MinGasPrices string `json:"min_gas_prices" toml:"min_gas_prices" comment:"Lowest gas prices accepted by a validator"` +} + +// DefaultAppConfig returns a default configuration for the application +func DefaultAppConfig() *AppConfig { + return &AppConfig{ + MinGasPrices: "", + } +} + +// ValidateBasic performs basic validation, checking format and param bounds, etc., and +// returns an error if any check fails. +func (cfg *AppConfig) ValidateBasic() error { + if cfg.MinGasPrices == "" { + return nil + } + if _, err := std.ParseGasPrices(cfg.MinGasPrices); err != nil { + return errors.Wrap(err, "invalid min gas prices") + } + + return nil +} diff --git a/tm2/pkg/sdk/config/config_test.go b/tm2/pkg/sdk/config/config_test.go new file mode 100644 index 00000000000..dd0c391849b --- /dev/null +++ b/tm2/pkg/sdk/config/config_test.go @@ -0,0 +1,36 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateAppConfig(t *testing.T) { + c := DefaultAppConfig() + c.MinGasPrices = "" // empty + + testCases := []struct { + testName string + minGasPrices string + expectErr bool + }{ + {"invalid min gas prices invalid gas", "10token/1", true}, + {"invalid min gas prices invalid gas denom", "9token/0gs", true}, + {"invalid min gas prices zero gas", "10token/0gas", true}, + {"invalid min gas prices no gas", "10token/gas", true}, + {"invalid min gas prices negtive gas", "10token/-1gas", true}, + {"invalid min gas prices invalid denom", "10$token/2gas", true}, + {"invalid min gas prices invalid second denom", "10token/2gas;10/3gas", true}, + {"valid min gas prices", "10foo/3gas;5bar/3gas", false}, + } + + cfg := DefaultAppConfig() + for _, tc := range testCases { + tc := tc + t.Run(tc.testName, func(t *testing.T) { + cfg.MinGasPrices = tc.minGasPrices + assert.Equal(t, tc.expectErr, cfg.ValidateBasic() != nil) + }) + } +} diff --git a/tm2/pkg/std/gasprice.go b/tm2/pkg/std/gasprice.go index fd082a93371..82d236c1d04 100644 --- a/tm2/pkg/std/gasprice.go +++ b/tm2/pkg/std/gasprice.go @@ -29,9 +29,11 @@ func ParseGasPrice(gasprice string) (GasPrice, error) { if gas.Denom != "gas" { return GasPrice{}, errors.New("invalid gas price: %s (invalid gas denom)", gasprice) } - if gas.Amount == 0 { - return GasPrice{}, errors.New("invalid gas price: %s (gas can not be zero)", gasprice) + + if gas.Amount <= 0 { + return GasPrice{}, errors.New("invalid gas price: %s (invalid gas amount)", gasprice) } + return GasPrice{ Gas: gas.Amount, Price: price, diff --git a/tm2/pkg/telemetry/metrics/metrics.go b/tm2/pkg/telemetry/metrics/metrics.go index e3ae932612f..5eeb664ec8f 100644 --- a/tm2/pkg/telemetry/metrics/metrics.go +++ b/tm2/pkg/telemetry/metrics/metrics.go @@ -16,12 +16,10 @@ import ( ) const ( - broadcastTxTimerKey = "broadcast_tx_hist" - buildBlockTimerKey = "build_block_hist" + buildBlockTimerKey = "build_block_hist" inboundPeersKey = "inbound_peers_gauge" outboundPeersKey = "outbound_peers_gauge" - dialingPeersKey = "dialing_peers_gauge" numMempoolTxsKey = "num_mempool_txs_hist" numCachedTxsKey = "num_cached_txs_hist" @@ -42,11 +40,6 @@ const ( ) var ( - // Misc // - - // BroadcastTxTimer measures the transaction broadcast duration - BroadcastTxTimer metric.Int64Histogram - // Networking // // InboundPeers measures the active number of inbound peers @@ -55,9 +48,6 @@ var ( // OutboundPeers measures the active number of outbound peers OutboundPeers metric.Int64Gauge - // DialingPeers measures the active number of peers in the dialing state - DialingPeers metric.Int64Gauge - // Mempool // // NumMempoolTxs measures the number of transaction inside the mempool @@ -156,14 +146,6 @@ func Init(config config.Config) error { otel.SetMeterProvider(provider) meter := provider.Meter(config.MeterName) - if BroadcastTxTimer, err = meter.Int64Histogram( - broadcastTxTimerKey, - metric.WithDescription("broadcast tx duration"), - metric.WithUnit("ms"), - ); err != nil { - return fmt.Errorf("unable to create histogram, %w", err) - } - if BuildBlockTimer, err = meter.Int64Histogram( buildBlockTimerKey, metric.WithDescription("block build duration"), @@ -188,18 +170,10 @@ func Init(config config.Config) error { ); err != nil { return fmt.Errorf("unable to create histogram, %w", err) } + // Initialize OutboundPeers Gauge OutboundPeers.Record(ctx, 0) - if DialingPeers, err = meter.Int64Gauge( - dialingPeersKey, - metric.WithDescription("dialing peer count"), - ); err != nil { - return fmt.Errorf("unable to create histogram, %w", err) - } - // Initialize DialingPeers Gauge - DialingPeers.Record(ctx, 0) - // Mempool // if NumMempoolTxs, err = meter.Int64Histogram( numMempoolTxsKey,