From 65416972f9977f66c09da521c0df480db9d5eda5 Mon Sep 17 00:00:00 2001 From: Elliot Winkler Date: Fri, 24 May 2024 10:07:21 -0600 Subject: [PATCH] TokenRatesController: providerConfig -> selectedNetworkClientId (#4317) The `providerConfig` state property is being removed from NetworkController. Currently this property is used in TokenRatesController to get the currently selected chain and ticker, but `selectedNetworkClientId` can be used instead to get the currently selected network client, and then the appropriate information can be read from that object. This commit makes that transition so that we can fully drop `providerConfig`. --- packages/assets-controllers/CHANGELOG.md | 2 + .../src/TokenRatesController.test.ts | 296 ++++++++++++------ .../src/TokenRatesController.ts | 8 +- 3 files changed, 216 insertions(+), 90 deletions(-) diff --git a/packages/assets-controllers/CHANGELOG.md b/packages/assets-controllers/CHANGELOG.md index 4b4d811666..950f291d68 100644 --- a/packages/assets-controllers/CHANGELOG.md +++ b/packages/assets-controllers/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - This should be functionally equivalent, but is being noted anyway. - NftDetectionController now makes use of `selectedNetworkClientId` when responding to changes in NetworkController state to capture the currently selected chain rather than `providerConfig` ([#4307](https://github.com/MetaMask/core/pull/4307)) - This should be functionally equivalent, but is being noted anyway. +- TokenRatesController now makes use of `selectedNetworkClientId` when responding to changes in NetworkController state to capture the currently selected chain rather than `providerConfig` ([#4317](https://github.com/MetaMask/core/pull/4317)) + - This should be functionally equivalent, but is being noted anyway. ## [29.0.0] diff --git a/packages/assets-controllers/src/TokenRatesController.test.ts b/packages/assets-controllers/src/TokenRatesController.test.ts index a82a195ff7..d5d0ae8701 100644 --- a/packages/assets-controllers/src/TokenRatesController.test.ts +++ b/packages/assets-controllers/src/TokenRatesController.test.ts @@ -1,16 +1,28 @@ import { + ChainId, + InfuraNetworkType, NetworksTicker, toChecksumHexAddress, toHex, } from '@metamask/controller-utils'; -import type { NetworkState } from '@metamask/network-controller'; +import type { + NetworkClientConfiguration, + NetworkClientId, + NetworkState, +} from '@metamask/network-controller'; +import { defaultState as defaultNetworkState } from '@metamask/network-controller'; import type { PreferencesState } from '@metamask/preferences-controller'; import type { Hex } from '@metamask/utils'; import { add0x } from '@metamask/utils'; +import assert from 'assert'; import nock from 'nock'; import { useFakeTimers } from 'sinon'; import { advanceTime, flushPromises } from '../../../tests/helpers'; +import { + buildCustomNetworkClientConfiguration, + buildMockGetNetworkClientById, +} from '../../network-controller/tests/helpers'; import { TOKEN_PRICES_BATCH_SIZE } from './assetsUtil'; import type { AbstractTokenPricesService, @@ -738,9 +750,13 @@ describe('TokenRatesController', () => { describe('when polling is active', () => { it('should update exchange rates when ticker changes', async () => { - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let networkStateChangeListener: (state: any) => Promise; + const getNetworkClientById = buildMockGetNetworkClientById({ + 'AAAA-BBBB-CCCC-DDDD': buildCustomNetworkClientConfiguration({ + chainId: toHex(1337), + ticker: 'NEW', + }), + }); + let networkStateChangeListener: (state: NetworkState) => Promise; const onNetworkStateChange = jest .fn() .mockImplementation((listener) => { @@ -748,7 +764,7 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, - getNetworkClientById: jest.fn(), + getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', selectedAddress: defaultSelectedAddress, @@ -764,16 +780,21 @@ describe('TokenRatesController', () => { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await networkStateChangeListener!({ - providerConfig: { chainId: toHex(1337), ticker: 'NEW' }, + ...defaultNetworkState, + selectedNetworkClientId: 'AAAA-BBBB-CCCC-DDDD', }); expect(updateExchangeRatesSpy).toHaveBeenCalledTimes(1); }); it('should update exchange rates when chain ID changes', async () => { - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let networkStateChangeListener: (state: any) => Promise; + const getNetworkClientById = buildMockGetNetworkClientById({ + 'AAAA-BBBB-CCCC-DDDD': buildCustomNetworkClientConfiguration({ + chainId: toHex(1338), + ticker: 'TEST', + }), + }); + let networkStateChangeListener: (state: NetworkState) => Promise; const onNetworkStateChange = jest .fn() .mockImplementation((listener) => { @@ -781,7 +802,7 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, - getNetworkClientById: jest.fn(), + getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', selectedAddress: defaultSelectedAddress, @@ -797,16 +818,21 @@ describe('TokenRatesController', () => { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await networkStateChangeListener!({ - providerConfig: { chainId: toHex(1338), ticker: 'TEST' }, + ...defaultNetworkState, + selectedNetworkClientId: 'AAAA-BBBB-CCCC-DDDD', }); expect(updateExchangeRatesSpy).toHaveBeenCalledTimes(1); }); it('should clear contractExchangeRates state when ticker changes', async () => { - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let networkStateChangeListener: (state: any) => Promise; + const getNetworkClientById = buildMockGetNetworkClientById({ + 'AAAA-BBBB-CCCC-DDDD': buildCustomNetworkClientConfiguration({ + chainId: toHex(1337), + ticker: 'NEW', + }), + }); + let networkStateChangeListener: (state: NetworkState) => Promise; const onNetworkStateChange = jest .fn() .mockImplementation((listener) => { @@ -814,7 +840,7 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, - getNetworkClientById: jest.fn(), + getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', selectedAddress: defaultSelectedAddress, @@ -828,16 +854,21 @@ describe('TokenRatesController', () => { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await networkStateChangeListener!({ - providerConfig: { chainId: toHex(1337), ticker: 'NEW' }, + ...defaultNetworkState, + selectedNetworkClientId: 'AAAA-BBBB-CCCC-DDDD', }); expect(controller.state.contractExchangeRates).toStrictEqual({}); }); it('should clear contractExchangeRates state when chain ID changes', async () => { - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let networkStateChangeListener: (state: any) => Promise; + const getNetworkClientById = buildMockGetNetworkClientById({ + 'AAAA-BBBB-CCCC-DDDD': buildCustomNetworkClientConfiguration({ + chainId: toHex(1338), + ticker: 'TEST', + }), + }); + let networkStateChangeListener: (state: NetworkState) => Promise; const onNetworkStateChange = jest .fn() .mockImplementation((listener) => { @@ -845,7 +876,7 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, - getNetworkClientById: jest.fn(), + getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', selectedAddress: defaultSelectedAddress, @@ -859,16 +890,21 @@ describe('TokenRatesController', () => { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await networkStateChangeListener!({ - providerConfig: { chainId: toHex(1338), ticker: 'TEST' }, + ...defaultNetworkState, + selectedNetworkClientId: 'AAAA-BBBB-CCCC-DDDD', }); expect(controller.state.contractExchangeRates).toStrictEqual({}); }); it('should not update exchange rates when network state changes without a ticker/chain id change', async () => { - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let networkStateChangeListener: (state: any) => Promise; + const getNetworkClientById = buildMockGetNetworkClientById({ + 'AAAA-BBBB-CCCC-DDDD': buildCustomNetworkClientConfiguration({ + chainId: toHex(1337), + ticker: 'TEST', + }), + }); + let networkStateChangeListener: (state: NetworkState) => Promise; const onNetworkStateChange = jest .fn() .mockImplementation((listener) => { @@ -876,7 +912,7 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, - getNetworkClientById: jest.fn(), + getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', selectedAddress: defaultSelectedAddress, @@ -892,7 +928,8 @@ describe('TokenRatesController', () => { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await networkStateChangeListener!({ - providerConfig: { chainId: toHex(1337), ticker: 'TEST' }, + ...defaultNetworkState, + selectedNetworkClientId: 'AAAA-BBBB-CCCC-DDDD', }); expect(updateExchangeRatesSpy).not.toHaveBeenCalled(); @@ -901,9 +938,13 @@ describe('TokenRatesController', () => { describe('when polling is inactive', () => { it('should not update exchange rates when ticker changes', async () => { - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let networkStateChangeListener: (state: any) => Promise; + const getNetworkClientById = buildMockGetNetworkClientById({ + 'AAAA-BBBB-CCCC-DDDD': buildCustomNetworkClientConfiguration({ + chainId: toHex(1337), + ticker: 'NEW', + }), + }); + let networkStateChangeListener: (state: NetworkState) => Promise; const onNetworkStateChange = jest .fn() .mockImplementation((listener) => { @@ -911,7 +952,7 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, - getNetworkClientById: jest.fn(), + getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', selectedAddress: defaultSelectedAddress, @@ -926,16 +967,21 @@ describe('TokenRatesController', () => { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await networkStateChangeListener!({ - providerConfig: { chainId: toHex(1337), ticker: 'NEW' }, + ...defaultNetworkState, + selectedNetworkClientId: 'AAAA-BBBB-CCCC-DDDD', }); expect(updateExchangeRatesSpy).not.toHaveBeenCalled(); }); it('should not update exchange rates when chain ID changes', async () => { - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let networkStateChangeListener: (state: any) => Promise; + const getNetworkClientById = buildMockGetNetworkClientById({ + 'AAAA-BBBB-CCCC-DDDD': buildCustomNetworkClientConfiguration({ + chainId: toHex(1338), + ticker: 'TEST', + }), + }); + let networkStateChangeListener: (state: NetworkState) => Promise; const onNetworkStateChange = jest .fn() .mockImplementation((listener) => { @@ -943,7 +989,7 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, - getNetworkClientById: jest.fn(), + getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', selectedAddress: defaultSelectedAddress, @@ -958,16 +1004,21 @@ describe('TokenRatesController', () => { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await networkStateChangeListener!({ - providerConfig: { chainId: toHex(1338), ticker: 'TEST' }, + ...defaultNetworkState, + selectedNetworkClientId: 'AAAA-BBBB-CCCC-DDDD', }); expect(updateExchangeRatesSpy).not.toHaveBeenCalled(); }); it('should clear contractExchangeRates state when ticker changes', async () => { - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let networkStateChangeListener: (state: any) => Promise; + const getNetworkClientById = buildMockGetNetworkClientById({ + 'AAAA-BBBB-CCCC-DDDD': buildCustomNetworkClientConfiguration({ + chainId: toHex(1337), + ticker: 'NEW', + }), + }); + let networkStateChangeListener: (state: NetworkState) => Promise; const onNetworkStateChange = jest .fn() .mockImplementation((listener) => { @@ -975,7 +1026,7 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, - getNetworkClientById: jest.fn(), + getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', selectedAddress: defaultSelectedAddress, @@ -988,16 +1039,21 @@ describe('TokenRatesController', () => { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await networkStateChangeListener!({ - providerConfig: { chainId: toHex(1337), ticker: 'NEW' }, + ...defaultNetworkState, + selectedNetworkClientId: 'AAAA-BBBB-CCCC-DDDD', }); expect(controller.state.contractExchangeRates).toStrictEqual({}); }); it('should clear contractExchangeRates state when chain ID changes', async () => { - // TODO: Replace `any` with type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let networkStateChangeListener: (state: any) => Promise; + const getNetworkClientById = buildMockGetNetworkClientById({ + 'AAAA-BBBB-CCCC-DDDD': buildCustomNetworkClientConfiguration({ + chainId: toHex(1338), + ticker: 'TEST', + }), + }); + let networkStateChangeListener: (state: NetworkState) => Promise; const onNetworkStateChange = jest .fn() .mockImplementation((listener) => { @@ -1005,7 +1061,7 @@ describe('TokenRatesController', () => { }); const controller = new TokenRatesController({ interval: 100, - getNetworkClientById: jest.fn(), + getNetworkClientById, chainId: toHex(1337), ticker: 'TEST', selectedAddress: defaultSelectedAddress, @@ -1018,7 +1074,8 @@ describe('TokenRatesController', () => { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion await networkStateChangeListener!({ - providerConfig: { chainId: toHex(1338), ticker: 'TEST' }, + ...defaultNetworkState, + selectedNetworkClientId: 'AAAA-BBBB-CCCC-DDDD', }); expect(controller.state.contractExchangeRates).toStrictEqual({}); @@ -1604,7 +1661,7 @@ describe('TokenRatesController', () => { await callUpdateExchangeRatesMethod({ allTokens: { - [toHex(1)]: { + [ChainId.mainnet]: { [controller.config.selectedAddress]: [ { address: tokenAddress, @@ -1615,11 +1672,12 @@ describe('TokenRatesController', () => { ], }, }, - chainId: toHex(1), + chainId: ChainId.mainnet, controller, controllerEvents, method, nativeCurrency: 'ETH', + selectedNetworkClientId: InfuraNetworkType.mainnet, }); expect(controller.state.contractExchangeRates).toStrictEqual({}); @@ -1638,7 +1696,7 @@ describe('TokenRatesController', () => { await callUpdateExchangeRatesMethod({ allTokens: { // These tokens are for the right chain but wrong account - [toHex(1)]: { + [ChainId.mainnet]: { [differentAccount]: [ { address: tokenAddress, @@ -1660,11 +1718,12 @@ describe('TokenRatesController', () => { ], }, }, - chainId: toHex(1), + chainId: ChainId.mainnet, controller, controllerEvents, method, nativeCurrency: 'ETH', + selectedNetworkClientId: InfuraNetworkType.mainnet, }); expect(controller.state.contractExchangeRates).toStrictEqual({}); @@ -1689,7 +1748,7 @@ describe('TokenRatesController', () => { async () => await callUpdateExchangeRatesMethod({ allTokens: { - [toHex(1)]: { + [ChainId.mainnet]: { [controller.config.selectedAddress]: [ { address: tokenAddress, @@ -1700,11 +1759,12 @@ describe('TokenRatesController', () => { ], }, }, - chainId: toHex(1), + chainId: ChainId.mainnet, controller, controllerEvents, method, nativeCurrency: 'ETH', + selectedNetworkClientId: InfuraNetworkType.mainnet, }), ).rejects.toThrow('Failed to fetch'); expect(controller.state.contractExchangeRates).toStrictEqual({}); @@ -1716,7 +1776,7 @@ describe('TokenRatesController', () => { }); it('fetches rates for all tokens in batches', async () => { - const chainId = toHex(1); + const chainId = ChainId.mainnet; const ticker = 'ETH'; const tokenAddresses = [...new Array(200).keys()] .map(buildAddress) @@ -1750,6 +1810,7 @@ describe('TokenRatesController', () => { controllerEvents, method, nativeCurrency: ticker, + selectedNetworkClientId: InfuraNetworkType.mainnet, }); const numBatches = Math.ceil( @@ -1795,7 +1856,7 @@ describe('TokenRatesController', () => { async ({ controller, controllerEvents }) => { await callUpdateExchangeRatesMethod({ allTokens: { - [toHex(1)]: { + [ChainId.mainnet]: { [controller.config.selectedAddress]: [ { address: tokenAddresses[0], @@ -1812,11 +1873,12 @@ describe('TokenRatesController', () => { ], }, }, - chainId: toHex(1), + chainId: ChainId.mainnet, controller, controllerEvents, method, nativeCurrency: 'ETH', + selectedNetworkClientId: InfuraNetworkType.mainnet, }); expect(controller.state).toMatchInlineSnapshot(` @@ -1908,6 +1970,12 @@ describe('TokenRatesController', () => { } it('updates exchange rates when native currency is not supported by the Price API', async () => { + const selectedNetworkClientId = 'AAAA-BBBB-CCCC-DDDD'; + const selectedNetworkClientConfiguration = + buildCustomNetworkClientConfiguration({ + chainId: toHex(137), + ticker: 'UNSUPPORTED', + }); const tokenAddresses = [ '0x0000000000000000000000000000000000000001', '0x0000000000000000000000000000000000000002', @@ -1935,16 +2003,23 @@ describe('TokenRatesController', () => { .get('/data/price') .query({ fsym: 'ETH', - tsyms: 'UNSUPPORTED', + tsyms: selectedNetworkClientConfiguration.ticker, }) - .reply(200, { UNSUPPORTED: 0.5 }); // .5 eth to 1 matic + .reply(200, { [selectedNetworkClientConfiguration.ticker]: 0.5 }); // .5 eth to 1 matic await withController( - { options: { tokenPricesService } }, + { + options: { + tokenPricesService, + }, + mockNetworkClientConfigurationsByNetworkClientId: { + [selectedNetworkClientId]: selectedNetworkClientConfiguration, + }, + }, async ({ controller, controllerEvents }) => { await callUpdateExchangeRatesMethod({ allTokens: { - [toHex(137)]: { + [selectedNetworkClientConfiguration.chainId]: { [controller.config.selectedAddress]: [ { address: tokenAddresses[0], @@ -1961,11 +2036,12 @@ describe('TokenRatesController', () => { ], }, }, - chainId: toHex(137), + chainId: selectedNetworkClientConfiguration.chainId, controller, controllerEvents, method, - nativeCurrency: 'UNSUPPORTED', + nativeCurrency: selectedNetworkClientConfiguration.ticker, + selectedNetworkClientId, }); // token value in terms of matic should be (token value in eth) * (eth value in matic) @@ -1990,15 +2066,19 @@ describe('TokenRatesController', () => { }); it('fetches rates for all tokens in batches when native currency is not supported by the Price API', async () => { - const chainId = toHex(1); - const ticker = 'UNSUPPORTED'; + const selectedNetworkClientId = 'AAAA-BBBB-CCCC-DDDD'; + const selectedNetworkClientConfiguration = + buildCustomNetworkClientConfiguration({ + chainId: toHex(999), + ticker: 'UNSUPPORTED', + }); const tokenAddresses = [...new Array(200).keys()] .map(buildAddress) .sort(); const tokenPricesService = buildMockTokenPricesService({ fetchTokenPrices: fetchTokenPricesWithIncreasingPriceForEachToken, validateCurrencySupported: (currency: unknown): currency is string => { - return currency !== ticker; + return currency !== selectedNetworkClientConfiguration.ticker; }, }); const fetchTokenPricesSpy = jest.spyOn( @@ -2012,28 +2092,31 @@ describe('TokenRatesController', () => { .get('/data/price') .query({ fsym: 'ETH', - tsyms: ticker, + tsyms: selectedNetworkClientConfiguration.ticker, }) - .reply(200, { [ticker]: 0.5 }); + .reply(200, { [selectedNetworkClientConfiguration.ticker]: 0.5 }); await withController( { options: { - ticker, tokenPricesService, }, + mockNetworkClientConfigurationsByNetworkClientId: { + [selectedNetworkClientId]: selectedNetworkClientConfiguration, + }, }, async ({ controller, controllerEvents }) => { await callUpdateExchangeRatesMethod({ allTokens: { - [chainId]: { + [selectedNetworkClientConfiguration.chainId]: { [controller.config.selectedAddress]: tokens, }, }, - chainId, + chainId: selectedNetworkClientConfiguration.chainId, controller, controllerEvents, method, - nativeCurrency: ticker, + nativeCurrency: selectedNetworkClientConfiguration.ticker, + selectedNetworkClientId, }); const numBatches = Math.ceil( @@ -2043,7 +2126,7 @@ describe('TokenRatesController', () => { for (let i = 1; i <= numBatches; i++) { expect(fetchTokenPricesSpy).toHaveBeenNthCalledWith(i, { - chainId, + chainId: selectedNetworkClientConfiguration.chainId, tokenAddresses: tokenAddresses.slice( (i - 1) * TOKEN_PRICES_BATCH_SIZE, i * TOKEN_PRICES_BATCH_SIZE, @@ -2056,6 +2139,12 @@ describe('TokenRatesController', () => { }); it('sets rates to undefined when chain is not supported by the Price API', async () => { + const selectedNetworkClientId = 'AAAA-BBBB-CCCC-DDDD'; + const selectedNetworkClientConfiguration = + buildCustomNetworkClientConfiguration({ + chainId: toHex(999), + ticker: 'TST', + }); const tokenAddresses = [ '0x0000000000000000000000000000000000000001', '0x0000000000000000000000000000000000000002', @@ -2080,11 +2169,18 @@ describe('TokenRatesController', () => { ) as unknown as AbstractTokenPricesService['validateChainIdSupported'], }); await withController( - { options: { tokenPricesService } }, + { + options: { + tokenPricesService, + }, + mockNetworkClientConfigurationsByNetworkClientId: { + [selectedNetworkClientId]: selectedNetworkClientConfiguration, + }, + }, async ({ controller, controllerEvents }) => { await callUpdateExchangeRatesMethod({ allTokens: { - [toHex(999)]: { + [selectedNetworkClientConfiguration.chainId]: { [controller.config.selectedAddress]: [ { address: tokenAddresses[0], @@ -2101,11 +2197,12 @@ describe('TokenRatesController', () => { ], }, }, - chainId: toHex(999), + chainId: selectedNetworkClientConfiguration.chainId, controller, controllerEvents, method, - nativeCurrency: 'TST', + nativeCurrency: selectedNetworkClientConfiguration.ticker, + selectedNetworkClientId, }); expect(controller.state).toMatchInlineSnapshot(` @@ -2171,7 +2268,8 @@ describe('TokenRatesController', () => { ], }, }, - chainId: toHex(1), + chainId: ChainId.mainnet, + selectedNetworkClientId: InfuraNetworkType.mainnet, controller, controllerEvents, method, @@ -2232,6 +2330,10 @@ type PartialConstructorParameters = { options?: Partial[0]>; config?: Partial; state?: Partial; + mockNetworkClientConfigurationsByNetworkClientId?: Record< + NetworkClientId, + NetworkClientConfiguration + >; }; type WithControllerArgs = @@ -2250,20 +2352,29 @@ type WithControllerArgs = async function withController( ...args: WithControllerArgs ) { - const [{ options, config, state }, testFunction] = - args.length === 2 - ? args - : [{ options: undefined, config: undefined, state: undefined }, args[0]]; + const [ + { + options = {}, + config = {}, + state = {}, + mockNetworkClientConfigurationsByNetworkClientId = {}, + }, + testFunction, + ] = args.length === 2 ? args : [{}, args[0]]; // explit cast used here because we know the `on____` functions are always // set in the constructor. const controllerEvents = {} as ControllerEvents; + const getNetworkClientById = buildMockGetNetworkClientById( + mockNetworkClientConfigurationsByNetworkClientId, + ); + const controllerOptions: ConstructorParameters< typeof TokenRatesController >[0] = { chainId: toHex(1), - getNetworkClientById: jest.fn(), + getNetworkClientById, onNetworkStateChange: (listener) => { controllerEvents.networkStateChange = listener; }, @@ -2315,6 +2426,8 @@ async function withController( * network we're getting updated exchange rates for. * @param args.setChainAsCurrent - When calling `updateExchangeRatesByChainId`, * this determines whether to set the chain as the globally selected chain. + * @param args.selectedNetworkClientId - The network client ID to use if + * `setChainAsCurrent` is true. */ async function callUpdateExchangeRatesMethod({ allTokens, @@ -2323,14 +2436,16 @@ async function callUpdateExchangeRatesMethod({ controllerEvents, method, nativeCurrency, + selectedNetworkClientId, setChainAsCurrent = true, }: { allTokens: TokenRatesConfig['allTokens']; - chainId: TokenRatesConfig['chainId']; + chainId: Hex; controller: TokenRatesController; controllerEvents: ControllerEvents; method: 'updateExchangeRates' | 'updateExchangeRatesByChainId'; nativeCurrency: TokenRatesConfig['nativeCurrency']; + selectedNetworkClientId?: NetworkClientId; setChainAsCurrent?: boolean; }) { if (method === 'updateExchangeRates' && !setChainAsCurrent) { @@ -2346,17 +2461,22 @@ async function callUpdateExchangeRatesMethod({ controllerEvents.tokensStateChange({ allDetectedTokens: {}, allTokens }); if (setChainAsCurrent) { + assert( + selectedNetworkClientId, + 'The "selectedNetworkClientId" option must be given if the "setChainAsCurrent" flag is also given', + ); + // We're using controller events here instead of calling `configure` // because `configure` does not update internal controller state correctly. // As with many BaseControllerV1-based controllers, runtime config // modification is allowed by the API but not supported in practice. + // + // @ts-expect-error Note that the state given here is intentionally + // incomplete because the controller only uses this one property, and the + // tests are written to only consider it. We want this to break if we start + // relying on more properties, as we'd need to update the tests accordingly. controllerEvents.networkStateChange({ - // Note that the state given here is intentionally incomplete because the - // controller only uses these two properties, and the tests are written to - // only consider these two. We want this to break if we start relying on - // more, as we'd need to update the tests accordingly. - // @ts-expect-error Intentionally incomplete state - providerConfig: { chainId, ticker: nativeCurrency }, + selectedNetworkClientId, }); } diff --git a/packages/assets-controllers/src/TokenRatesController.ts b/packages/assets-controllers/src/TokenRatesController.ts index 099e500004..a564e8b5a0 100644 --- a/packages/assets-controllers/src/TokenRatesController.ts +++ b/packages/assets-controllers/src/TokenRatesController.ts @@ -254,8 +254,12 @@ export class TokenRatesController extends StaticIntervalPollingControllerV1< } }); - onNetworkStateChange(async ({ providerConfig }) => { - const { chainId, ticker } = providerConfig; + onNetworkStateChange(async ({ selectedNetworkClientId }) => { + const selectedNetworkClient = getNetworkClientById( + selectedNetworkClientId, + ); + const { chainId, ticker } = selectedNetworkClient.configuration; + if ( this.config.chainId !== chainId || this.config.nativeCurrency !== ticker