diff --git a/tracer.go b/tracer.go index 58ca99f7e..a5cd7b3ee 100644 --- a/tracer.go +++ b/tracer.go @@ -25,6 +25,24 @@ type TraceQueryEndData struct { Err error } +type MultiQueryTracer struct { + Tracers []QueryTracer +} + +func (m *MultiQueryTracer) TraceQueryStart(ctx context.Context, conn *Conn, data TraceQueryStartData) context.Context { + for _, t := range m.Tracers { + ctx = t.TraceQueryStart(ctx, conn, data) + } + + return ctx +} + +func (m *MultiQueryTracer) TraceQueryEnd(ctx context.Context, conn *Conn, data TraceQueryEndData) { + for _, t := range m.Tracers { + t.TraceQueryEnd(ctx, conn, data) + } +} + // BatchTracer traces SendBatch. type BatchTracer interface { // TraceBatchStart is called at the beginning of SendBatch calls. The returned context is used for the diff --git a/tracer_test.go b/tracer_test.go index a0fea71e6..b2a51e503 100644 --- a/tracer_test.go +++ b/tracer_test.go @@ -97,6 +97,26 @@ func (tt *testTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnect } } +type testMultiQueryTracer []testTracer + +func (tmqt *testMultiQueryTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + for _, tt := range *tmqt { + if tt.traceQueryStart != nil { + return tt.traceQueryStart(ctx, conn, data) + } + } + + return ctx +} + +func (tmqt *testMultiQueryTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + for _, tt := range *tmqt { + if tt.traceQueryEnd != nil { + tt.traceQueryEnd(ctx, conn, data) + } + } +} + func TestTraceExec(t *testing.T) { t.Parallel() @@ -179,6 +199,49 @@ func TestTraceQuery(t *testing.T) { }) } +func TestMultiTraceQuery(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + multiQueryTracer := &testMultiQueryTracer{*tracer} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = multiQueryTracer + return config + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + traceQueryStartCalled := false + (*multiQueryTracer)[0].traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + traceQueryStartCalled = true + require.Equal(t, `select $1::text`, data.SQL) + require.Len(t, data.Args, 1) + require.Equal(t, `testing`, data.Args[0]) + return context.WithValue(ctx, ctxKey("fromTraceQueryStart"), "foo") + } + + traceQueryEndCalled := false + (*multiQueryTracer)[0].traceQueryEnd = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + traceQueryEndCalled = true + require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceQueryStart"))) + require.Equal(t, `SELECT 1`, data.CommandTag.String()) + require.NoError(t, data.Err) + } + + var s string + err := conn.QueryRow(ctx, `select $1::text`, "testing").Scan(&s) + require.NoError(t, err) + require.Equal(t, "testing", s) + require.True(t, traceQueryStartCalled) + require.True(t, traceQueryEndCalled) + }) +} + func TestTraceBatchNormal(t *testing.T) { t.Parallel()