Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactoring rpcFactory #480

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pages/api/rpc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ const rpc = rpcFactory({
allowedCallAddresses,
allowedLogsAddresses,
maxBatchCount: config.PROVIDER_MAX_BATCH,
disallowEmptyAddressGetLogs: false,
disallowEmptyAddressGetLogs: true,
});

export default wrapNextRequest([
Expand Down
288 changes: 157 additions & 131 deletions utilsApi/rpcFactory.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
import { Readable, Transform } from 'node:stream';
import { Readable } from 'node:stream';
import { ReadableStream } from 'node:stream/web';
import type { NextApiRequest, NextApiResponse } from 'next';
import { Counter, Registry } from 'prom-client';
import type { TrackedFetchRPC } from '@lidofinance/api-rpc';
import type { FetchRpcInitBody } from '@lidofinance/rpc';
import { iterateUrls } from '@lidofinance/rpc';

type EthCallParams = [{ to: string; [key: string]: any }];
type EthGetLogsParams = [{ address?: string | string[]; [key: string]: any }];

type ValidatingMethodsType = 'eth_call' | 'eth_getLogs';
type MethodParams = {
eth_call: EthCallParams;
eth_getLogs: EthGetLogsParams;
};

export type RpcProviders = Record<string | number, [string, ...string[]]>;

export const DEFAULT_API_ERROR_MESSAGE =
'Something went wrong. Sorry, try again later :(';

export const HEALTHY_RPC_SERVICES_ARE_OVER = 'Healthy RPC services are over!';
export const HTTP_METHOD_POST = 'POST';
export const CONTENT_TYPE_JSON = 'application/json';

export class ClientError extends Error {}
export class UnsupportedChainIdError extends ClientError {
Expand All @@ -32,32 +42,126 @@ export class InvalidRequestError extends ClientError {
}
}

export class SizeTooLargeError extends ClientError {
constructor(message?: string) {
super(message || 'Invalid Request');
type ValidatorContext = {
allowedCallAddressMap: Record<string, Set<string>>;
allowedLogsAddressMap: Record<string, Set<string>>;
chainId: number;
rpcRequestBlocked: Counter<string>;
disallowEmptyAddressGetLogs: boolean;
};

const validateEthCall = (params: EthCallParams, context: ValidatorContext) => {
const { allowedCallAddressMap, chainId, rpcRequestBlocked } = context;
if (!allowedCallAddressMap[chainId]) return;

const [{ to }] = params;
if (typeof to !== 'string') {
throw new InvalidRequestError(`Invalid eth_call params`);
}
}
if (!allowedCallAddressMap[chainId].has(to.toLowerCase())) {
rpcRequestBlocked.inc();
throw new InvalidRequestError(`Address not allowed for eth_call`);
}
};

const createSizeLogger = (MAX_SIZE: number) => {
let bytesWritten = 0;
const logSizeStream = new Transform({
transform(chunk, _encoding, callback) {
bytesWritten += chunk.length;
if (bytesWritten > MAX_SIZE) {
// Emit an error if size exceeds MAX_SIZE
return callback(
new SizeTooLargeError(
`Stream size exceeds the maximum limit of ${MAX_SIZE} bytes`,
),
);
}
return callback(null, chunk); // Pass the chunk through
},
flush(callback) {
callback();
},
});
return logSizeStream;
const validateEthGetLogs = (
params: EthGetLogsParams,
context: ValidatorContext,
) => {
const {
disallowEmptyAddressGetLogs,
chainId,
rpcRequestBlocked,
allowedLogsAddressMap,
} = context;
if (!disallowEmptyAddressGetLogs && !allowedLogsAddressMap[chainId]) return;

const [{ address }] = params;
if (
disallowEmptyAddressGetLogs &&
(!address || (Array.isArray(address) && address.length === 0))
) {
rpcRequestBlocked.inc();
throw new InvalidRequestError(`No empty address on eth_getLogs`);
}

const addresses = Array.isArray(address) ? address : [address];
const isInvalidAddress = (addr: any) =>
typeof addr !== 'string' ||
!allowedLogsAddressMap[chainId].has(addr.toLowerCase());

if (addresses.some(isInvalidAddress)) {
rpcRequestBlocked.inc();
throw new InvalidRequestError(`Address not allowed for eth_getLogs`);
}
};

const methodValidators: Record<
ValidatingMethodsType,
(params: any, context: ValidatorContext) => void
> = {
eth_call: validateEthCall,
eth_getLogs: validateEthGetLogs,
};

const validateMethod = (
method: ValidatingMethodsType,
params: MethodParams[ValidatingMethodsType],
context: ValidatorContext,
) => {
const validator = methodValidators[method];
if (validator) {
validator(params, context);
}
};

type ValidateRequestContentParams = {
req: NextApiRequest;
allowedRPCMethods: string[];
rpcRequestBlocked: Counter<string>;
allowedCallAddressMap: Record<string, Set<string>>;
allowedLogsAddressMap: Record<string, Set<string>>;
chainId: number;
disallowEmptyAddressGetLogs: boolean;
maxBatchCount?: number;
};

const validateRequestContent = ({
req,
allowedRPCMethods,
rpcRequestBlocked,
allowedCallAddressMap,
allowedLogsAddressMap,
chainId,
disallowEmptyAddressGetLogs,
maxBatchCount,
}: ValidateRequestContentParams) => {
const content = Array.isArray(req.body) ? req.body : [req.body];

if (typeof maxBatchCount === 'number' && content.length > maxBatchCount) {
throw new InvalidRequestError(`Too many batched requests`);
}

const context: ValidatorContext = {
allowedCallAddressMap,
allowedLogsAddressMap,
chainId,
rpcRequestBlocked,
disallowEmptyAddressGetLogs,
};

for (const { method, params } of content) {
if (typeof method !== 'string') {
throw new InvalidRequestError(`RPC method isn't string`);
}

if (!allowedRPCMethods.includes(method)) {
rpcRequestBlocked.inc();
throw new InvalidRequestError(`RPC method ${method} isn't allowed`);
}

validateMethod(method as ValidatingMethodsType, params, context);
}
};

export type RPCFactoryParams = {
Expand All @@ -77,7 +181,18 @@ export type RPCFactoryParams = {
allowedLogsAddresses?: Record<number, string[]>;
disallowEmptyAddressGetLogs?: boolean;
maxBatchCount?: number;
maxResponseSize?: number;
};

const createAllowedAddressMap = (
addresses: Record<number, string[]>,
): Record<string, Set<string>> => {
return Object.entries(addresses).reduce(
(acc, [chainId, addressList]) => {
acc[chainId] = new Set(addressList.map((a) => a.toLowerCase()));
return acc;
},
{} as Record<string, Set<string>>,
);
};

export const rpcFactory = ({
Expand All @@ -89,7 +204,6 @@ export const rpcFactory = ({
allowedCallAddresses = {},
allowedLogsAddresses = {},
maxBatchCount,
maxResponseSize = 1_000_000, // ~1MB,
disallowEmptyAddressGetLogs = false,
}: RPCFactoryParams) => {
const rpcRequestBlocked = new Counter({
Expand All @@ -100,26 +214,13 @@ export const rpcFactory = ({
});
registry.registerMetric(rpcRequestBlocked);

const allowedCallAddressMap = Object.entries(allowedCallAddresses).reduce(
(acc, [chainId, addresses]) => {
acc[chainId] = new Set(addresses.map((a) => a.toLowerCase()));
return acc;
},
{} as Record<string, Set<string>>,
);

const allowedLogsAddressMap = Object.entries(allowedLogsAddresses).reduce(
(acc, [chainId, addresses]) => {
acc[chainId] = new Set(addresses.map((a) => a.toLowerCase()));
return acc;
},
{} as Record<string, Set<string>>,
);
const allowedCallAddressMap = createAllowedAddressMap(allowedCallAddresses);
const allowedLogsAddressMap = createAllowedAddressMap(allowedLogsAddresses);

return async (req: NextApiRequest, res: NextApiResponse): Promise<void> => {
try {
// Accept only POST requests
if (req.method !== 'POST') {
if (req.method !== HTTP_METHOD_POST) {
// We don't care about tracking blocked requests here
throw new UnsupportedHTTPMethodError();
}
Expand All @@ -132,73 +233,17 @@ export const rpcFactory = ({
throw new UnsupportedChainIdError();
}

const requests = Array.isArray(req.body) ? req.body : [req.body];

if (
typeof maxBatchCount === 'number' &&
requests.length > maxBatchCount
) {
throw new InvalidRequestError(`Too many batched requests`);
}

// Check if provided methods are allowed
// We throw HTTP error for ANY invalid RPC request out of batch
// because we assume that frontend must not send invalid requests
for (const { method, params } of requests) {
if (typeof method !== 'string') {
throw new InvalidRequestError(`RPC method isn't string`);
}
if (!allowedRPCMethods.includes(method)) {
rpcRequestBlocked.inc();
throw new InvalidRequestError(`RPC method ${method} isn't allowed`);
}
if (method === 'eth_call' && allowedCallAddressMap[chainId]) {
if (
Array.isArray(params) &&
typeof params[0] === 'object' &&
typeof params[0].to === 'string'
) {
if (
!allowedCallAddressMap[chainId].has(params[0].to.toLowerCase())
) {
rpcRequestBlocked.inc();
throw new InvalidRequestError(`Address not allowed for eth_call`);
}
} else
throw new InvalidRequestError(`RPC method eth_call is invalid`);
}
if (
method === 'eth_getLogs' &&
(disallowEmptyAddressGetLogs || allowedLogsAddressMap[chainId])
) {
if (Array.isArray(params) && typeof params[0] === 'object') {
const address = params[0].address;
if (
disallowEmptyAddressGetLogs &&
(!address || (Array.isArray(address) && address.length === 0))
) {
rpcRequestBlocked.inc();
throw new InvalidRequestError(`No empty address on eth_getLogs`);
}
const addresses = Array.isArray(address) ? address : [address];
if (
addresses.some(
(eventAddress) =>
// needs this check before toLowerCase
typeof eventAddress !== 'string' ||
!allowedLogsAddressMap[chainId].has(
eventAddress.toLowerCase(),
),
)
) {
rpcRequestBlocked.inc();
throw new InvalidRequestError(
`Address not allowed for eth_getLogs`,
);
}
} else throw new InvalidRequestError(`Invalid eth_getLogs`);
}
}
// Validate request content
validateRequestContent({
req,
allowedRPCMethods,
rpcRequestBlocked,
allowedCallAddressMap,
allowedLogsAddressMap,
chainId,
disallowEmptyAddressGetLogs,
maxBatchCount,
});

const requested = await iterateUrls(
providers[chainId],
Expand All @@ -214,26 +259,7 @@ export const rpcFactory = ({
requested.headers.get('Content-Type') ?? 'application/json',
);
if (requested.body) {
const sizeLimit = createSizeLogger(maxResponseSize);
const readableStream = Readable.fromWeb(
requested.body as ReadableStream,
);
readableStream
.pipe(sizeLimit)
.on('error', (error) => {
if (error instanceof SizeTooLargeError) {
console.warn(
`[rpcFactory] RPC response too large: ${JSON.stringify(requests)}`,
);
// Payload Too Large
res.status(413).end();
} else {
res.statusCode = 500;
res.end(DEFAULT_API_ERROR_MESSAGE);
}
readableStream.destroy();
})
.pipe(res);
Readable.fromWeb(requested.body as ReadableStream).pipe(res);
} else {
res
.status(requested.status)
Expand Down
Loading