diff --git a/Cargo.toml b/Cargo.toml index 6eb0bbf..4922458 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fastbloom_rs" -version = "0.5.0" +version = "0.5.1" edition = "2021" authors = ["Yan Kun "] description = "Some fast bloom filter implemented by Rust for Python and Rust! 10x faster than pybloom!" diff --git a/fastbloom-rs/Cargo.toml b/fastbloom-rs/Cargo.toml index 353d17a..4a23547 100644 --- a/fastbloom-rs/Cargo.toml +++ b/fastbloom-rs/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fastbloom-rs" -version = "0.5.0" +version = "0.5.1" edition = "2021" authors = ["Yan Kun "] description = "Some fast bloom filter implemented by Rust for Python and Rust!" diff --git a/fastbloom-rs/src/bloom.rs b/fastbloom-rs/src/bloom.rs index 1871e4d..8d39f49 100644 --- a/fastbloom-rs/src/bloom.rs +++ b/fastbloom-rs/src/bloom.rs @@ -1,4 +1,5 @@ use std::clone; +use std::cmp::min; use std::ptr::slice_from_raw_parts; use fastmurmur3::murmur3_x64_128; @@ -459,6 +460,31 @@ from_array!(from_u16_array, u16, 4); from_array!(from_u32_array, u32, 8); from_array!(from_u64_array, u64, 16); +impl CountingBloomFilter { + /// Get the estimate count for element in this counting bloom filter. + /// See: https://github.com/yankun1992/fastbloom/issues/3 + pub fn estimate_count(&self, element: &[u8]) -> usize { + let m = self.config.size; + let hash1 = xxh3_64_with_seed(element, 0) % m; + let hash2 = xxh3_64_with_seed(element, 32) % m; + + let mut res = self.counting_vec.get(hash1 as usize); + if res == 0 { return 0; } + + for i in 1..self.config.hashes as u64 { + let mo = ((hash1 + i * hash2) % m) as usize; + let count = self.counting_vec.get(mo); + if count == 0 { return 0; } else { res = min(count, res) } + } + + res + } + + /// Get the underlying counter at index. + pub fn counter_at(&self, index: u64) -> usize { + self.counting_vec.get(index as usize) + } +} impl Membership for CountingBloomFilter { fn add(&mut self, element: &[u8]) { @@ -785,4 +811,26 @@ fn counting_bloom_hash_indices_test() { bloom.remove(b"hello"); assert_eq!(bloom.contains(b"hello"), false); assert_eq!(bloom.contains_hash_indices(&bloom.get_hash_indices(b"hello")), false); -} \ No newline at end of file +} + +#[test] +fn counting_bloom_estimate_count() { + let mut builder = + FilterBuilder::new(10_000, 0.01); + let mut bloom = builder.build_counting_bloom_filter(); + + bloom.add(b"hello"); + bloom.add(b"world"); + + assert_eq!(bloom.estimate_count(b"hello"), 1); + let indices = bloom.get_hash_indices(b"hello"); + + for index in indices { + assert_eq!(bloom.counter_at(index), 1) + } + + assert_eq!(bloom.estimate_count(b"world"), 1); + for index in bloom.get_hash_indices(b"world") { + assert!(bloom.counter_at(index) <= 2); + } +} diff --git a/fastbloom_rs/fastbloom_rs.pyi b/fastbloom_rs/fastbloom_rs.pyi index ba73d56..3ec98da 100644 --- a/fastbloom_rs/fastbloom_rs.pyi +++ b/fastbloom_rs/fastbloom_rs.pyi @@ -172,6 +172,18 @@ class PyCountingBloomFilter(object): def get_hash_indices_str(self, element: str) -> Sequence[int]: ... + def estimate_count(self, element: bytes) -> int: + ... + + def estimate_count_int(self, element: int) -> int: + ... + + def estimate_count_str(self, element: str) -> int: + ... + + def counter_at(self, index: int) -> int: + ... + @staticmethod def from_bytes(array: bytes, hashes: int, enable_repeat_insert: bool) -> PyCountingBloomFilter: ... diff --git a/fastbloom_rs/filter.py b/fastbloom_rs/filter.py index 15e1ebf..55cd14f 100644 --- a/fastbloom_rs/filter.py +++ b/fastbloom_rs/filter.py @@ -494,6 +494,33 @@ def get_hash_indices(self, element: Union[str, int, bytes]) -> Sequence[int]: else: return self._py_counting_bloom.get_hash_indices_str(str(element)) + def estimate_count(self, element: Union[str, int, bytes]) -> int: + """ + Get the estimate count for element in this counting bloom filter. + See: https://github.com/yankun1992/fastbloom/issues/3 + + :param element: + :return: + """ + if isinstance(element, int): + return self._py_counting_bloom.estimate_count_int(element) + elif isinstance(element, str): + return self._py_counting_bloom.estimate_count_str(element) + elif isinstance(element, bytes): + return self._py_counting_bloom.estimate_count(element) + else: + return self._py_counting_bloom.estimate_count_str(str(element)) + + def counter_at(self, index: int) -> int: + """ + Get the underlying counter at index. + + :param index: index of counter slot. + :return: + """ + assert index > 0 + return self._py_counting_bloom.counter_at(index) + def config(self) -> FilterBuilder: """ Returns the configuration/builder of the Bloom filter. diff --git a/py_tests/test_counting_bloom_filter.py b/py_tests/test_counting_bloom_filter.py index a309cfa..2b1c835 100644 --- a/py_tests/test_counting_bloom_filter.py +++ b/py_tests/test_counting_bloom_filter.py @@ -91,3 +91,24 @@ def test_hash_indices(): assert not cbf2.contains_hash_indices(cbf2.get_hash_indices(31)) assert not cbf2.contains_hash_indices(cbf2.get_hash_indices('world')) assert cbf2.contains_hash_indices(cbf2.get_hash_indices('Yan Kun')) + + +def test_estimate_count(): + builder = FilterBuilder(100_000, 0.01) + # enable repeat insert + builder.enable_repeat_insert(True) + cbf = builder.build_counting_bloom_filter() # type: CountingBloomFilter + + cbf.add(b'hello') + + assert cbf.estimate_count(b'hello') == 1 + + for index in cbf.get_hash_indices(b'hello'): + assert cbf.counter_at(index) == 1 + + cbf.add(b'world') + for index in cbf.get_hash_indices(b'world'): + assert cbf.counter_at(index) <= 2 + + cbf.add(b'hello') + assert cbf.estimate_count(b'hello') == 2 diff --git a/src/pybloom.rs b/src/pybloom.rs index 7d58053..b8c4098 100644 --- a/src/pybloom.rs +++ b/src/pybloom.rs @@ -245,6 +245,22 @@ impl PyCountingBloomFilter { Ok(self.counting_bloom_filter.get_hash_indices(bts.as_bytes())) } + pub fn estimate_count_int(&self, element: i64) -> PyResult { + Ok(self.counting_bloom_filter.estimate_count(&i64::to_le_bytes(element)) as u32) + } + + pub fn estimate_count_str(&self, element: &str) -> PyResult { + Ok(self.counting_bloom_filter.estimate_count(element.as_bytes()) as u32) + } + + pub fn estimate_count(&self, element: &PyBytes) -> PyResult { + Ok(self.counting_bloom_filter.estimate_count(element.as_bytes()) as u32) + } + + pub fn counter_at(&self, index: i64) -> PyResult { + Ok(self.counting_bloom_filter.counter_at(index as u64) as u64) + } + #[staticmethod] pub fn from_bytes(array: &[u8], hashes: u32, enable_repeat_insert: bool) -> PyResult { Ok(PyCountingBloomFilter {