Skip to content

[tmpnet] Ensure all node runtime methods accept a context #3894

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/fixture/e2e/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func NewTestEnvironment(tc tests.TestContext, flagVars *FlagVars, desiredNetwork

if len(networkDir) > 0 {
var err error
network, err = tmpnet.ReadNetwork(tc.Log(), networkDir)
network, err = tmpnet.ReadNetwork(tc.DefaultContext(), tc.Log(), networkDir)
require.NoError(err)
tc.Log().Info("loaded a network",
zap.String("networkDir", networkDir),
Expand Down Expand Up @@ -223,7 +223,7 @@ func (te *TestEnvironment) GetRandomNodeURI() tmpnet.NodeURI {
// Retrieve the network to target for testing.
func (te *TestEnvironment) GetNetwork() *tmpnet.Network {
tc := te.testContext
network, err := tmpnet.ReadNetwork(tc.Log(), te.NetworkDir)
network, err := tmpnet.ReadNetwork(tc.DefaultContext(), tc.Log(), te.NetworkDir)
require.NoError(tc, err)
return network
}
Expand All @@ -238,7 +238,7 @@ func (te *TestEnvironment) StartPrivateNetwork(network *tmpnet.Network) {
tc := te.testContext
require := require.New(tc)
// Use the same configuration as the shared network
sharedNetwork, err := tmpnet.ReadNetwork(tc.Log(), te.NetworkDir)
sharedNetwork, err := tmpnet.ReadNetwork(tc.DefaultContext(), tc.Log(), te.NetworkDir)
require.NoError(err)
network.DefaultRuntimeConfig = sharedNetwork.DefaultRuntimeConfig

Expand Down
12 changes: 6 additions & 6 deletions tests/fixture/tmpnet/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func BootstrapNewNetwork(

// Stops the nodes of the network configured in the provided directory.
func StopNetwork(ctx context.Context, log logging.Logger, dir string) error {
network, err := ReadNetwork(log, dir)
network, err := ReadNetwork(ctx, log, dir)
if err != nil {
return err
}
Expand All @@ -184,15 +184,15 @@ func StopNetwork(ctx context.Context, log logging.Logger, dir string) error {

// Restarts the nodes of the network configured in the provided directory.
func RestartNetwork(ctx context.Context, log logging.Logger, dir string) error {
network, err := ReadNetwork(log, dir)
network, err := ReadNetwork(ctx, log, dir)
if err != nil {
return err
}
return network.Restart(ctx)
}

// Reads a network from the provided directory.
func ReadNetwork(log logging.Logger, dir string) (*Network, error) {
func ReadNetwork(ctx context.Context, log logging.Logger, dir string) (*Network, error) {
canonicalDir, err := toCanonicalDir(dir)
if err != nil {
return nil, err
Expand All @@ -201,7 +201,7 @@ func ReadNetwork(log logging.Logger, dir string) (*Network, error) {
Dir: canonicalDir,
log: log,
}
if err := network.Read(); err != nil {
if err := network.Read(ctx); err != nil {
return nil, fmt.Errorf("failed to read network: %w", err)
}
if network.DefaultFlags == nil {
Expand Down Expand Up @@ -475,7 +475,7 @@ func (n *Network) StartNode(ctx context.Context, node *Node) error {
return fmt.Errorf("writing node flags: %w", err)
}

if err := node.Start(); err != nil {
if err := node.Start(ctx); err != nil {
// Attempt to stop an unhealthy node to provide some assurance to the caller
// that an error condition will not result in a lingering process.
err = errors.Join(err, node.Stop(ctx))
Expand Down Expand Up @@ -513,7 +513,7 @@ func (n *Network) RestartNode(ctx context.Context, node *Node) error {
// Stops all nodes in the network.
func (n *Network) Stop(ctx context.Context) error {
// Ensure the node state is up-to-date
if err := n.readNodes(); err != nil {
if err := n.readNodes(ctx); err != nil {
return err
}

Expand Down
9 changes: 5 additions & 4 deletions tests/fixture/tmpnet/network_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package tmpnet

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand All @@ -21,11 +22,11 @@ import (
var errMissingNetworkDir = errors.New("failed to write network: missing network directory")

// Read network and node configuration from disk.
func (n *Network) Read() error {
func (n *Network) Read(ctx context.Context) error {
if err := n.readNetwork(); err != nil {
return err
}
if err := n.readNodes(); err != nil {
if err := n.readNodes(ctx); err != nil {
return err
}
return n.readSubnets()
Expand Down Expand Up @@ -57,7 +58,7 @@ func (n *Network) readNetwork() error {
}

// Read the nodes associated with the network from disk.
func (n *Network) readNodes() error {
func (n *Network) readNodes(ctx context.Context) error {
nodes := []*Node{}

// Node configuration is stored in child directories
Expand All @@ -72,7 +73,7 @@ func (n *Network) readNodes() error {

node := NewNode()
dataDir := filepath.Join(n.Dir, entry.Name())
err := node.Read(n, dataDir)
err := node.Read(ctx, n, dataDir)
if errors.Is(err, os.ErrNotExist) {
// If no config file exists, assume this is not the path of a node
continue
Expand Down
7 changes: 5 additions & 2 deletions tests/fixture/tmpnet/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package tmpnet

import (
"context"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -16,6 +17,8 @@ func TestNetworkSerialization(t *testing.T) {

tmpDir := t.TempDir()

ctx := context.Background()

network := NewDefaultNetwork("testnet")
// Validate round-tripping of primary subnet configuration
network.PrimarySubnetConfig = ConfigMap{
Expand All @@ -24,9 +27,9 @@ func TestNetworkSerialization(t *testing.T) {
require.NoError(network.EnsureDefaultConfig(logging.NoLog{}))
require.NoError(network.Create(tmpDir))
// Ensure node runtime is initialized
require.NoError(network.readNodes())
require.NoError(network.readNodes(ctx))

loadedNetwork, err := ReadNetwork(logging.NoLog{}, network.Dir)
loadedNetwork, err := ReadNetwork(ctx, logging.NoLog{}, network.Dir)
require.NoError(err)
for _, key := range loadedNetwork.PreFundedKeys {
// Address() enables comparison with the original network by
Expand Down
16 changes: 8 additions & 8 deletions tests/fixture/tmpnet/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ var (

// NodeRuntime defines the methods required to support running a node.
type NodeRuntime interface {
readState() error
readState(ctx context.Context) error
GetLocalURI(ctx context.Context) (string, func(), error)
GetLocalStakingAddress(ctx context.Context) (netip.AddrPort, func(), error)
InitiateStop() error
Start() error
Start(ctx context.Context) error
InitiateStop(ctx context.Context) error
WaitForStopped(ctx context.Context) error
IsHealthy(ctx context.Context) (bool, error)
}
Expand Down Expand Up @@ -140,23 +140,23 @@ func (n *Node) IsHealthy(ctx context.Context) (bool, error) {
return n.getRuntime().IsHealthy(ctx)
}

func (n *Node) Start() error {
return n.getRuntime().Start()
func (n *Node) Start(ctx context.Context) error {
return n.getRuntime().Start(ctx)
}

func (n *Node) InitiateStop(ctx context.Context) error {
if err := n.SaveMetricsSnapshot(ctx); err != nil {
return err
}
return n.getRuntime().InitiateStop()
return n.getRuntime().InitiateStop(ctx)
}

func (n *Node) WaitForStopped(ctx context.Context) error {
return n.getRuntime().WaitForStopped(ctx)
}

func (n *Node) readState() error {
return n.getRuntime().readState()
func (n *Node) readState(ctx context.Context) error {
return n.getRuntime().readState(ctx)
}

func (n *Node) GetLocalURI(ctx context.Context) (string, func(), error) {
Expand Down
5 changes: 3 additions & 2 deletions tests/fixture/tmpnet/node_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package tmpnet

import (
"context"
"encoding/json"
"fmt"
"os"
Expand Down Expand Up @@ -70,7 +71,7 @@ func (n *Node) writeConfig() error {
return nil
}

func (n *Node) Read(network *Network, dataDir string) error {
func (n *Node) Read(ctx context.Context, network *Network, dataDir string) error {
n.network = network
n.DataDir = dataDir

Expand All @@ -80,7 +81,7 @@ func (n *Node) Read(network *Network, dataDir string) error {
if err := n.EnsureNodeID(); err != nil {
return err
}
return n.readState()
return n.readState(ctx)
}

func (n *Node) Write() error {
Expand Down
14 changes: 8 additions & 6 deletions tests/fixture/tmpnet/process_runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (p *ProcessRuntime) setProcessContext(processContext node.ProcessContext) {
p.node.StakingAddress = processContext.StakingAddress
}

func (p *ProcessRuntime) readState() error {
func (p *ProcessRuntime) readState(_ context.Context) error {
path := p.getProcessContextPath()
bytes, err := os.ReadFile(path)
if errors.Is(err, fs.ErrNotExist) {
Expand All @@ -81,7 +81,7 @@ func (p *ProcessRuntime) readState() error {
// its staking port. The network will start faster with this
// synchronization due to the avoidance of exponential backoff
// if a node tries to connect to a beacon that is not ready.
func (p *ProcessRuntime) Start() error {
func (p *ProcessRuntime) Start(ctx context.Context) error {
log := p.node.network.log

// Avoid attempting to start an already running node.
Expand Down Expand Up @@ -122,7 +122,7 @@ func (p *ProcessRuntime) Start() error {
// a configuration error preventing startup. Such a log entry will be provided to the
// cancelWithCause function so that waitForProcessContext can exit early with an error
// that includes the log entry.
ctx, cancelWithCause := context.WithCancelCause(context.Background())
ctx, cancelWithCause := context.WithCancelCause(ctx)
defer cancelWithCause(nil)
logPath := p.node.DataDir + "/logs/main.log"
go watchLogFileForFatal(ctx, cancelWithCause, log, logPath)
Expand All @@ -145,7 +145,7 @@ func (p *ProcessRuntime) Start() error {
}

// Signals the node process to stop.
func (p *ProcessRuntime) InitiateStop() error {
func (p *ProcessRuntime) InitiateStop(_ context.Context) error {
proc, err := p.getProcess()
if err != nil {
return fmt.Errorf("failed to retrieve process to stop: %w", err)
Expand Down Expand Up @@ -211,7 +211,7 @@ func (p *ProcessRuntime) waitForProcessContext(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, defaultNodeInitTimeout)
defer cancel()
for len(p.node.URI) == 0 {
err := p.readState()
err := p.readState(ctx)
if err != nil {
return fmt.Errorf("failed to read process context for node %q: %w", p.node.NodeID, err)
}
Expand All @@ -229,9 +229,11 @@ func (p *ProcessRuntime) waitForProcessContext(ctx context.Context) error {
// process liveness, the node's process context will be refreshed if
// live or cleared if not running.
func (p *ProcessRuntime) getProcess() (*os.Process, error) {
// This context is not used but a non-nil value must be supplied to satisfy the linter
ctx := context.TODO()
// Read the process context to ensure freshness. The node may have
// stopped or been restarted since last read.
if err := p.readState(); err != nil {
if err := p.readState(ctx); err != nil {
return nil, fmt.Errorf("failed to read process context: %w", err)
}

Expand Down