diff --git a/plugins/node/instrumentation-dataloader/src/instrumentation.ts b/plugins/node/instrumentation-dataloader/src/instrumentation.ts index 21a192484f..c97eca3edb 100644 --- a/plugins/node/instrumentation-dataloader/src/instrumentation.ts +++ b/plugins/node/instrumentation-dataloader/src/instrumentation.ts @@ -91,64 +91,75 @@ export class DataloaderInstrumentation extends InstrumentationBase + ): Dataloader.BatchLoadFn { + const instrumentation = this; + + return function patchedBatchLoadFn( + this: DataloaderInternal, + ...args: Parameters> + ) { + if ( + !instrumentation.isEnabled() || + !instrumentation.shouldCreateSpans() + ) { + return batchLoadFn.call(this, ...args); + } + + const parent = context.active(); + const span = instrumentation.tracer.startSpan( + instrumentation.getSpanName(this, 'batch'), + { links: this._batch?.spanLinks as Link[] | undefined }, + parent + ); + + return context.with(trace.setSpan(parent, span), () => { + return (batchLoadFn.apply(this, args) as Promise) + .then(value => { + span.end(); + return value; + }) + .catch(err => { + span.recordException(err); + span.setStatus({ + code: SpanStatusCode.ERROR, + message: err.message, + }); + span.end(); + throw err; + }); + }); + }; + } + private _getPatchedConstructor( constructor: typeof Dataloader ): typeof Dataloader { - const prototype = constructor.prototype; const instrumentation = this; + const prototype = constructor.prototype; + + if (!instrumentation.isEnabled()) { + return constructor; + } function PatchedDataloader( + this: DataloaderInternal, ...args: ConstructorParameters ) { - const inst = new constructor(...args) as DataloaderInternal; - - if (!instrumentation.isEnabled()) { - return inst; - } + // BatchLoadFn is the first constructor argument + // https://github.com/graphql/dataloader/blob/77c2cd7ca97e8795242018ebc212ce2487e729d2/src/index.js#L47 + if (typeof args[0] === 'function') { + if (isWrapped(args[0])) { + instrumentation._unwrap(args, 0); + } - if (isWrapped(inst._batchLoadFn)) { - instrumentation._unwrap(inst, '_batchLoadFn'); + args[0] = instrumentation._wrapBatchLoadFn( + args[0] + ) as Dataloader.BatchLoadFn; } - instrumentation._wrap(inst, '_batchLoadFn', original => { - return function patchedBatchLoadFn( - this: DataloaderInternal, - ...args: Parameters> - ) { - if ( - !instrumentation.isEnabled() || - !instrumentation.shouldCreateSpans() - ) { - return original.call(this, ...args); - } - - const parent = context.active(); - const span = instrumentation.tracer.startSpan( - instrumentation.getSpanName(inst, 'batch'), - { links: this._batch?.spanLinks as Link[] | undefined }, - parent - ); - - return context.with(trace.setSpan(parent, span), () => { - return (original.apply(this, args) as Promise) - .then(value => { - span.end(); - return value; - }) - .catch(err => { - span.recordException(err); - span.setStatus({ - code: SpanStatusCode.ERROR, - message: err.message, - }); - span.end(); - throw err; - }); - }); - }; - }); - - return inst; + return constructor.apply(this, args); } PatchedDataloader.prototype = prototype; diff --git a/plugins/node/instrumentation-dataloader/test/dataloader.test.ts b/plugins/node/instrumentation-dataloader/test/dataloader.test.ts index 319bde4d03..7e10081756 100644 --- a/plugins/node/instrumentation-dataloader/test/dataloader.test.ts +++ b/plugins/node/instrumentation-dataloader/test/dataloader.test.ts @@ -31,6 +31,7 @@ extraInstrumentation.disable(); import * as assert from 'assert'; import * as Dataloader from 'dataloader'; +import * as crypto from 'crypto'; describe('DataloaderInstrumentation', () => { let dataloader: Dataloader; @@ -335,4 +336,28 @@ describe('DataloaderInstrumentation', () => { assert.deepStrictEqual(await alternativeDataloader.loadMany(['test']), [1]); assert.strictEqual(memoryExporter.getFinishedSpans().length, 5); }); + + it('should not prune custom methods', async () => { + const getMd5HashFromIdx = (idx: number) => + crypto.createHash('md5').update(String(idx)).digest('hex'); + + class CustomDataLoader extends Dataloader { + constructor() { + super(async keys => keys.map((_, idx) => getMd5HashFromIdx(idx))); + } + + public async customLoad() { + return this.load('test'); + } + } + + const customDataloader = new CustomDataLoader(); + await customDataloader.customLoad(); + + assert.strictEqual(memoryExporter.getFinishedSpans().length, 2); + const [batchSpan, loadSpan] = memoryExporter.getFinishedSpans(); + + assert.strictEqual(loadSpan.name, 'dataloader.load'); + assert.strictEqual(batchSpan.name, 'dataloader.batch'); + }); });