From 029df1d78b75bb2d5c2d7288fc25d4ae5d59e4a4 Mon Sep 17 00:00:00 2001 From: Koren-Brand Date: Mon, 11 Nov 2024 15:41:35 +0200 Subject: [PATCH] Fixed bug that caused poseidon merkle tree to fail on rust (And added the same test to cpp test_hash_api) Signed-off-by: Koren-Brand --- .../backend/cpu/src/hash/cpu_merkle_tree.cpp | 7 +- icicle/tests/test_hash_api.cpp | 83 +++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/icicle/backend/cpu/src/hash/cpu_merkle_tree.cpp b/icicle/backend/cpu/src/hash/cpu_merkle_tree.cpp index 180742218..38fea5492 100644 --- a/icicle/backend/cpu/src/hash/cpu_merkle_tree.cpp +++ b/icicle/backend/cpu/src/hash/cpu_merkle_tree.cpp @@ -501,7 +501,12 @@ namespace icicle { // This is not the root layer (root) LayerDB& next_layer = m_layers[next_layer_idx]; const uint64_t next_input_size = next_layer.m_hash.default_input_chunk_size(); - const uint64_t next_segment_idx = cur_segment_idx * cur_layer.m_hash.output_size() / next_input_size; + // Ensure next segment does not overflow due to a sized batch by comparing it to the + // max possible segment index (And taking the smaller one) + const uint64_t max_segment_idx = + (next_layer.m_nof_hashes_2_execute - next_layer.m_last_hash_config.batch) / NOF_OPERATIONS_PER_TASK; + const uint64_t next_segment_idx = + std::min(cur_segment_idx * cur_layer.m_hash.output_size() / next_input_size, max_segment_idx); const uint64_t next_segment_id = next_segment_idx ^ (next_layer_idx << 56); // If next_segment does not appear in m_map_segment_id_2_inputs, then add it diff --git a/icicle/tests/test_hash_api.cpp b/icicle/tests/test_hash_api.cpp index 5f9a7e5db..10f91fea6 100644 --- a/icicle/tests/test_hash_api.cpp +++ b/icicle/tests/test_hash_api.cpp @@ -1152,4 +1152,87 @@ TEST_F(HashApiTest, poseidon3_batch) ASSERT_EQ(0, memcmp(output_cpu.get(), output_cuda.get(), config.batch * sizeof(scalar_t))); } +TEST_F(HashApiTest, poseidon_tree) +{ + const uint64_t t = 9; + const uint64_t nof_layers = 4; + uint64_t nof_leaves = 1; + for (int i = 0; i < nof_layers; i++) { + nof_leaves *= t; + } + auto leaves = std::make_unique(nof_leaves); + const uint64_t leaf_size = sizeof(scalar_t); + const uint64_t total_input_size = nof_leaves * leaf_size; + ICICLE_LOG_INFO << total_input_size; + randomize(leaves.get(), nof_leaves); + + std::vector> device_roots(s_registered_devices.size()); + int device_roots_idx = 0; + + for (int i = 0; i < s_registered_devices.size(); i++) { + const auto& device = s_registered_devices[i]; + ICICLE_LOG_INFO << "MerkleTreeDeviceBig on device=" << device; + ICICLE_CHECK(icicle_set_device(device)); + + // Create a Keccak256 hasher with an arity of 2: every 64B -> 32B + auto layer_hash = Poseidon::create(t); + // Create a vector of `Hash` objects, all initialized with the same `layer_hash` + std::vector hashes(nof_layers, layer_hash); + + // copy leaves to device + std::byte* device_leaves = nullptr; + ICICLE_CHECK(icicle_malloc((void**)&device_leaves, total_input_size)); + ICICLE_CHECK(icicle_copy(device_leaves, leaves.get(), total_input_size)); + + auto config = default_merkle_tree_config(); + config.is_leaves_on_device = true; + auto prover_tree = MerkleTree::create(hashes, leaf_size); + auto verifier_tree = MerkleTree::create(hashes, leaf_size); + + // build tree + START_TIMER(MerkleTree_build) + ICICLE_CHECK(prover_tree.build(device_leaves, total_input_size, config)); + END_TIMER(MerkleTree_build, "Merkle Tree large", true) + + auto [root, root_size] = prover_tree.get_merkle_root(); + deep_copy_byte_array_to_vec(root, root_size, device_roots[i]); + + // proof leaves and verify + for (int test_leaf_idx = 0; test_leaf_idx < 5; test_leaf_idx++) { + const int leaf_idx = rand() % nof_leaves; + + // test non-pruned path + MerkleProof merkle_proof{}; + bool verification_valid = false; + ICICLE_CHECK(prover_tree.get_merkle_proof( + device_leaves, total_input_size, leaf_idx, false /*=pruned*/, config, merkle_proof)); + ICICLE_CHECK(verifier_tree.verify(merkle_proof, verification_valid)); + ASSERT_TRUE(verification_valid); + + // test pruned path + verification_valid = false; + ICICLE_CHECK(prover_tree.get_merkle_proof( + device_leaves, total_input_size, leaf_idx, true /*=pruned*/, config, merkle_proof)); + ICICLE_CHECK(verifier_tree.verify(merkle_proof, verification_valid)); + ASSERT_TRUE(verification_valid); + } + ICICLE_CHECK(icicle_free(device_leaves)); + } + + // Check valid tree of each device by comparing their roots + for (int i = 1; i < device_roots.size(); i++) { + std::vector& first_root = device_roots[0]; + std::vector& root = device_roots[i]; + ASSERT_EQ(first_root.size(), root.size()); + auto size = root.size(); + for (int j = 0; j < size; j++) { + ASSERT_EQ(first_root[j], root[j]) << "Different tree roots:\n" + << s_registered_devices[0] << " =\t0x" + << HashApiTest::voidPtrToHexString(first_root.data(), size) << "\n" + << s_registered_devices[i] << " =\t0x" + << HashApiTest::voidPtrToHexString(root.data(), size); + } + } +} + #endif // POSEIDON