From ae1b28aab6caaafde0a5159afaa94690ecdf2d01 Mon Sep 17 00:00:00 2001 From: salimtb Date: Tue, 22 Apr 2025 11:36:17 +0200 Subject: [PATCH 1/3] feat: remove current chainId dependency from asset contract controller --- .../src/AssetsContractController.test.ts | 144 +++++++++--------- .../src/AssetsContractController.ts | 112 +++++--------- 2 files changed, 109 insertions(+), 147 deletions(-) diff --git a/packages/assets-controllers/src/AssetsContractController.test.ts b/packages/assets-controllers/src/AssetsContractController.test.ts index c51221372e4..7944646c9cf 100644 --- a/packages/assets-controllers/src/AssetsContractController.test.ts +++ b/packages/assets-controllers/src/AssetsContractController.test.ts @@ -34,7 +34,6 @@ import { AssetsContractController, MISSING_PROVIDER_ERROR, } from './AssetsContractController'; -import { SupportedTokenDetectionNetworks } from './assetsUtil'; const ERC20_UNI_ADDRESS = '0x1f9840a85d5af5bf1d1762f925bdaddc4201f984'; const ERC20_SAI_ADDRESS = '0x89d24a6b4ccb1b6faa2625fe562bdd9a23260359'; @@ -192,10 +191,8 @@ describe('AssetsContractController', () => { it('should set default config', async () => { const { assetsContract, messenger } = await setupAssetContractControllers(); expect({ - chainId: assetsContract.chainId, ipfsGateway: assetsContract.ipfsGateway, }).toStrictEqual({ - chainId: SupportedTokenDetectionNetworks.mainnet, ipfsGateway: IPFS_DEFAULT_GATEWAY_URL, }); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); @@ -205,10 +202,8 @@ describe('AssetsContractController', () => { const { assetsContract, messenger, triggerPreferencesStateChange } = await setupAssetContractControllers(); expect({ - chainId: assetsContract.chainId, ipfsGateway: assetsContract.ipfsGateway, }).toStrictEqual({ - chainId: SupportedTokenDetectionNetworks.mainnet, ipfsGateway: IPFS_DEFAULT_GATEWAY_URL, }); @@ -218,45 +213,42 @@ describe('AssetsContractController', () => { }); expect({ - chainId: assetsContract.chainId, ipfsGateway: assetsContract.ipfsGateway, }).toStrictEqual({ ipfsGateway: 'newIPFSGateWay', - chainId: SupportedTokenDetectionNetworks.mainnet, }); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should throw missing provider error when getting ERC-20 token balance when missing provider', async () => { - const { assetsContract, messenger } = await setupAssetContractControllers(); - assetsContract.setProvider(undefined); + const { messenger } = await setupAssetContractControllers(); await expect( messenger.call( `AssetsContractController:getERC20BalanceOf`, ERC20_UNI_ADDRESS, TEST_ACCOUNT_PUBLIC_ADDRESS, + undefined as unknown as NetworkClientId, ), ).rejects.toThrow(MISSING_PROVIDER_ERROR); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should throw missing provider error when getting ERC-20 token decimal when missing provider', async () => { - const { assetsContract, messenger } = await setupAssetContractControllers(); - assetsContract.setProvider(undefined); + const { messenger } = await setupAssetContractControllers(); await expect( messenger.call( `AssetsContractController:getERC20TokenDecimals`, ERC20_UNI_ADDRESS, + undefined as unknown as NetworkClientId, ), ).rejects.toThrow(MISSING_PROVIDER_ERROR); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should get balance of ERC-20 token contract correctly', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -298,11 +290,13 @@ describe('AssetsContractController', () => { `AssetsContractController:getERC20BalanceOf`, ERC20_UNI_ADDRESS, TEST_ACCOUNT_PUBLIC_ADDRESS, + 'mainnet', ); const UNINoBalance = await messenger.call( `AssetsContractController:getERC20BalanceOf`, ERC20_UNI_ADDRESS, '0x202637dAAEfbd7f131f90338a4A6c69F6Cd5CE91', + 'mainnet', ); expect(UNIBalance.toString(16)).not.toBe('0'); expect(UNINoBalance.toString(16)).toBe('0'); @@ -310,9 +304,8 @@ describe('AssetsContractController', () => { }); it('should get ERC-721 NFT tokenId correctly', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -339,33 +332,33 @@ describe('AssetsContractController', () => { ERC721_GODS_ADDRESS, '0x9a90bd8d1149a88b42a99cf62215ad955d6f498a', 0, + 'mainnet', ); expect(tokenId).not.toBe(0); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should throw missing provider error when getting ERC-721 token standard and details when missing provider', async () => { - const { assetsContract, messenger } = await setupAssetContractControllers(); - assetsContract.setProvider(undefined); + const { messenger } = await setupAssetContractControllers(); await expect( messenger.call( `AssetsContractController:getTokenStandardAndDetails`, ERC20_UNI_ADDRESS, TEST_ACCOUNT_PUBLIC_ADDRESS, + undefined as unknown as NetworkClientId, ), ).rejects.toThrow(MISSING_PROVIDER_ERROR); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should throw contract standard error when getting ERC-20 token standard and details when provided with invalid ERC-20 address', async () => { - const { assetsContract, messenger, provider } = - await setupAssetContractControllers(); - assetsContract.setProvider(provider); + const { messenger } = await setupAssetContractControllers(); const error = 'Unable to determine contract standard'; await expect( messenger.call( `AssetsContractController:getTokenStandardAndDetails`, 'BaDeRc20AdDrEsS', + 'mainnet', TEST_ACCOUNT_PUBLIC_ADDRESS, ), ).rejects.toThrow(error); @@ -373,9 +366,8 @@ describe('AssetsContractController', () => { }); it('should get ERC-721 token standard and details', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -432,6 +424,7 @@ describe('AssetsContractController', () => { const standardAndDetails = await messenger.call( `AssetsContractController:getTokenStandardAndDetails`, ERC721_GODS_ADDRESS, + 'mainnet', TEST_ACCOUNT_PUBLIC_ADDRESS, ); expect(standardAndDetails.standard).toBe('ERC721'); @@ -439,9 +432,8 @@ describe('AssetsContractController', () => { }); it('should get ERC-1155 token standard and details', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -514,6 +506,7 @@ describe('AssetsContractController', () => { const standardAndDetails = await messenger.call( `AssetsContractController:getTokenStandardAndDetails`, ERC1155_ADDRESS, + 'mainnet', TEST_ACCOUNT_PUBLIC_ADDRESS, ); @@ -525,9 +518,8 @@ describe('AssetsContractController', () => { }); it('should get ERC-20 token standard and details', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -616,6 +608,7 @@ describe('AssetsContractController', () => { const standardAndDetails = await messenger.call( `AssetsContractController:getTokenStandardAndDetails`, ERC20_UNI_ADDRESS, + 'mainnet', TEST_ACCOUNT_PUBLIC_ADDRESS, ); expect(standardAndDetails.standard).toBe('ERC20'); @@ -623,9 +616,8 @@ describe('AssetsContractController', () => { }); it('should get ERC-721 NFT tokenURI correctly', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -667,15 +659,15 @@ describe('AssetsContractController', () => { `AssetsContractController:getERC721TokenURI`, ERC721_GODS_ADDRESS, '0', + 'mainnet', ); expect(tokenId).toBe('https://api.godsunchained.com/card/0'); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should not throw an error when address given does not support NFT Metadata interface', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); const errorLogSpy = jest .spyOn(console, 'error') .mockImplementationOnce(() => { @@ -721,6 +713,7 @@ describe('AssetsContractController', () => { `AssetsContractController:getERC721TokenURI`, '0x0000000000000000000000000000000000000000', '0', + 'mainnet', ); expect(uri).toBe('https://api.godsunchained.com/card/0'); expect(errorLogSpy).toHaveBeenCalledTimes(1); @@ -732,9 +725,8 @@ describe('AssetsContractController', () => { }); it('should get ERC-721 NFT name', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -759,15 +751,15 @@ describe('AssetsContractController', () => { const name = await messenger.call( `AssetsContractController:getERC721AssetName`, ERC721_GODS_ADDRESS, + 'mainnet', ); expect(name).toBe('Gods Unchained'); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should get ERC-721 NFT symbol', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -792,6 +784,7 @@ describe('AssetsContractController', () => { const symbol = await messenger.call( `AssetsContractController:getERC721AssetSymbol`, ERC721_GODS_ADDRESS, + 'mainnet', ); expect(symbol).toBe('GODS'); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); @@ -803,15 +796,15 @@ describe('AssetsContractController', () => { messenger.call( `AssetsContractController:getERC721AssetSymbol`, ERC721_GODS_ADDRESS, + undefined as unknown as string, ), ).rejects.toThrow(MISSING_PROVIDER_ERROR); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should get ERC-20 token decimals', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -836,15 +829,15 @@ describe('AssetsContractController', () => { const decimals = await messenger.call( `AssetsContractController:getERC20TokenDecimals`, ERC20_SAI_ADDRESS, + 'mainnet', ); expect(Number(decimals)).toBe(18); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should get ERC-20 token name', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -870,6 +863,7 @@ describe('AssetsContractController', () => { const name = await messenger.call( `AssetsContractController:getERC20TokenName`, ERC20_DAI_ADDRESS, + 'mainnet', ); expect(name).toBe('Dai Stablecoin'); @@ -877,9 +871,8 @@ describe('AssetsContractController', () => { }); it('should get ERC-721 NFT ownership', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -905,6 +898,7 @@ describe('AssetsContractController', () => { `AssetsContractController:getERC721OwnerOf`, ERC721_GODS_ADDRESS, '148332', + 'mainnet', ); expect(tokenId).not.toBe(''); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); @@ -917,15 +911,15 @@ describe('AssetsContractController', () => { `AssetsContractController:getERC721OwnerOf`, ERC721_GODS_ADDRESS, '148332', + undefined as unknown as string, ), ).rejects.toThrow(MISSING_PROVIDER_ERROR); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should get balance of ERC-20 token in a single call on network with token detection support', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -951,6 +945,7 @@ describe('AssetsContractController', () => { `AssetsContractController:getBalancesInSingleCall`, ERC20_SAI_ADDRESS, [ERC20_SAI_ADDRESS], + 'mainnet', ); expect(balances[ERC20_SAI_ADDRESS]).toBeDefined(); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); @@ -1024,20 +1019,19 @@ describe('AssetsContractController', () => { }, ], }); - const { assetsContract, messenger, provider } = - await setupAssetContractControllers({ - options: { - chainId: ChainId.mainnet, - }, - useNetworkControllerProvider: true, - infuraProjectId, - }); - assetsContract.setProvider(provider); + const { messenger } = await setupAssetContractControllers({ + options: { + chainId: ChainId.mainnet, + }, + useNetworkControllerProvider: true, + infuraProjectId, + }); const balancesOnMainnet = await messenger.call( 'AssetsContractController:getBalancesInSingleCall', ERC20_SAI_ADDRESS, [ERC20_SAI_ADDRESS], + 'mainnet', ); expect(balancesOnMainnet).toStrictEqual({ [ERC20_SAI_ADDRESS]: BigNumber.from('0x0733ed8ef4c4a0155d09'), @@ -1052,6 +1046,7 @@ describe('AssetsContractController', () => { 'AssetsContractController:getBalancesInSingleCall', ERC20_SAI_ADDRESS, [ERC20_SAI_ADDRESS], + 'linea-mainnet', ); expect(balancesOnLineaMainnet).toStrictEqual({ [ERC20_SAI_ADDRESS]: BigNumber.from('0xa0155d09733ed8ef4c4'), @@ -1060,14 +1055,8 @@ describe('AssetsContractController', () => { }); it('should not have balance in a single call after switching to network without token detection support', async () => { - const { - assetsContract, - messenger, - network, - provider, - networkClientConfiguration, - } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); + const { messenger, network, networkClientConfiguration } = + await setupAssetContractControllers(); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -1125,6 +1114,7 @@ describe('AssetsContractController', () => { `AssetsContractController:getBalancesInSingleCall`, ERC20_SAI_ADDRESS, [ERC20_SAI_ADDRESS], + 'mainnet', ); expect(balances[ERC20_SAI_ADDRESS]).toBeDefined(); @@ -1134,14 +1124,14 @@ describe('AssetsContractController', () => { `AssetsContractController:getBalancesInSingleCall`, ERC20_SAI_ADDRESS, [ERC20_SAI_ADDRESS], + 'sepolia', ); expect(noBalances).toStrictEqual({}); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should throw missing provider error when transferring single ERC-1155 when missing provider', async () => { - const { assetsContract, messenger } = await setupAssetContractControllers(); - assetsContract.setProvider(undefined); + const { messenger } = await setupAssetContractControllers(); await expect( messenger.call( `AssetsContractController:transferSingleERC1155`, @@ -1150,15 +1140,15 @@ describe('AssetsContractController', () => { TEST_ACCOUNT_PUBLIC_ADDRESS, ERC1155_ID, '1', + undefined as unknown as string, ), ).rejects.toThrow(MISSING_PROVIDER_ERROR); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should throw when ERC1155 function transferSingle is not defined', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -1188,15 +1178,15 @@ describe('AssetsContractController', () => { TEST_ACCOUNT_PUBLIC_ADDRESS, ERC1155_ID, '1', + 'sepolia', ), ).rejects.toThrow('contract.transferSingle is not a function'); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should get the balance of a ERC-1155 NFT for a given address', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -1223,6 +1213,7 @@ describe('AssetsContractController', () => { TEST_ACCOUNT_PUBLIC_ADDRESS, ERC1155_ADDRESS, ERC1155_ID, + 'sepolia', ); expect(Number(balance)).toBeGreaterThan(0); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); @@ -1236,15 +1227,15 @@ describe('AssetsContractController', () => { TEST_ACCOUNT_PUBLIC_ADDRESS, ERC1155_ADDRESS, ERC1155_ID, + undefined as unknown as string, ), ).rejects.toThrow(MISSING_PROVIDER_ERROR); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should get the URI of a ERC-1155 NFT', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, mocks: [ @@ -1271,15 +1262,15 @@ describe('AssetsContractController', () => { `AssetsContractController:getERC1155TokenURI`, ERC1155_ADDRESS, ERC1155_ID, + 'mainnet', ); expect(uri.toLowerCase()).toStrictEqual(expectedUri); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); it('should get the staked ethereum balance for an address', async () => { - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { assetsContract, messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, @@ -1323,6 +1314,7 @@ describe('AssetsContractController', () => { const balance = await assetsContract.getStakedBalanceForChain( TEST_ACCOUNT_PUBLIC_ADDRESS, + 'mainnet', ); // exchange rate shares = 1e18 @@ -1339,9 +1331,8 @@ describe('AssetsContractController', () => { it('should return default of zero hex as staked ethereum balance if user has no shares', async () => { const errorSpy = jest.spyOn(console, 'error'); - const { assetsContract, messenger, provider, networkClientConfiguration } = + const { assetsContract, messenger, networkClientConfiguration } = await setupAssetContractControllers(); - assetsContract.setProvider(provider); mockNetworkWithDefaultChainId({ networkClientConfiguration, @@ -1368,6 +1359,7 @@ describe('AssetsContractController', () => { const balance = await assetsContract.getStakedBalanceForChain( TEST_ACCOUNT_PUBLIC_ADDRESS, + 'mainnet', ); expect(balance).toBeDefined(); @@ -1386,12 +1378,11 @@ describe('AssetsContractController', () => { .mockImplementationOnce((e) => { error = e; }); - const { assetsContract, messenger, provider } = - await setupAssetContractControllers(); - assetsContract.setProvider(provider); + const { assetsContract, messenger } = await setupAssetContractControllers(); const balance = await assetsContract.getStakedBalanceForChain( TEST_ACCOUNT_PUBLIC_ADDRESS, + 'mainnet', ); expect(balance).toBeDefined(); @@ -1407,7 +1398,10 @@ describe('AssetsContractController', () => { it('should throw missing provider error when getting staked ethereum balance and missing provider', async () => { const { assetsContract, messenger } = await setupAssetContractControllers(); await expect( - assetsContract.getStakedBalanceForChain(TEST_ACCOUNT_PUBLIC_ADDRESS), + assetsContract.getStakedBalanceForChain( + TEST_ACCOUNT_PUBLIC_ADDRESS, + undefined as unknown as string, + ), ).rejects.toThrow(MISSING_PROVIDER_ERROR); messenger.clearEventSubscriptions('NetworkController:networkDidChange'); }); diff --git a/packages/assets-controllers/src/AssetsContractController.ts b/packages/assets-controllers/src/AssetsContractController.ts index 5e8a9398d67..4b396cf4505 100644 --- a/packages/assets-controllers/src/AssetsContractController.ts +++ b/packages/assets-controllers/src/AssetsContractController.ts @@ -14,7 +14,6 @@ import type { NetworkControllerGetSelectedNetworkClientAction, NetworkControllerGetStateAction, NetworkControllerNetworkDidChangeEvent, - Provider, } from '@metamask/network-controller'; import type { PreferencesControllerStateChangeEvent } from '@metamask/preferences-controller'; import { getKnownPropertyNames, type Hex } from '@metamask/utils'; @@ -219,11 +218,9 @@ export class AssetsContractController { protected messagingSystem: AssetsContractControllerMessenger; - #provider: Provider | undefined; - #ipfsGateway: string; - #chainId: Hex; + // #chainId: Hex; /** * Creates a AssetsContractController instance. @@ -234,15 +231,14 @@ export class AssetsContractController { */ constructor({ messenger, - chainId: initialChainId, + // chainId: initialChainId, }: { messenger: AssetsContractControllerMessenger; chainId: Hex; }) { this.messagingSystem = messenger; - this.#provider = undefined; this.#ipfsGateway = IPFS_DEFAULT_GATEWAY_URL; - this.#chainId = initialChainId; + // this.#chainId = initialChainId; this.#registerActionHandlers(); this.#registerEventSubscriptions(); @@ -284,37 +280,15 @@ export class AssetsContractController { this.#ipfsGateway = ipfsGateway; }, ); - - this.messagingSystem.subscribe( - `NetworkController:networkDidChange`, - ({ selectedNetworkClientId }) => { - const chainId = this.#getCorrectChainId(selectedNetworkClientId); - - if (this.#chainId !== chainId) { - this.#chainId = chainId; - // @ts-expect-error TODO: remove this annotation once the `Eip1193Provider` class is released - this.#provider = this.#getCorrectProvider(); - } - }, - ); - } - - /** - * Sets a new provider. - * - * @param provider - Provider used to create a new underlying Web3 instance - */ - setProvider(provider: Provider | undefined) { - this.#provider = provider; } get ipfsGateway() { return this.#ipfsGateway; } - get chainId() { - return this.#chainId; - } + // get chainId() { + // return this.#chainId; + // } /** * Get the relevant provider instance. @@ -322,20 +296,16 @@ export class AssetsContractController { * @param networkClientId - Network Client ID. * @returns Web3Provider instance. */ - #getCorrectProvider(networkClientId?: NetworkClientId): Web3Provider { - const provider = networkClientId - ? this.messagingSystem.call( - `NetworkController:getNetworkClientById`, - networkClientId, - ).provider - : (this.messagingSystem.call('NetworkController:getSelectedNetworkClient') - ?.provider ?? this.#provider); - - if (provider === undefined) { + #getCorrectProvider(networkClientId: NetworkClientId): Web3Provider { + try { + const { provider } = this.messagingSystem.call( + `NetworkController:getNetworkClientById`, + networkClientId, + ); + return new Web3Provider(provider); + } catch (err) { throw new Error(MISSING_PROVIDER_ERROR); } - - return new Web3Provider(provider); } /** @@ -344,15 +314,13 @@ export class AssetsContractController { * @param networkClientId - Network Client ID used to get the provider. * @returns Hex chain ID. */ - #getCorrectChainId(networkClientId?: NetworkClientId): Hex { - if (networkClientId) { - const networkClientConfiguration = this.messagingSystem.call( - 'NetworkController:getNetworkConfigurationByNetworkClientId', - networkClientId, - ); - if (networkClientConfiguration) { - return networkClientConfiguration.chainId; - } + #getCorrectChainId(networkClientId: NetworkClientId): Hex { + const networkClientConfiguration = this.messagingSystem.call( + 'NetworkController:getNetworkConfigurationByNetworkClientId', + networkClientId, + ); + if (networkClientConfiguration) { + return networkClientConfiguration.chainId; } const { selectedNetworkClientId } = this.messagingSystem.call( 'NetworkController:getState', @@ -361,7 +329,7 @@ export class AssetsContractController { 'NetworkController:getNetworkClientById', selectedNetworkClientId, ); - return networkClient.configuration?.chainId ?? this.#chainId; + return networkClient.configuration.chainId; } /** @@ -370,7 +338,7 @@ export class AssetsContractController { * @param networkClientId - Network Client ID used to get the provider. * @returns ERC20Standard instance. */ - getERC20Standard(networkClientId?: NetworkClientId): ERC20Standard { + getERC20Standard(networkClientId: NetworkClientId): ERC20Standard { const provider = this.#getCorrectProvider(networkClientId); return new ERC20Standard(provider); } @@ -381,7 +349,7 @@ export class AssetsContractController { * @param networkClientId - Network Client ID used to get the provider. * @returns ERC721Standard instance. */ - getERC721Standard(networkClientId?: NetworkClientId): ERC721Standard { + getERC721Standard(networkClientId: NetworkClientId): ERC721Standard { const provider = this.#getCorrectProvider(networkClientId); return new ERC721Standard(provider); } @@ -392,7 +360,7 @@ export class AssetsContractController { * @param networkClientId - Network Client ID used to get the provider. * @returns ERC1155Standard instance. */ - getERC1155Standard(networkClientId?: NetworkClientId): ERC1155Standard { + getERC1155Standard(networkClientId: NetworkClientId): ERC1155Standard { const provider = this.#getCorrectProvider(networkClientId); return new ERC1155Standard(provider); } @@ -408,7 +376,7 @@ export class AssetsContractController { async getERC20BalanceOf( address: string, selectedAddress: string, - networkClientId?: NetworkClientId, + networkClientId: NetworkClientId, ): Promise { const erc20Standard = this.getERC20Standard(networkClientId); return erc20Standard.getBalanceOf(address, selectedAddress); @@ -423,7 +391,7 @@ export class AssetsContractController { */ async getERC20TokenDecimals( address: string, - networkClientId?: NetworkClientId, + networkClientId: NetworkClientId, ): Promise { const erc20Standard = this.getERC20Standard(networkClientId); return erc20Standard.getTokenDecimals(address); @@ -438,7 +406,7 @@ export class AssetsContractController { */ async getERC20TokenName( address: string, - networkClientId?: NetworkClientId, + networkClientId: NetworkClientId, ): Promise { const erc20Standard = this.getERC20Standard(networkClientId); return erc20Standard.getTokenName(address); @@ -457,7 +425,7 @@ export class AssetsContractController { address: string, selectedAddress: string, index: number, - networkClientId?: NetworkClientId, + networkClientId: NetworkClientId, ): Promise { const erc721Standard = this.getERC721Standard(networkClientId); return erc721Standard.getNftTokenId(address, selectedAddress, index); @@ -467,16 +435,16 @@ export class AssetsContractController { * Enumerate assets assigned to an owner. * * @param tokenAddress - ERC721 asset contract address. + * @param networkClientId - Network Client ID to fetch the provider with. * @param userAddress - Current account public address. * @param tokenId - ERC721 asset identifier. - * @param networkClientId - Network Client ID to fetch the provider with. * @returns Promise resolving to an object containing the token standard and a set of details which depend on which standard the token supports. */ async getTokenStandardAndDetails( tokenAddress: string, + networkClientId: NetworkClientId, userAddress?: string, tokenId?: string, - networkClientId?: NetworkClientId, ): Promise<{ standard: string; tokenURI?: string | undefined; @@ -540,7 +508,7 @@ export class AssetsContractController { async getERC721TokenURI( address: string, tokenId: string, - networkClientId?: NetworkClientId, + networkClientId: NetworkClientId, ): Promise { const erc721Standard = this.getERC721Standard(networkClientId); return erc721Standard.getTokenURI(address, tokenId); @@ -555,7 +523,7 @@ export class AssetsContractController { */ async getERC721AssetName( address: string, - networkClientId?: NetworkClientId, + networkClientId: NetworkClientId, ): Promise { const erc721Standard = this.getERC721Standard(networkClientId); return erc721Standard.getAssetName(address); @@ -570,7 +538,7 @@ export class AssetsContractController { */ async getERC721AssetSymbol( address: string, - networkClientId?: NetworkClientId, + networkClientId: NetworkClientId, ): Promise { const erc721Standard = this.getERC721Standard(networkClientId); return erc721Standard.getAssetSymbol(address); @@ -587,7 +555,7 @@ export class AssetsContractController { async getERC721OwnerOf( address: string, tokenId: string, - networkClientId?: NetworkClientId, + networkClientId: NetworkClientId, ): Promise { const erc721Standard = this.getERC721Standard(networkClientId); return erc721Standard.getOwnerOf(address, tokenId); @@ -604,7 +572,7 @@ export class AssetsContractController { async getERC1155TokenURI( address: string, tokenId: string, - networkClientId?: NetworkClientId, + networkClientId: NetworkClientId, ): Promise { const erc1155Standard = this.getERC1155Standard(networkClientId); return erc1155Standard.getTokenURI(address, tokenId); @@ -623,7 +591,7 @@ export class AssetsContractController { userAddress: string, nftAddress: string, nftId: string, - networkClientId?: NetworkClientId, + networkClientId: NetworkClientId, ): Promise { const erc1155Standard = this.getERC1155Standard(networkClientId); return erc1155Standard.getBalanceOf(nftAddress, userAddress, nftId); @@ -646,7 +614,7 @@ export class AssetsContractController { recipientAddress: string, nftId: string, qty: string, - networkClientId?: NetworkClientId, + networkClientId: NetworkClientId, ): Promise { const erc1155Standard = this.getERC1155Standard(networkClientId); return erc1155Standard.transferSingle( @@ -670,7 +638,7 @@ export class AssetsContractController { async getBalancesInSingleCall( selectedAddress: string, tokensToDetect: string[], - networkClientId?: NetworkClientId, + networkClientId: NetworkClientId, ) { const chainId = this.#getCorrectChainId(networkClientId); const provider = this.#getCorrectProvider(networkClientId); @@ -712,7 +680,7 @@ export class AssetsContractController { */ async getStakedBalanceForChain( address: string, - networkClientId?: NetworkClientId, + networkClientId: NetworkClientId, ): Promise { const chainId = this.#getCorrectChainId(networkClientId); const provider = this.#getCorrectProvider(networkClientId); From 501834463915f37d93e9120ff68dc28fb6350c06 Mon Sep 17 00:00:00 2001 From: salimtb Date: Sun, 20 Apr 2025 22:32:34 +0200 Subject: [PATCH 2/3] fix: parallelize fetch native balances --- .../src/AccountTrackerController.test.ts | 50 ++++---- .../src/AccountTrackerController.ts | 118 ++++++++++++------ 2 files changed, 106 insertions(+), 62 deletions(-) diff --git a/packages/assets-controllers/src/AccountTrackerController.test.ts b/packages/assets-controllers/src/AccountTrackerController.test.ts index 59c40417a02..e29b537cd89 100644 --- a/packages/assets-controllers/src/AccountTrackerController.test.ts +++ b/packages/assets-controllers/src/AccountTrackerController.test.ts @@ -127,7 +127,7 @@ describe('AccountTrackerController', () => { listAccounts: [mockAccount1, mockAccount2], }, async ({ controller }) => { - await controller.refresh(); + await controller.refresh(['mainnet']); expect(controller.state).toStrictEqual({ accountsByChainId: { '0x1': { @@ -154,7 +154,7 @@ describe('AccountTrackerController', () => { listAccounts: [ACCOUNT_1], }, async ({ controller }) => { - await controller.refresh(); + await controller.refresh(['mainnet']); expect(controller.state).toStrictEqual({ accountsByChainId: { @@ -181,7 +181,7 @@ describe('AccountTrackerController', () => { listAccounts: [ACCOUNT_1, ACCOUNT_2], }, async ({ controller }) => { - await controller.refresh(); + await controller.refresh(['mainnet']); expect(controller.state).toStrictEqual({ accountsByChainId: { @@ -207,7 +207,7 @@ describe('AccountTrackerController', () => { listAccounts: [ACCOUNT_1, ACCOUNT_2], }, async ({ controller }) => { - await controller.refresh(); + await controller.refresh(['mainnet']); expect(controller.state).toStrictEqual({ accountsByChainId: { @@ -237,7 +237,7 @@ describe('AccountTrackerController', () => { listAccounts: [ACCOUNT_1, ACCOUNT_2], }, async ({ controller }) => { - await controller.refresh(); + await controller.refresh(['mainnet']); expect(controller.state).toStrictEqual({ accountsByChainId: { @@ -272,7 +272,7 @@ describe('AccountTrackerController', () => { listAccounts: [ACCOUNT_1, ACCOUNT_2], }, async ({ controller }) => { - await controller.refresh(); + await controller.refresh(['mainnet']); expect(controller.state).toStrictEqual({ accountsByChainId: { @@ -306,7 +306,7 @@ describe('AccountTrackerController', () => { listAccounts: [ACCOUNT_1, ACCOUNT_2], }, async ({ controller }) => { - await controller.refresh(); + await controller.refresh(['mainnet']); expect(controller.state).toStrictEqual({ accountsByChainId: { @@ -366,7 +366,7 @@ describe('AccountTrackerController', () => { }, }, async ({ controller }) => { - await controller.refresh(networkClientId); + await controller.refresh(['networkClientId1']); expect(controller.state).toStrictEqual({ accountsByChainId: { '0x1': { @@ -403,7 +403,7 @@ describe('AccountTrackerController', () => { }, }, async ({ controller }) => { - await controller.refresh(networkClientId); + await controller.refresh(['networkClientId1']); expect(controller.state).toStrictEqual({ accountsByChainId: { @@ -441,7 +441,7 @@ describe('AccountTrackerController', () => { }, }, async ({ controller }) => { - await controller.refresh(networkClientId); + await controller.refresh(['networkClientId1']); expect(controller.state).toStrictEqual({ accountsByChainId: { @@ -477,7 +477,7 @@ describe('AccountTrackerController', () => { }, }, async ({ controller }) => { - await controller.refresh(networkClientId); + await controller.refresh(['networkClientId1']); expect(controller.state).toStrictEqual({ accountsByChainId: { @@ -517,7 +517,7 @@ describe('AccountTrackerController', () => { }, }, async ({ controller }) => { - await controller.refresh(); + await controller.refresh(['mainnet']); expect(controller.state).toStrictEqual({ accountsByChainId: { @@ -558,7 +558,7 @@ describe('AccountTrackerController', () => { }, }, async ({ controller }) => { - await controller.refresh(); + await controller.refresh(['mainnet']); expect(controller.state).toStrictEqual({ accountsByChainId: { @@ -598,7 +598,7 @@ describe('AccountTrackerController', () => { }, }, async ({ controller }) => { - await controller.refresh(); + await controller.refresh(['mainnet']); expect(controller.state).toStrictEqual({ accountsByChainId: { @@ -640,7 +640,7 @@ describe('AccountTrackerController', () => { }, }, async ({ controller }) => { - await controller.refresh(); + await controller.refresh(['mainnet']); expect(controller.state).toStrictEqual({ accountsByChainId: { @@ -726,7 +726,7 @@ describe('AccountTrackerController', () => { jest.spyOn(controller, 'refresh').mockResolvedValue(); await controller.startPolling({ - networkClientId: 'networkClientId1', + networkClientIds: ['networkClientId1'], }); await advanceTime({ clock, duration: 1 }); @@ -759,34 +759,34 @@ describe('AccountTrackerController', () => { .mockResolvedValue(); controller.startPolling({ - networkClientId: networkClientId1, + networkClientIds: [networkClientId1], }); await advanceTime({ clock, duration: 0 }); - expect(refreshSpy).toHaveBeenNthCalledWith(1, networkClientId1); + expect(refreshSpy).toHaveBeenNthCalledWith(1, [networkClientId1]); expect(refreshSpy).toHaveBeenCalledTimes(1); await advanceTime({ clock, duration: 50 }); expect(refreshSpy).toHaveBeenCalledTimes(1); await advanceTime({ clock, duration: 50 }); - expect(refreshSpy).toHaveBeenNthCalledWith(2, networkClientId1); + expect(refreshSpy).toHaveBeenNthCalledWith(2, [networkClientId1]); expect(refreshSpy).toHaveBeenCalledTimes(2); const pollToken = controller.startPolling({ - networkClientId: networkClientId2, + networkClientIds: [networkClientId2], }); await advanceTime({ clock, duration: 0 }); - expect(refreshSpy).toHaveBeenNthCalledWith(3, networkClientId2); + expect(refreshSpy).toHaveBeenNthCalledWith(3, [networkClientId2]); expect(refreshSpy).toHaveBeenCalledTimes(3); await advanceTime({ clock, duration: 100 }); - expect(refreshSpy).toHaveBeenNthCalledWith(4, networkClientId1); - expect(refreshSpy).toHaveBeenNthCalledWith(5, networkClientId2); + expect(refreshSpy).toHaveBeenNthCalledWith(4, [networkClientId1]); + expect(refreshSpy).toHaveBeenNthCalledWith(5, [networkClientId2]); expect(refreshSpy).toHaveBeenCalledTimes(5); controller.stopPollingByPollingToken(pollToken); await advanceTime({ clock, duration: 100 }); - expect(refreshSpy).toHaveBeenNthCalledWith(6, networkClientId1); + expect(refreshSpy).toHaveBeenNthCalledWith(6, [networkClientId1]); expect(refreshSpy).toHaveBeenCalledTimes(6); controller.stopAllPolling(); @@ -810,7 +810,7 @@ describe('AccountTrackerController', () => { expect(refreshSpy).not.toHaveBeenCalled(); controller.startPolling({ - networkClientId: 'networkClientId1', + networkClientIds: ['networkClientId1'], }); await advanceTime({ clock, duration: 1 }); diff --git a/packages/assets-controllers/src/AccountTrackerController.ts b/packages/assets-controllers/src/AccountTrackerController.ts index 56a073bbaa2..c8d9f0871e3 100644 --- a/packages/assets-controllers/src/AccountTrackerController.ts +++ b/packages/assets-controllers/src/AccountTrackerController.ts @@ -124,7 +124,7 @@ export type AccountTrackerControllerMessenger = RestrictedMessenger< /** The input to start polling for the {@link AccountTrackerController} */ type AccountTrackerPollingInput = { - networkClientId: NetworkClientId; + networkClientIds: NetworkClientId[]; }; /** @@ -194,7 +194,7 @@ export class AccountTrackerController extends StaticIntervalPollingController this.refresh(), + () => this.refresh(this.#getNetworkClientIds()), ); } @@ -282,18 +282,35 @@ export class AccountTrackerController extends StaticIntervalPollingController + networkConfiguration.rpcEndpoints.map( + (rpcEndpoint) => rpcEndpoint.networkClientId, + ), + ); + } + /** * Refreshes the balances of the accounts using the networkClientId * * @param input - The input for the poll. - * @param input.networkClientId - The network client ID used to get balances. + * @param input.networkClientIds - The network client IDs used to get balances. */ async _executePoll({ - networkClientId, + networkClientIds, }: AccountTrackerPollingInput): Promise { // TODO: Either fix this lint violation or explain why it's necessary to ignore. // eslint-disable-next-line @typescript-eslint/no-floating-promises - this.refresh(networkClientId); + this.refresh(networkClientIds); } /** @@ -301,50 +318,77 @@ export class AccountTrackerController extends StaticIntervalPollingController { + const { chainId, ethQuery } = + this.#getCorrectNetworkClient(networkClientId); + this.syncAccounts(chainId); + const { accountsByChainId } = this.state; + const { isMultiAccountBalancesEnabled } = this.messagingSystem.call( + 'PreferencesController:getState', + ); + + const accountsToUpdate = isMultiAccountBalancesEnabled + ? Object.keys(accountsByChainId[chainId]) + : [toChecksumHexAddress(selectedAccount.address)]; + + const accountsForChain = { ...accountsByChainId[chainId] }; + + // Create an array of promises for balance and staked balance fetching + const balancePromises = accountsToUpdate.map(async (address) => { + const balancePromise = this.#getBalanceFromChain(address, ethQuery); + const stakedBalancePromise = this.#includeStakedAssets + ? this.#getStakedBalanceForChain(address, networkClientId) + : Promise.resolve(null); + + const [balanceResult, stakedBalanceResult] = await Promise.allSettled( + [balancePromise, stakedBalancePromise], ); - if (stakedBalance) { + + // Update account balances + if (balanceResult.status === 'fulfilled' && balanceResult.value) { + accountsForChain[address] = { + balance: balanceResult.value, + }; + } + + if ( + stakedBalanceResult.status === 'fulfilled' && + stakedBalanceResult.value + ) { accountsForChain[address] = { ...accountsForChain[address], - stakedBalance, + stakedBalance: stakedBalanceResult.value, }; } - } - } + }); - this.update((state) => { - state.accountsByChainId[chainId] = accountsForChain; + // Wait for all balance-related promises to settle + await Promise.allSettled(balancePromises); + + // After all promises for this networkClientId are settled, return the updated data + return { chainId, accountsForChain }; + }); + + // Wait for all networkClientId updates to settle in parallel + const allResults = await Promise.allSettled(updatePromises); + + // Update the state once all networkClientId updates are completed + allResults.forEach((result) => { + if (result.status === 'fulfilled') { + const { chainId, accountsForChain } = result.value; + this.update((state) => { + state.accountsByChainId[chainId] = accountsForChain; + }); + } }); } finally { releaseLock(); From 30c0f69f0b27d17b3326e5951c075e6cc4d82e13 Mon Sep 17 00:00:00 2001 From: salimtb Date: Wed, 23 Apr 2025 11:00:58 +0200 Subject: [PATCH 3/3] fix: fix linter --- packages/assets-controllers/src/AssetsContractController.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/assets-controllers/src/AssetsContractController.ts b/packages/assets-controllers/src/AssetsContractController.ts index 4b396cf4505..b593abdb3b2 100644 --- a/packages/assets-controllers/src/AssetsContractController.ts +++ b/packages/assets-controllers/src/AssetsContractController.ts @@ -303,7 +303,7 @@ export class AssetsContractController { networkClientId, ); return new Web3Provider(provider); - } catch (err) { + } catch { throw new Error(MISSING_PROVIDER_ERROR); } }