diff --git a/runner/full.go b/runner/full.go index 1b0cb489..63724231 100644 --- a/runner/full.go +++ b/runner/full.go @@ -19,28 +19,18 @@ import ( "github.com/pganalyze/collector/util" ) -func collectDiffAndSubmit(ctx context.Context, server *state.Server, globalCollectionOpts state.CollectionOpts, logger *util.Logger) (state.PersistedState, state.CollectionStatus, error) { +func collectDiffAndSubmit(ctx context.Context, db *sql.DB, server *state.Server, globalCollectionOpts state.CollectionOpts, logger *util.Logger) (state.PersistedState, state.CollectionStatus, error) { var newState state.PersistedState var err error - var connection *sql.DB - connection, err = postgres.EstablishConnection(ctx, server, logger, globalCollectionOpts, "") + newState, transientState, err := input.CollectFull(ctx, server, db, globalCollectionOpts, logger) if err != nil { - return newState, state.CollectionStatus{}, fmt.Errorf("Failed to connect to database: %s", err) - } - - newState, transientState, err := input.CollectFull(ctx, server, connection, globalCollectionOpts, logger) - if err != nil { - connection.Close() return newState, state.CollectionStatus{}, err } if globalCollectionOpts.TestRun { logger.PrintInfo(" Test collection successful for %s", transientState.Version.Full) } - // This is the easiest way to avoid opening multiple connections to different databases on the same instance - connection.Close() - logsDisabled, logsIgnoreStatement, logsIgnoreDuration, logsDisabledReason := logs.ValidateLogCollectionConfig(server, transientState.Settings) collectionStatus := state.CollectionStatus{ LogSnapshotDisabled: logsDisabled, @@ -89,7 +79,13 @@ func processServer(ctx context.Context, server *state.Server, globalCollectionOp var collectionStatus state.CollectionStatus var err error - err = checkReplicaCollectionDisabled(ctx, server, globalCollectionOpts, logger) + db, err := postgres.EstablishConnection(ctx, server, logger, globalCollectionOpts, "") + if err != nil { + return state.PersistedState{}, state.Grant{}, state.CollectionStatus{}, fmt.Errorf("Failed to connect to database: %s", err) + } + defer db.Close() + + err = checkReplicaCollectionDisabled(ctx, db, server, globalCollectionOpts, logger) if err != nil { return state.PersistedState{}, state.Grant{}, state.CollectionStatus{}, err } @@ -119,7 +115,7 @@ func processServer(ctx context.Context, server *state.Server, globalCollectionOp } runFunc := func() { - newState, collectionStatus, err = collectDiffAndSubmit(ctx, server, globalCollectionOpts, logger) + newState, collectionStatus, err = collectDiffAndSubmit(ctx, db, server, globalCollectionOpts, logger) } var panicErr interface{} @@ -152,19 +148,12 @@ func runCompletionCallback(callbackType string, callbackCmd string, sectionName } } -func checkReplicaCollectionDisabled(ctx context.Context, server *state.Server, opts state.CollectionOpts, logger *util.Logger) error { +func checkReplicaCollectionDisabled(ctx context.Context, db *sql.DB, server *state.Server, opts state.CollectionOpts, logger *util.Logger) error { if !server.Config.SkipIfReplica { return nil } - connection, err := postgres.EstablishConnection(ctx, server, logger, opts, "") - if err != nil { - return fmt.Errorf("Failed to connect to database: %s", err) - } - defer connection.Close() - - var isReplica bool - isReplica, err = postgres.GetIsReplica(ctx, logger, connection) + isReplica, err := postgres.GetIsReplica(ctx, logger, db) if err != nil { return fmt.Errorf("Error checking replication status") }