Skip to content

Commit

Permalink
Fixed bug that caused poseidon merkle tree to fail on rust (And added…
Browse files Browse the repository at this point in the history
… the same test to cpp test_hash_api)

Signed-off-by: Koren-Brand <[email protected]>
  • Loading branch information
Koren-Brand committed Nov 11, 2024
1 parent b76fda5 commit 029df1d
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 1 deletion.
7 changes: 6 additions & 1 deletion icicle/backend/cpu/src/hash/cpu_merkle_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,12 @@ namespace icicle {
// This is not the root layer (root)
LayerDB& next_layer = m_layers[next_layer_idx];
const uint64_t next_input_size = next_layer.m_hash.default_input_chunk_size();
const uint64_t next_segment_idx = cur_segment_idx * cur_layer.m_hash.output_size() / next_input_size;
// Ensure next segment does not overflow due to a <NOF_OPERATIONS_PER_TASK+1> sized batch by comparing it to the
// max possible segment index (And taking the smaller one)
const uint64_t max_segment_idx =
(next_layer.m_nof_hashes_2_execute - next_layer.m_last_hash_config.batch) / NOF_OPERATIONS_PER_TASK;
const uint64_t next_segment_idx =
std::min(cur_segment_idx * cur_layer.m_hash.output_size() / next_input_size, max_segment_idx);
const uint64_t next_segment_id = next_segment_idx ^ (next_layer_idx << 56);

// If next_segment does not appear in m_map_segment_id_2_inputs, then add it
Expand Down
83 changes: 83 additions & 0 deletions icicle/tests/test_hash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,4 +1152,87 @@ TEST_F(HashApiTest, poseidon3_batch)
ASSERT_EQ(0, memcmp(output_cpu.get(), output_cuda.get(), config.batch * sizeof(scalar_t)));
}

TEST_F(HashApiTest, poseidon_tree)
{
const uint64_t t = 9;
const uint64_t nof_layers = 4;
uint64_t nof_leaves = 1;
for (int i = 0; i < nof_layers; i++) {
nof_leaves *= t;
}
auto leaves = std::make_unique<scalar_t[]>(nof_leaves);
const uint64_t leaf_size = sizeof(scalar_t);
const uint64_t total_input_size = nof_leaves * leaf_size;
ICICLE_LOG_INFO << total_input_size;
randomize<scalar_t>(leaves.get(), nof_leaves);

std::vector<std::vector<std::byte>> device_roots(s_registered_devices.size());
int device_roots_idx = 0;

for (int i = 0; i < s_registered_devices.size(); i++) {
const auto& device = s_registered_devices[i];
ICICLE_LOG_INFO << "MerkleTreeDeviceBig on device=" << device;
ICICLE_CHECK(icicle_set_device(device));

// Create a Keccak256 hasher with an arity of 2: every 64B -> 32B
auto layer_hash = Poseidon::create<scalar_t>(t);
// Create a vector of `Hash` objects, all initialized with the same `layer_hash`
std::vector<Hash> hashes(nof_layers, layer_hash);

// copy leaves to device
std::byte* device_leaves = nullptr;
ICICLE_CHECK(icicle_malloc((void**)&device_leaves, total_input_size));
ICICLE_CHECK(icicle_copy(device_leaves, leaves.get(), total_input_size));

auto config = default_merkle_tree_config();
config.is_leaves_on_device = true;
auto prover_tree = MerkleTree::create(hashes, leaf_size);
auto verifier_tree = MerkleTree::create(hashes, leaf_size);

// build tree
START_TIMER(MerkleTree_build)
ICICLE_CHECK(prover_tree.build(device_leaves, total_input_size, config));
END_TIMER(MerkleTree_build, "Merkle Tree large", true)

auto [root, root_size] = prover_tree.get_merkle_root();
deep_copy_byte_array_to_vec(root, root_size, device_roots[i]);

// proof leaves and verify
for (int test_leaf_idx = 0; test_leaf_idx < 5; test_leaf_idx++) {
const int leaf_idx = rand() % nof_leaves;

// test non-pruned path
MerkleProof merkle_proof{};
bool verification_valid = false;
ICICLE_CHECK(prover_tree.get_merkle_proof(
device_leaves, total_input_size, leaf_idx, false /*=pruned*/, config, merkle_proof));
ICICLE_CHECK(verifier_tree.verify(merkle_proof, verification_valid));
ASSERT_TRUE(verification_valid);

// test pruned path
verification_valid = false;
ICICLE_CHECK(prover_tree.get_merkle_proof(
device_leaves, total_input_size, leaf_idx, true /*=pruned*/, config, merkle_proof));
ICICLE_CHECK(verifier_tree.verify(merkle_proof, verification_valid));
ASSERT_TRUE(verification_valid);
}
ICICLE_CHECK(icicle_free(device_leaves));
}

// Check valid tree of each device by comparing their roots
for (int i = 1; i < device_roots.size(); i++) {
std::vector<std::byte>& first_root = device_roots[0];
std::vector<std::byte>& root = device_roots[i];
ASSERT_EQ(first_root.size(), root.size());
auto size = root.size();
for (int j = 0; j < size; j++) {
ASSERT_EQ(first_root[j], root[j]) << "Different tree roots:\n"
<< s_registered_devices[0] << " =\t0x"
<< HashApiTest::voidPtrToHexString(first_root.data(), size) << "\n"
<< s_registered_devices[i] << " =\t0x"
<< HashApiTest::voidPtrToHexString(root.data(), size);
}
}
}

#endif // POSEIDON

0 comments on commit 029df1d

Please sign in to comment.