From cd189f169602292146a4870d8ccc7b9b906d53bb Mon Sep 17 00:00:00 2001 From: jiexi Date: Tue, 7 May 2024 11:56:57 -0700 Subject: [PATCH] fix: Fix `StaticIntervalPollingController` not properly stopping polling if `_executePoll` was still pending (#4230) ## Explanation Currently, there is a bug in the StaticIntervalPollingController that causes polling to not be properly stopped if a stop is requested while `_executePoll` has not yet resolved for the current loop. This PR adds a guard to check if the id returned by setTimeout still matches what is in state, indicating that polling is still active. ## References ## Changelog ### `@metamask/polling-controller` - **FIXED**: `StaticIntervalPollingControllerOnly`, `StaticIntervalPollingController`, and `StaticIntervalPollingControllerV1` now properly stops polling when a stop is requested while `_executePoll` has not yet resolved for the current loop ## Checklist - [x] I've updated the test suite for new or updated code as appropriate - [x] I've updated documentation (JSDoc, Markdown, etc.) for new or updated code as appropriate - [ ] I've highlighted breaking changes using the "BREAKING" category above as appropriate --------- Co-authored-by: Alex Donesky --- .../StaticIntervalPollingController.test.ts | 72 ++++++++++++++++--- .../src/StaticIntervalPollingController.ts | 9 ++- 2 files changed, 68 insertions(+), 13 deletions(-) diff --git a/packages/polling-controller/src/StaticIntervalPollingController.test.ts b/packages/polling-controller/src/StaticIntervalPollingController.test.ts index 20bf3ba753..2238fbc111 100644 --- a/packages/polling-controller/src/StaticIntervalPollingController.test.ts +++ b/packages/polling-controller/src/StaticIntervalPollingController.test.ts @@ -1,4 +1,5 @@ import { ControllerMessenger } from '@metamask/base-controller'; +import { createDeferredPromise } from '@metamask/utils'; import { useFakeTimers } from 'sinon'; import { advanceTime } from '../../../tests/helpers'; @@ -6,13 +7,6 @@ import { StaticIntervalPollingController } from './StaticIntervalPollingControll const TICK_TIME = 5; -const createExecutePollMock = () => { - const executePollMock = jest.fn().mockImplementation(async () => { - return true; - }); - return executePollMock; -}; - class ChildBlockTrackerPollingController extends StaticIntervalPollingController< // TODO: Replace `any` with type // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -24,7 +18,18 @@ class ChildBlockTrackerPollingController extends StaticIntervalPollingController // eslint-disable-next-line @typescript-eslint/no-explicit-any any > { - _executePoll = createExecutePollMock(); + executePollPromises: { + reject: (err: unknown) => void; + resolve: () => void; + }[] = []; + + _executePoll = jest.fn().mockImplementation(() => { + const { promise, reject, resolve } = createDeferredPromise({ + suppressUnhandledRejection: true, + }); + this.executePollPromises.push({ reject, resolve }); + return promise; + }); } describe('StaticIntervalPollingController', () => { @@ -48,6 +53,7 @@ describe('StaticIntervalPollingController', () => { controller.setIntervalLength(TICK_TIME); clock = useFakeTimers(); }); + afterEach(() => { clock.restore(); }); @@ -57,6 +63,7 @@ describe('StaticIntervalPollingController', () => { controller.startPollingByNetworkClientId('mainnet'); await advanceTime({ clock, duration: 0 }); expect(controller._executePoll).toHaveBeenCalledTimes(1); + controller.executePollPromises[0].resolve(); await advanceTime({ clock, duration: TICK_TIME }); expect(controller._executePoll).toHaveBeenCalledTimes(2); controller.stopAllPolling(); @@ -70,11 +77,16 @@ describe('StaticIntervalPollingController', () => { await advanceTime({ clock, duration: 0 }); expect(controller._executePoll).toHaveBeenCalledTimes(1); - await advanceTime({ clock, duration: TICK_TIME * 2 }); + controller.executePollPromises[0].resolve(); + await advanceTime({ clock, duration: TICK_TIME }); + controller.executePollPromises[1].resolve(); + await advanceTime({ clock, duration: TICK_TIME }); + controller.executePollPromises[2].resolve(); expect(controller._executePoll).toHaveBeenCalledTimes(3); controller.stopAllPolling(); }); + describe('multiple networkClientIds', () => { it('should poll for each networkClientId', async () => { controller.startPollingByNetworkClientId('mainnet'); @@ -87,6 +99,9 @@ describe('StaticIntervalPollingController', () => { ['mainnet', {}], ['rinkeby', {}], ]); + + controller.executePollPromises[0].resolve(); + controller.executePollPromises[1].resolve(); await advanceTime({ clock, duration: TICK_TIME }); expect(controller._executePoll.mock.calls).toMatchObject([ @@ -95,6 +110,9 @@ describe('StaticIntervalPollingController', () => { ['mainnet', {}], ['rinkeby', {}], ]); + + controller.executePollPromises[2].resolve(); + controller.executePollPromises[3].resolve(); await advanceTime({ clock, duration: TICK_TIME }); expect(controller._executePoll.mock.calls).toMatchObject([ @@ -116,6 +134,7 @@ describe('StaticIntervalPollingController', () => { expect(controller._executePoll.mock.calls).toMatchObject([ ['mainnet', {}], ]); + controller.executePollPromises[0].resolve(); await advanceTime({ clock, duration: TICK_TIME }); controller.startPollingByNetworkClientId('sepolia'); @@ -125,6 +144,8 @@ describe('StaticIntervalPollingController', () => { ['mainnet', {}], ['sepolia', {}], ]); + + controller.executePollPromises[1].resolve(); await advanceTime({ clock, duration: TICK_TIME }); expect(controller._executePoll.mock.calls).toMatchObject([ @@ -132,6 +153,8 @@ describe('StaticIntervalPollingController', () => { ['sepolia', {}], ['mainnet', {}], ]); + + controller.executePollPromises[2].resolve(); await advanceTime({ clock, duration: TICK_TIME }); expect(controller._executePoll.mock.calls).toMatchObject([ @@ -140,6 +163,8 @@ describe('StaticIntervalPollingController', () => { ['mainnet', {}], ['sepolia', {}], ]); + + controller.executePollPromises[3].resolve(); await advanceTime({ clock, duration: TICK_TIME }); expect(controller._executePoll.mock.calls).toMatchObject([ @@ -149,6 +174,8 @@ describe('StaticIntervalPollingController', () => { ['sepolia', {}], ['mainnet', {}], ]); + + controller.executePollPromises[4].resolve(); await advanceTime({ clock, duration: TICK_TIME }); expect(controller._executePoll.mock.calls).toMatchObject([ @@ -168,24 +195,29 @@ describe('StaticIntervalPollingController', () => { const pollingToken = controller.startPollingByNetworkClientId('mainnet'); await advanceTime({ clock, duration: 0 }); expect(controller._executePoll).toHaveBeenCalledTimes(1); + controller.executePollPromises[0].resolve(); await advanceTime({ clock, duration: TICK_TIME }); controller.stopPollingByPollingToken(pollingToken); await advanceTime({ clock, duration: TICK_TIME }); expect(controller._executePoll).toHaveBeenCalledTimes(2); controller.stopAllPolling(); }); + it('should not stop polling if called with one of multiple active polling tokens for a given networkClient', async () => { const pollingToken1 = controller.startPollingByNetworkClientId('mainnet'); await advanceTime({ clock, duration: 0 }); controller.startPollingByNetworkClientId('mainnet'); expect(controller._executePoll).toHaveBeenCalledTimes(1); + controller.executePollPromises[0].resolve(); await advanceTime({ clock, duration: TICK_TIME }); controller.stopPollingByPollingToken(pollingToken1); + controller.executePollPromises[1].resolve(); await advanceTime({ clock, duration: TICK_TIME }); expect(controller._executePoll).toHaveBeenCalledTimes(3); controller.stopAllPolling(); }); + it('should error if no pollingToken is passed', () => { controller.startPollingByNetworkClientId('mainnet'); expect(() => { @@ -195,10 +227,10 @@ describe('StaticIntervalPollingController', () => { }); it('should start and stop polling sessions for different networkClientIds with the same options', async () => { - controller.setIntervalLength(TICK_TIME); const pollToken1 = controller.startPollingByNetworkClientId('mainnet', { address: '0x1', }); + await advanceTime({ clock, duration: 0 }); controller.startPollingByNetworkClientId('mainnet', { address: '0x2' }); await advanceTime({ clock, duration: 0 }); @@ -210,6 +242,10 @@ describe('StaticIntervalPollingController', () => { ['mainnet', { address: '0x2' }], ['sepolia', { address: '0x2' }], ]); + + controller.executePollPromises[0].resolve(); + controller.executePollPromises[1].resolve(); + controller.executePollPromises[2].resolve(); await advanceTime({ clock, duration: TICK_TIME }); expect(controller._executePoll.mock.calls).toMatchObject([ @@ -221,6 +257,9 @@ describe('StaticIntervalPollingController', () => { ['sepolia', { address: '0x2' }], ]); controller.stopPollingByPollingToken(pollToken1); + controller.executePollPromises[3].resolve(); + controller.executePollPromises[4].resolve(); + controller.executePollPromises[5].resolve(); await advanceTime({ clock, duration: TICK_TIME }); expect(controller._executePoll.mock.calls).toMatchObject([ @@ -234,6 +273,19 @@ describe('StaticIntervalPollingController', () => { ['sepolia', { address: '0x2' }], ]); }); + + it('should stop polling session after current iteration if stop is requested while current iteration is still executing', async () => { + const pollingToken = controller.startPollingByNetworkClientId('mainnet'); + await advanceTime({ clock, duration: 0 }); + expect(controller._executePoll).toHaveBeenCalledTimes(1); + controller.stopPollingByPollingToken(pollingToken); + controller.executePollPromises[0].resolve(); + await advanceTime({ clock, duration: TICK_TIME }); + expect(controller._executePoll).toHaveBeenCalledTimes(1); + await advanceTime({ clock, duration: TICK_TIME }); + expect(controller._executePoll).toHaveBeenCalledTimes(1); + controller.stopAllPolling(); + }); }); describe('onPollingCompleteByNetworkClientId', () => { diff --git a/packages/polling-controller/src/StaticIntervalPollingController.ts b/packages/polling-controller/src/StaticIntervalPollingController.ts index 96f1b57656..3d49f02304 100644 --- a/packages/polling-controller/src/StaticIntervalPollingController.ts +++ b/packages/polling-controller/src/StaticIntervalPollingController.ts @@ -50,17 +50,20 @@ function StaticIntervalPollingControllerMixin( const existingInterval = this.#intervalIds[key]; this._stopPollingByPollingTokenSetId(key); - this.#intervalIds[key] = setTimeout( + // eslint-disable-next-line no-multi-assign + const intervalId = (this.#intervalIds[key] = setTimeout( async () => { try { await this._executePoll(networkClientId, options); } catch (error) { console.error(error); } - this._startPollingByNetworkClientId(networkClientId, options); + if (intervalId === this.#intervalIds[key]) { + this._startPollingByNetworkClientId(networkClientId, options); + } }, existingInterval ? this.#intervalLength : 0, - ); + )); } _stopPollingByPollingTokenSetId(key: PollingTokenSetId) {