From c41401d8b83a78432ae04e560053a83b46686d9d Mon Sep 17 00:00:00 2001 From: Takeru Ohta Date: Fri, 1 Dec 2023 08:16:06 +0900 Subject: [PATCH] Update Node::{insert(), get()} to handle String key correctly --- src/map.rs | 2 +- src/node.rs | 15 +++++++-------- src/tree.rs | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/map.rs b/src/map.rs index 9c5445d..0c093be 100644 --- a/src/map.rs +++ b/src/map.rs @@ -145,7 +145,7 @@ impl GenericPatriciaMap { /// assert_eq!(map.get("bar"), None); /// ``` pub fn get>(&self, key: Q) -> Option<&V> { - self.tree.get(key.as_ref().as_bytes()) + self.tree.get(key.as_ref()) } /// Returns a mutable reference to the value corresponding to the key. diff --git a/src/node.rs b/src/node.rs index 9c4475d..1b04447 100644 --- a/src/node.rs +++ b/src/node.rs @@ -405,21 +405,21 @@ impl Node { } } - pub(crate) fn get(&self, key: &[u8]) -> Option<&V> { - let common_prefix_len = self.skip_common_prefix(key); - let next = &key[common_prefix_len..]; + pub(crate) fn get(&self, key: &K) -> Option<&V> { + let (next, common_prefix_len) = key.strip_common_prefix_and_len(self.label()); if common_prefix_len == self.label().len() { - if next.is_empty() { + if next.as_bytes().is_empty() { self.value() } else { self.child().and_then(|child| child.get(next)) } - } else if common_prefix_len == 0 && self.label().first() <= key.first() { + } else if common_prefix_len == 0 && key.cmp_first_item(self.label()).is_ge() { self.sibling().and_then(|sibling| sibling.get(next)) } else { None } } + pub(crate) fn get_mut(&mut self, key: &[u8]) -> Option<&mut V> { let common_prefix_len = self.skip_common_prefix(key); let next = &key[common_prefix_len..]; @@ -583,7 +583,7 @@ impl Node { } } pub(crate) fn insert(&mut self, key: &K, value: V) -> Option { - if self.label().first() > key.as_bytes().first() { + if key.cmp_first_item(self.label()).is_lt() { let this = Node { ptr: self.ptr, _value: PhantomData, @@ -594,8 +594,7 @@ impl Node { return None; } - let next = key.strip_common_prefix(self.label()); - let common_prefix_len = key.as_bytes().len() - next.as_bytes().len(); + let (next, common_prefix_len) = key.strip_common_prefix_and_len(self.label()); let is_label_matched = common_prefix_len == self.label().len(); if next.as_bytes().is_empty() { if is_label_matched { diff --git a/src/tree.rs b/src/tree.rs index b754fb2..bf06f94 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -32,7 +32,7 @@ impl PatriciaTree { None } } - pub fn get(&self, key: &[u8]) -> Option<&V> { + pub fn get(&self, key: &K) -> Option<&V> { self.root.get(key) } pub fn get_mut(&mut self, key: &[u8]) -> Option<&mut V> {