diff --git a/graphbolt/src/concurrent_id_hash_map.cc b/graphbolt/src/concurrent_id_hash_map.cc index e9c81bb67f2e..c43d06320bf9 100644 --- a/graphbolt/src/concurrent_id_hash_map.cc +++ b/graphbolt/src/concurrent_id_hash_map.cc @@ -65,10 +65,10 @@ torch::Tensor ConcurrentIdHashMap::Init( // This code block is to fill the ids into hash_map_. auto unique_ids = torch::empty_like(ids); IdType* unique_ids_data = unique_ids.data_ptr(); - // Fill in the first `num_seeds` ids. - torch::parallel_for(0, num_seeds, kGrainSize, [&](int64_t s, int64_t e) { + // Insert all ids into the hash map. + torch::parallel_for(0, num_ids, kGrainSize, [&](int64_t s, int64_t e) { for (int64_t i = s; i < e; i++) { - InsertAndSet(ids_data[i], static_cast(i)); + InsertAndSetMin(ids_data[i], static_cast(i)); } }); // Place the first `num_seeds` ids. @@ -82,13 +82,16 @@ torch::Tensor ConcurrentIdHashMap::Init( const int64_t num_threads = torch::get_num_threads(); std::vector block_offset(num_threads + 1, 0); - // Insert all elements in this loop. + + // Count the valid numbers in each thread. torch::parallel_for( num_seeds, num_ids, kGrainSize, [&](int64_t s, int64_t e) { size_t count = 0; for (int64_t i = s; i < e; i++) { - valid[i] = Insert(ids_data[i]); - count += valid[i]; + if (MapId(ids_data[i]) == i) { + count++; + valid[i] = 1; + } } auto thread_id = torch::get_thread_num(); block_offset[thread_id + 1] = count; @@ -199,6 +202,27 @@ inline void ConcurrentIdHashMap::InsertAndSet(IdType id, IdType value) { hash_map_.data_ptr()[getValueIndex(pos)] = value; } +template +void ConcurrentIdHashMap::InsertAndSetMin(IdType id, IdType value) { + IdType pos = (id & mask_), delta = 1; + IdType* hash_map_data = hash_map_.data_ptr(); + InsertState state = AttemptInsertAt(pos, id); + while (state == InsertState::OCCUPIED) { + Next(&pos, &delta); + state = AttemptInsertAt(pos, id); + } + + IdType empty_key = static_cast(kEmptyKey); + IdType val_pos = getValueIndex(pos); + IdType old_val = empty_key; + while (old_val == empty_key || old_val > value) { + IdType replaced_val = + CompareAndSwap(&(hash_map_data[val_pos]), old_val, value); + if (old_val == replaced_val) break; + old_val = replaced_val; + } +} + template inline typename ConcurrentIdHashMap::InsertState ConcurrentIdHashMap::AttemptInsertAt(int64_t pos, IdType key) { diff --git a/graphbolt/src/concurrent_id_hash_map.h b/graphbolt/src/concurrent_id_hash_map.h index fe3b306b537d..bb9d54d3cb83 100644 --- a/graphbolt/src/concurrent_id_hash_map.h +++ b/graphbolt/src/concurrent_id_hash_map.h @@ -158,6 +158,16 @@ class ConcurrentIdHashMap { */ inline void InsertAndSet(IdType key, IdType value); + /** + * @brief Insert a key into the hash map. If the key exists, set the value + * with the smaller value. + * + * @param id The key to be inserted. + * @param value The value to be set for the `key`. + * + */ + inline void InsertAndSetMin(IdType id, IdType value); + /** * @brief Attempt to insert the key into the hash map at the given position. *