Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prevent all transaction methods from nil transaction errors #1001

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion v3/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ require (
google.golang.org/protobuf v1.34.2
)


retract v3.22.0 // release process error corrected in v3.22.1

retract v3.25.0 // release process error corrected in v3.25.1
Expand Down
67 changes: 42 additions & 25 deletions v3/newrelic/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@ type Transaction struct {
thread *thread
}

// nilTransaction guards against nil errors when handling a transaction.
func nilTransaction(txn *Transaction) bool {
return txn == nil || txn.thread == nil || txn.thread.txn == nil
}

// End finishes the Transaction. After that, subsequent calls to End or
// other Transaction methods have no effect. All segments and
// instrumentation must be completed before End is called.
func (txn *Transaction) End() {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}

Expand All @@ -55,15 +60,15 @@ func (txn *Transaction) End() {
// The set of options should be the complete set you wish to have in effect,
// just as if you were calling StartTransaction now with the same set of options.
func (txn *Transaction) SetOption(options ...TraceOption) {
if txn == nil || txn.thread == nil || txn.thread.txn == nil {
if nilTransaction(txn) {
return
}
txn.thread.txn.setOption(options...)
}

// Ignore prevents this transaction's data from being recorded.
func (txn *Transaction) Ignore() {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.logAPIError(txn.thread.Ignore(), "ignore transaction", nil)
Expand All @@ -72,7 +77,7 @@ func (txn *Transaction) Ignore() {
// SetName names the transaction. Use a limited set of unique names to
// ensure that Transactions are grouped usefully.
func (txn *Transaction) SetName(name string) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.logAPIError(txn.thread.SetName(name), "set transaction name", nil)
Expand All @@ -84,8 +89,7 @@ func (txn *Transaction) Name() string {
// This is called Name rather than GetName to be consistent with the prevailing naming
// conventions for the Go language, even though the underlying internal call must be called
// something else (like GetName) because there's already a Name struct member.

if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return ""
}
return txn.thread.GetName()
Expand Down Expand Up @@ -117,7 +121,7 @@ func (txn *Transaction) Name() string {
// way to directly control the recorded error's message, class, stacktrace,
// and attributes.
func (txn *Transaction) NoticeError(err error) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.logAPIError(txn.thread.NoticeError(err, false), "notice error", nil)
Expand Down Expand Up @@ -151,7 +155,7 @@ func (txn *Transaction) NoticeError(err error) {
// way to directly control the recorded error's message, class, stacktrace,
// and attributes.
func (txn *Transaction) NoticeExpectedError(err error) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.logAPIError(txn.thread.NoticeError(err, true), "notice error", nil)
Expand All @@ -166,7 +170,7 @@ func (txn *Transaction) NoticeExpectedError(err error) {
// For more information, see:
// https://docs.newrelic.com/docs/agents/manage-apm-agents/agent-metrics/collect-custom-attributes
func (txn *Transaction) AddAttribute(key string, value any) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.logAPIError(txn.thread.AddAttribute(key, value), "add attribute", nil)
Expand All @@ -176,10 +180,9 @@ func (txn *Transaction) AddAttribute(key string, value any) {
// belong to or interact with. This will propogate an attribute containing this information to all events that are
// a child of this transaction, like errors and spans.
func (txn *Transaction) SetUserID(userID string) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}

txn.thread.logAPIError(txn.thread.AddUserID(userID), "set user ID", nil)
}

Expand All @@ -192,6 +195,9 @@ func (txn *Transaction) SetUserID(userID string) {
// as well as log metrics depending on how your application is
// configured.
func (txn *Transaction) RecordLog(log LogData) {
if nilTransaction(txn) {
return
}
event, err := log.toLogEvent()
if err != nil {
txn.Application().app.Error("unable to record log", map[string]any{
Expand All @@ -212,6 +218,9 @@ func (txn *Transaction) RecordLog(log LogData) {
// present, the agent will look for distributed tracing headers using
// Transaction.AcceptDistributedTraceHeaders.
func (txn *Transaction) SetWebRequestHTTP(r *http.Request) {
if nilTransaction(txn) {
return
}
if r == nil {
txn.SetWebRequest(WebRequest{})
return
Expand Down Expand Up @@ -265,7 +274,7 @@ func reqBody(req *http.Request) *BodyBuffer {
// distributed tracing headers using Transaction.AcceptDistributedTraceHeaders.
// Use Transaction.SetWebRequestHTTP if you have a *http.Request.
func (txn *Transaction) SetWebRequest(r WebRequest) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
if IsSecurityAgentPresent() {
Expand All @@ -289,7 +298,7 @@ func (txn *Transaction) SetWebRequest(r WebRequest) {
// package middlewares. Therefore, you probably want to use this only if you
// are writing your own instrumentation middleware.
func (txn *Transaction) SetWebResponse(w http.ResponseWriter) http.ResponseWriter {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return w
}
return txn.thread.SetWebResponse(w)
Expand All @@ -304,7 +313,7 @@ func (txn *Transaction) StartSegmentNow() SegmentStartTime {
}

func (txn *Transaction) startSegmentAt(at time.Time) SegmentStartTime {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return SegmentStartTime{}
}
return txn.thread.startSegmentAt(at)
Expand All @@ -324,7 +333,11 @@ func (txn *Transaction) startSegmentAt(at time.Time) SegmentStartTime {
// // ... code you want to time here ...
// segment.End()
func (txn *Transaction) StartSegment(name string) *Segment {
if IsSecurityAgentPresent() && txn != nil && txn.thread != nil && txn.thread.thread != nil && txn.thread.thread.threadID > 0 {
if nilTransaction(txn) {
return &Segment{} // return a non-nil Segment to avoid nil dereference
}

if IsSecurityAgentPresent() && txn.thread.thread != nil && txn.thread.thread.threadID > 0 {
// async segment start
secureAgent.SendEvent("NEW_GOROUTINE_LINKER", txn.thread.getCsecData())
}
Expand All @@ -346,7 +359,7 @@ func (txn *Transaction) StartSegment(name string) *Segment {
// StartExternalSegment calls InsertDistributedTraceHeaders, so you don't need
// to use it for outbound HTTP calls: Just use StartExternalSegment!
func (txn *Transaction) InsertDistributedTraceHeaders(hdrs http.Header) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.CreateDistributedTracePayload(hdrs)
Expand All @@ -367,7 +380,7 @@ func (txn *Transaction) InsertDistributedTraceHeaders(hdrs http.Header) {
// context headers. Only when those are not found will it look for the New
// Relic distributed tracing header.
func (txn *Transaction) AcceptDistributedTraceHeaders(t TransportType, hdrs http.Header) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.logAPIError(txn.thread.AcceptDistributedTraceHeaders(t, hdrs), "accept trace payload", nil)
Expand All @@ -379,6 +392,10 @@ func (txn *Transaction) AcceptDistributedTraceHeaders(t TransportType, hdrs http
// convert the JSON string to http headers. There is no guarantee that the header data found in JSON
// is correct beyond conforming to the expected types and syntax.
func (txn *Transaction) AcceptDistributedTraceHeadersFromJSON(t TransportType, jsondata string) error {
if nilTransaction(txn) { // do no work if txn is nil
return nil
}

hdrs, err := DistributedTraceHeadersFromJSON(jsondata)
if err != nil {
return err
Expand Down Expand Up @@ -465,7 +482,7 @@ func DistributedTraceHeadersFromJSON(jsondata string) (hdrs http.Header, err err

// Application returns the Application which started the transaction.
func (txn *Transaction) Application() *Application {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return nil
}
return txn.thread.Application()
Expand All @@ -484,7 +501,7 @@ func (txn *Transaction) Application() *Application {
// monitoring is disabled, the application is not connected, or an error
// occurred. It is safe to call the pointer's methods if it is nil.
func (txn *Transaction) BrowserTimingHeader() *BrowserTimingHeader {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return nil
}
b, err := txn.thread.BrowserTimingHeader()
Expand All @@ -506,7 +523,7 @@ func (txn *Transaction) BrowserTimingHeader() *BrowserTimingHeader {
// Note that any segments that end after the transaction ends will not
// be reported.
func (txn *Transaction) NewGoroutine() *Transaction {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return nil
}
newTxn := txn.thread.NewGoroutine()
Expand All @@ -519,7 +536,7 @@ func (txn *Transaction) NewGoroutine() *Transaction {
// GetTraceMetadata returns distributed tracing identifiers. Empty
// string identifiers are returned if the transaction has finished.
func (txn *Transaction) GetTraceMetadata() TraceMetadata {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return TraceMetadata{}
}
return txn.thread.GetTraceMetadata()
Expand All @@ -528,7 +545,7 @@ func (txn *Transaction) GetTraceMetadata() TraceMetadata {
// GetLinkingMetadata returns the fields needed to link data to a trace or
// entity.
func (txn *Transaction) GetLinkingMetadata() LinkingMetadata {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return LinkingMetadata{}
}
return txn.thread.GetLinkingMetadata()
Expand All @@ -539,21 +556,21 @@ func (txn *Transaction) GetLinkingMetadata() LinkingMetadata {
// must be enabled for transactions to be sampled. False is returned if
// the Transaction has finished.
func (txn *Transaction) IsSampled() bool {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return false
}
return txn.thread.IsSampled()
}

func (txn *Transaction) GetCsecAttributes() map[string]any {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return nil
}
return txn.thread.getCsecAttributes()
}

func (txn *Transaction) SetCsecAttributes(key string, value any) {
if txn == nil || txn.thread == nil {
if nilTransaction(txn) {
return
}
txn.thread.setCsecAttributes(key, value)
Expand Down
72 changes: 72 additions & 0 deletions v3/newrelic/transaction_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package newrelic

import (
"fmt"
"net/http"
"testing"
)

func TestTransaction_MethodsWithNilTransaction(t *testing.T) {
var nilTxn *Transaction

defer func() {
if r := recover(); r != nil {
t.Errorf("panics should not occur on methods of Transaction: %v", r)
}
}()

// Ensure no panic occurs when calling methods on a nil transaction
nilTxn.End()
nilTxn.SetOption()
nilTxn.Ignore()
nilTxn.SetName("test")
name := nilTxn.Name()
if name != "" {
t.Errorf("expected empty string, got %s", name)
}
nilTxn.NoticeError(fmt.Errorf("test error"))
nilTxn.NoticeExpectedError(fmt.Errorf("test expected error"))
nilTxn.AddAttribute("key", "value")
nilTxn.SetUserID("user123")
nilTxn.RecordLog(LogData{})
nilTxn.SetWebRequestHTTP(nil)
nilTxn.SetWebRequest(WebRequest{})
nilTxn.SetWebResponse(nil)
nilTxn.StartSegmentNow()
nilTxn.StartSegment("test segment")
nilTxn.InsertDistributedTraceHeaders(http.Header{})
nilTxn.AcceptDistributedTraceHeaders(TransportHTTP, http.Header{})
err := nilTxn.AcceptDistributedTraceHeadersFromJSON(TransportHTTP, "{}")
if err != nil {
t.Errorf("expected no error, got %v", err)
}
app := nilTxn.Application()
if app != nil {
t.Errorf("expected nil, got %v", app)
}
bth := nilTxn.BrowserTimingHeader()
if bth != nil {
t.Errorf("expected nil, got %v", bth)
}
newTxn := nilTxn.NewGoroutine()
if newTxn != nil {
t.Errorf("expected nil, got %v", newTxn)
}
traceMetadata := nilTxn.GetTraceMetadata()
if traceMetadata != (TraceMetadata{}) {
t.Errorf("expected empty TraceMetadata, got %v", traceMetadata)
}
linkingMetadata := nilTxn.GetLinkingMetadata()
if linkingMetadata != (LinkingMetadata{}) {
t.Errorf("expected empty LinkingMetadata, got %v", linkingMetadata)
}
isSampled := nilTxn.IsSampled()
if isSampled {
t.Errorf("expected false, got %v", isSampled)
}
csecAttributes := nilTxn.GetCsecAttributes()
if csecAttributes != nil {
t.Errorf("expected nil, got %v", csecAttributes)
}
nilTxn.SetCsecAttributes("key", "value")
}
Loading