From a97a469ada10bad6f051ee25a372f9bb68c2da50 Mon Sep 17 00:00:00 2001 From: Hadrien Croubois Date: Thu, 22 Feb 2024 11:42:55 +0100 Subject: [PATCH] add support for custom internal hash in SimpleMerkleTree --- src/core.ts | 20 ++++---- src/format.ts | 2 +- src/hashes.ts | 8 +++ src/simple.test.ts | 125 ++++++++++++++++++++++++++++++++------------- src/simple.ts | 59 +++++++++++++-------- 5 files changed, 145 insertions(+), 69 deletions(-) create mode 100644 src/hashes.ts diff --git a/src/core.ts b/src/core.ts index 7bcef4e..7866588 100644 --- a/src/core.ts +++ b/src/core.ts @@ -1,9 +1,7 @@ -import { keccak256 } from '@ethersproject/keccak256'; import { BytesLike, HexString, toHex, toBytes, concat, compare } from './bytes'; +import { HashPairFn, keccak256SortedPair } from './hashes'; import { throwError } from './utils/throw-error'; -const hashPair = (a: BytesLike, b: BytesLike): HexString => keccak256(concat([a, b].sort(compare))); - const leftChildIndex = (i: number) => 2 * i + 1; const rightChildIndex = (i: number) => 2 * i + 2; const parentIndex = (i: number) => i > 0 ? Math.floor((i - 1) / 2) : throwError('Root has no parent'); @@ -19,7 +17,7 @@ const checkInternalNode = (tree: unknown[], i: number) => void (isInternalNod const checkLeafNode = (tree: unknown[], i: number) => void (isLeafNode(tree, i) || throwError('Index is not a leaf')); const checkValidMerkleNode = (node: BytesLike) => void (isValidMerkleNode(node) || throwError('Merkle tree nodes must be Uint8Array of length 32')); -export function makeMerkleTree(leaves: BytesLike[]): HexString[] { +export function makeMerkleTree(leaves: BytesLike[], hash?: HashPairFn): HexString[] { leaves.forEach(checkValidMerkleNode); if (leaves.length === 0) { @@ -32,7 +30,7 @@ export function makeMerkleTree(leaves: BytesLike[]): HexString[] { tree[tree.length - 1 - i] = toHex(leaf); } for (let i = tree.length - 1 - leaves.length; i >= 0; i--) { - tree[i] = hashPair( + tree[i] = (hash ?? keccak256SortedPair)( tree[leftChildIndex(i)]!, tree[rightChildIndex(i)]!, ); @@ -52,11 +50,11 @@ export function getProof(tree: BytesLike[], index: number): HexString[] { return proof.map(node => toHex(node)); } -export function processProof(leaf: BytesLike, proof: BytesLike[]): HexString { +export function processProof(leaf: BytesLike, proof: BytesLike[], hash?: HashPairFn): HexString { checkValidMerkleNode(leaf); proof.forEach(checkValidMerkleNode); - return toHex(proof.reduce(hashPair, leaf)); + return toHex(proof.reduce(hash ?? keccak256SortedPair, leaf)); } export interface MultiProof { @@ -103,7 +101,7 @@ export function getMultiProof(tree: BytesLike[], indices: number[]): MultiProof< }; } -export function processMultiProof(multiproof: MultiProof): HexString { +export function processMultiProof(multiproof: MultiProof, hash?: HashPairFn): HexString { multiproof.leaves.forEach(checkValidMerkleNode); multiproof.proof.forEach(checkValidMerkleNode); @@ -124,7 +122,7 @@ export function processMultiProof(multiproof: MultiProof): HexString if (a === undefined || b === undefined) { throw new Error('Broken invariant'); } - stack.push(hashPair(a, b)); + stack.push((hash ?? keccak256SortedPair)(a, b)); } if (stack.length + proof.length !== 1) { @@ -134,7 +132,7 @@ export function processMultiProof(multiproof: MultiProof): HexString return toHex(stack.pop() ?? proof.shift()!); } -export function isValidMerkleTree(tree: BytesLike[]): boolean { +export function isValidMerkleTree(tree: BytesLike[], hash?: HashPairFn): boolean { for (const [i, node] of tree.entries()) { if (!isValidMerkleNode(node)) { return false; @@ -147,7 +145,7 @@ export function isValidMerkleTree(tree: BytesLike[]): boolean { if (l < tree.length) { return false; } - } else if (node !== hashPair(tree[l]!, tree[r]!)) { + } else if (node !== (hash ?? keccak256SortedPair)(tree[l]!, tree[r]!)) { return false; } } diff --git a/src/format.ts b/src/format.ts index aac029c..a99a1fc 100644 --- a/src/format.ts +++ b/src/format.ts @@ -2,7 +2,7 @@ import type { HexString } from "./bytes"; // Dump/Load format export type MerkleTreeData = { - format: 'standard-v1'; + format: 'standard-v1' | 'custom-v1'; tree: HexString[]; values: { value: T; diff --git a/src/hashes.ts b/src/hashes.ts new file mode 100644 index 0000000..e3b54aa --- /dev/null +++ b/src/hashes.ts @@ -0,0 +1,8 @@ +import { keccak256 } from '@ethersproject/keccak256'; +import { BytesLike, HexString, concat, compare } from './bytes'; + +export type HashPairFn = (a: BytesLike, b: BytesLike) => HexString; + +export function keccak256SortedPair(a: BytesLike, b: BytesLike): HexString { + return keccak256(concat([a, b].sort(compare))); +} \ No newline at end of file diff --git a/src/simple.test.ts b/src/simple.test.ts index 508016b..91db7ff 100644 --- a/src/simple.test.ts +++ b/src/simple.test.ts @@ -2,12 +2,16 @@ import assert from 'assert/strict'; import { HashZero as zero } from '@ethersproject/constants'; import { keccak256 } from '@ethersproject/keccak256'; import { SimpleMerkleTree } from './simple'; +import { BytesLike, HexString, concat, compare } from './bytes'; + +const reverseHashPair = (a: BytesLike, b: BytesLike): HexString => keccak256(concat([a, b].sort(compare).reverse())); describe('simple merkle tree', () => { for (const opts of [ {}, { sortLeaves: true }, { sortLeaves: false }, + { hashPair: reverseHashPair }, ]) { describe(`with options '${JSON.stringify(opts)}'`, () => { const leaves = 'abcdef'.split('').map(c => keccak256(Buffer.from(c))); @@ -28,7 +32,11 @@ describe('simple merkle tree', () => { assert(tree.verify(id, proof1)); assert(tree.verify(leaf, proof1)); - assert(SimpleMerkleTree.verify(tree.root, leaf, proof1)); + if (opts.hashPair) { + assert(SimpleMerkleTree.verify(tree.root, leaf, proof1, opts.hashPair)); + } else { + assert(SimpleMerkleTree.verify(tree.root, leaf, proof1)); + } } }); @@ -37,7 +45,11 @@ describe('simple merkle tree', () => { const invalidProof = otherTree.getProof(leaf); assert(!tree.verify(leaf, invalidProof)); - assert(!SimpleMerkleTree.verify(tree.root, leaf, invalidProof)); + if (opts.hashPair) { + assert(!SimpleMerkleTree.verify(tree.root, leaf, invalidProof, opts.hashPair)); + } else { + assert(!SimpleMerkleTree.verify(tree.root, leaf, invalidProof)); + } }); it('generates valid multiproofs', () => { @@ -48,7 +60,11 @@ describe('simple merkle tree', () => { assert.deepEqual(proof1, proof2); assert(tree.verifyMultiProof(proof1)); - assert(SimpleMerkleTree.verifyMultiProof(tree.root, proof1)); + if (opts.hashPair) { + assert(SimpleMerkleTree.verifyMultiProof(tree.root, proof1, opts.hashPair)); + } else { + assert(SimpleMerkleTree.verifyMultiProof(tree.root, proof1)); + } } }); @@ -56,45 +72,68 @@ describe('simple merkle tree', () => { const multiProof = otherTree.getMultiProof(leaves.slice(0, 3)); assert(!tree.verifyMultiProof(multiProof)); - assert(!SimpleMerkleTree.verifyMultiProof(tree.root, multiProof)); + if (opts.hashPair) { + assert(!SimpleMerkleTree.verifyMultiProof(tree.root, multiProof, opts.hashPair)); + } else { + assert(!SimpleMerkleTree.verifyMultiProof(tree.root, multiProof)); + } }); it('renders tree representation', () => { - assert.equal( - tree.render(), - opts.sortLeaves == false - ? [ - "0) 0x9012f1e18a87790d2e01faace75aaaca38e53df437cdce2c0552464dda4af49c", - "├─ 1) 0x68203f90e9d07dc5859259d7536e87a6ba9d345f2552b5b9de2999ddce9ce1bf", - "│ ├─ 3) 0xd253a52d4cb00de2895e85f2529e2976e6aaaa5c18106b68ab66813e14415669", - "│ │ ├─ 7) 0xf1918e8562236eb17adc8502332f4c9c82bc14e19bfc0aa10ab674ff75b3d2f3", - "│ │ └─ 8) 0x0b42b6393c1f53060fe3ddbfcd7aadcca894465a5a438f69c87d790b2299b9b2", - "│ └─ 4) 0x805b21d846b189efaeb0377d6bb0d201b3872a363e607c25088f025b0c6ae1f8", - "│ ├─ 9) 0xb5553de315e0edf504d9150af82dafa5c4667fa618ed0a6f19c69b41166c5510", - "│ └─ 10) 0x3ac225168df54212a25c1c01fd35bebfea408fdac2e31ddd6f80a4bbf9a5f1cb", - "└─ 2) 0xf0b49bb4b0d9396e0315755ceafaa280707b32e75e6c9053f5cdf2679dcd5c6a", - " ├─ 5) 0xd1e8aeb79500496ef3dc2e57ba746a8315d048b7a664a2bf948db4fa91960483", - " └─ 6) 0xa8982c89d80987fb9a510e25981ee9170206be21af3c8e0eb312ef1d3382e761", - ].join("\n") - : [ - "0) 0x1b404f199ea828ec5771fb30139c222d8417a82175fefad5cd42bc3a189bd8d5", - "├─ 1) 0xec554bdfb01d31fa838d0830339b0e6e8a70e0d55a8f172ffa8bebbf8e8d5ba0", - "│ ├─ 3) 0x434d51cfeb80272378f4c3a8fd2824561c2cad9fce556ea600d46f20550976a6", - "│ │ ├─ 7) 0xb5553de315e0edf504d9150af82dafa5c4667fa618ed0a6f19c69b41166c5510", - "│ │ └─ 8) 0xa8982c89d80987fb9a510e25981ee9170206be21af3c8e0eb312ef1d3382e761", - "│ └─ 4) 0x7dea550f679f3caab547cbbc5ee1a4c978c8c039b572ba00af1baa6481b88360", - "│ ├─ 9) 0x3ac225168df54212a25c1c01fd35bebfea408fdac2e31ddd6f80a4bbf9a5f1cb", - "│ └─ 10) 0x0b42b6393c1f53060fe3ddbfcd7aadcca894465a5a438f69c87d790b2299b9b2", - "└─ 2) 0xaf46af0745b433e1d5bed9a04b1fdf4002f67a733c20db2fca5b2af6120d9bcb", - " ├─ 5) 0xf1918e8562236eb17adc8502332f4c9c82bc14e19bfc0aa10ab674ff75b3d2f3", - " └─ 6) 0xd1e8aeb79500496ef3dc2e57ba746a8315d048b7a664a2bf948db4fa91960483", - ].join("\n"), - ); + const expected = ( + // standard hash + unsorted + !opts.hashPair && opts.sortLeaves === false + ? [ + "0) 0x9012f1e18a87790d2e01faace75aaaca38e53df437cdce2c0552464dda4af49c", + "├─ 1) 0x68203f90e9d07dc5859259d7536e87a6ba9d345f2552b5b9de2999ddce9ce1bf", + "│ ├─ 3) 0xd253a52d4cb00de2895e85f2529e2976e6aaaa5c18106b68ab66813e14415669", + "│ │ ├─ 7) 0xf1918e8562236eb17adc8502332f4c9c82bc14e19bfc0aa10ab674ff75b3d2f3", + "│ │ └─ 8) 0x0b42b6393c1f53060fe3ddbfcd7aadcca894465a5a438f69c87d790b2299b9b2", + "│ └─ 4) 0x805b21d846b189efaeb0377d6bb0d201b3872a363e607c25088f025b0c6ae1f8", + "│ ├─ 9) 0xb5553de315e0edf504d9150af82dafa5c4667fa618ed0a6f19c69b41166c5510", + "│ └─ 10) 0x3ac225168df54212a25c1c01fd35bebfea408fdac2e31ddd6f80a4bbf9a5f1cb", + "└─ 2) 0xf0b49bb4b0d9396e0315755ceafaa280707b32e75e6c9053f5cdf2679dcd5c6a", + " ├─ 5) 0xd1e8aeb79500496ef3dc2e57ba746a8315d048b7a664a2bf948db4fa91960483", + " └─ 6) 0xa8982c89d80987fb9a510e25981ee9170206be21af3c8e0eb312ef1d3382e761", + ] + // sortLeaves = true | undefined --- standard hash + sorted + : !opts.hashPair + ? [ + "0) 0x1b404f199ea828ec5771fb30139c222d8417a82175fefad5cd42bc3a189bd8d5", + "├─ 1) 0xec554bdfb01d31fa838d0830339b0e6e8a70e0d55a8f172ffa8bebbf8e8d5ba0", + "│ ├─ 3) 0x434d51cfeb80272378f4c3a8fd2824561c2cad9fce556ea600d46f20550976a6", + "│ │ ├─ 7) 0xb5553de315e0edf504d9150af82dafa5c4667fa618ed0a6f19c69b41166c5510", + "│ │ └─ 8) 0xa8982c89d80987fb9a510e25981ee9170206be21af3c8e0eb312ef1d3382e761", + "│ └─ 4) 0x7dea550f679f3caab547cbbc5ee1a4c978c8c039b572ba00af1baa6481b88360", + "│ ├─ 9) 0x3ac225168df54212a25c1c01fd35bebfea408fdac2e31ddd6f80a4bbf9a5f1cb", + "│ └─ 10) 0x0b42b6393c1f53060fe3ddbfcd7aadcca894465a5a438f69c87d790b2299b9b2", + "└─ 2) 0xaf46af0745b433e1d5bed9a04b1fdf4002f67a733c20db2fca5b2af6120d9bcb", + " ├─ 5) 0xf1918e8562236eb17adc8502332f4c9c82bc14e19bfc0aa10ab674ff75b3d2f3", + " └─ 6) 0xd1e8aeb79500496ef3dc2e57ba746a8315d048b7a664a2bf948db4fa91960483", + ] + // non standard hash + : [ + "0) 0x8f0a1adb058c628fa4ce2e7bd26024180b888fec77087d4e5ee6890746e9c6ec", + "├─ 1) 0xb9f5a6bc1b75fadcd9765163dfc8d4865d1608337a2a310ff51fecb431faaee4", + "│ ├─ 3) 0x37d657e93dfbae50b18241610418794b51124af5ca872f1b56c08490cb2905ac", + "│ │ ├─ 7) 0xb5553de315e0edf504d9150af82dafa5c4667fa618ed0a6f19c69b41166c5510", + "│ │ └─ 8) 0xa8982c89d80987fb9a510e25981ee9170206be21af3c8e0eb312ef1d3382e761", + "│ └─ 4) 0xed90ef72e95e6692b91b020dc6cb5c4db9dc149a496799c4318fa8075960c48e", + "│ ├─ 9) 0x3ac225168df54212a25c1c01fd35bebfea408fdac2e31ddd6f80a4bbf9a5f1cb", + "│ └─ 10) 0x0b42b6393c1f53060fe3ddbfcd7aadcca894465a5a438f69c87d790b2299b9b2", + "└─ 2) 0x138c55cca8f6430d75b6bbcea643a7afa8ee74c22643ad76723ecafd4fcd21d4", + " ├─ 5) 0xf1918e8562236eb17adc8502332f4c9c82bc14e19bfc0aa10ab674ff75b3d2f3", + " └─ 6) 0xd1e8aeb79500496ef3dc2e57ba746a8315d048b7a664a2bf948db4fa91960483", + ] + ).join("\n"); + + assert.equal(tree.render(), expected); }); it('dump and load', () => { - const recoveredTree = SimpleMerkleTree.load(tree.dump()); - + const recoveredTree = opts.hashPair + ? SimpleMerkleTree.load(tree.dump(), opts.hashPair) + : SimpleMerkleTree.load(tree.dump()); recoveredTree.validate(); assert.deepEqual(tree, recoveredTree); }); @@ -124,6 +163,20 @@ describe('simple merkle tree', () => { ); }); + it('reject standard tree dump with a custom hash', () => { + assert.throws( + () => SimpleMerkleTree.load({ format: 'standard-v1'} as any, reverseHashPair), + /^Error: Format 'standard-v1' does not support custom hashing functions$/, + ); + }); + + it('reject custom tree dump without a custom hash', () => { + assert.throws( + () => SimpleMerkleTree.load({ format: 'custom-v1'} as any), + /^Error: Format 'custom-v1' requires a hashing function$/, + ); + }); + it('reject malformed tree dump', () => { const loadedTree1 = SimpleMerkleTree.load({ format: 'standard-v1', diff --git a/src/simple.ts b/src/simple.ts index ebdbf6f..6f727df 100644 --- a/src/simple.ts +++ b/src/simple.ts @@ -7,6 +7,10 @@ import { compare, } from './bytes'; +import { + HashPairFn, +} from './hashes'; + import { MerkleTreeData, } from './format'; @@ -32,6 +36,7 @@ export class SimpleMerkleTree { private constructor( private readonly tree: HexString[], private readonly values: { value: HexString, treeIndex: number }[], + private readonly hashPair?: HashPairFn, ) { this.hashLookup = Object.fromEntries(values.map(({ value }, valueIndex) => [ @@ -40,8 +45,8 @@ export class SimpleMerkleTree { ])); } - static of(values: BytesLike[], options: MerkleTreeOptions = {}) { - const { sortLeaves } = { ...defaultOptions, ...options }; + static of(values: BytesLike[], options: MerkleTreeOptions & { hashPair?: HashPairFn } = {}) { + const { sortLeaves, hashPair } = { ...defaultOptions, ...options }; values.forEach((value, i) => { if (toBytes(value).length !== 32) { @@ -55,43 +60,55 @@ export class SimpleMerkleTree { hashedValues.sort((a, b) => compare(a.hash, b.hash)); } - const tree = makeMerkleTree(hashedValues.map(v => v.hash)); + const tree = makeMerkleTree(hashedValues.map(v => v.hash), hashPair); const indexedValues = values.map(value => ({ value: toHex(value), treeIndex: 0 })); for (const [leafIndex, { valueIndex }] of hashedValues.entries()) { indexedValues[valueIndex]!.treeIndex = tree.length - leafIndex - 1; } - return new SimpleMerkleTree(tree, indexedValues); + return new SimpleMerkleTree(tree, indexedValues, hashPair); } - static load(data: MerkleTreeData): SimpleMerkleTree { - if (data.format !== 'standard-v1') { - throwError(`Unknown format '${data.format}'`); + static load(data: MerkleTreeData, hashPair ?: HashPairFn): SimpleMerkleTree { + switch (data.format) { + case 'standard-v1': + if (hashPair !== undefined) throwError(`Format '${data.format}' does not support custom hashing functions`); + break; + case 'custom-v1': + if (hashPair === undefined) throwError(`Format '${data.format}' requires a hashing function`); + // TODO: check that the hash matches the data. + break; + default: + throwError(`Unknown format '${data.format}'`); } return new SimpleMerkleTree( data.tree, data.values.map(({ value, treeIndex }) => ({ value: toHex(value), treeIndex })), + hashPair, ); } - static verify(root: BytesLike, leaf: BytesLike, proof: BytesLike[]): boolean { - return toHex(root) === processProof(leaf, proof); + static verify(root: BytesLike, leaf: BytesLike, proof: BytesLike[], hashPair ?: HashPairFn): boolean { + return toHex(root) === processProof(leaf, proof, hashPair); } - static verifyMultiProof(root: BytesLike, multiproof: MultiProof): boolean { - return toHex(root) === processMultiProof({ - leaves: multiproof.leaves, - proof: multiproof.proof, - proofFlags: multiproof.proofFlags, - }); + static verifyMultiProof(root: BytesLike, multiproof: MultiProof, hashPair ?: HashPairFn): boolean { + return toHex(root) === processMultiProof( + { + leaves: multiproof.leaves, + proof: multiproof.proof, + proofFlags: multiproof.proofFlags, + }, + hashPair, + ); } dump(): MerkleTreeData { return { - format: 'standard-v1', - tree: this.tree, - values: this.values, + format: this.hashPair === undefined ? 'standard-v1' : 'custom-v1', + tree: this.tree, + values: this.values, }; } @@ -113,7 +130,7 @@ export class SimpleMerkleTree { for (let i = 0; i < this.values.length; i++) { this.validateValue(i); } - if (!isValidMerkleTree(this.tree)) { + if (!isValidMerkleTree(this.tree, this.hashPair)) { throwError('Merkle tree is invalid'); } } @@ -167,7 +184,7 @@ export class SimpleMerkleTree { } private _verify(leafHash: BytesLike, proof: BytesLike[]): boolean { - return this.root === processProof(leafHash, proof); + return this.root === processProof(leafHash, proof, this.hashPair); } verifyMultiProof(multiproof: MultiProof): boolean { @@ -179,7 +196,7 @@ export class SimpleMerkleTree { } private _verifyMultiProof(multiproof: MultiProof): boolean { - return this.root === processMultiProof(multiproof); + return this.root === processMultiProof(multiproof, this.hashPair); } private validateValue(valueIndex: number): HexString {