diff --git a/rust/types/src/data_record.rs b/rust/types/src/data_record.rs index ec1f866ad6f..ca21d036bdd 100644 --- a/rust/types/src/data_record.rs +++ b/rust/types/src/data_record.rs @@ -10,6 +10,29 @@ pub struct DataRecord<'a> { pub document: Option<&'a str>, } +#[derive(Debug, Clone)] +pub struct OwnedDataRecord { + pub id: String, + pub embedding: Vec, + pub metadata: Option, + pub document: Option, +} + +impl<'a> From<&DataRecord<'a>> for OwnedDataRecord { + fn from(data_record: &DataRecord<'a>) -> Self { + let id = data_record.id.to_string(); + let embedding = data_record.embedding.to_vec(); + let metadata = data_record.metadata.clone(); + let document = data_record.document.map(|doc| doc.to_string()); + OwnedDataRecord { + id, + embedding, + metadata, + document, + } + } +} + impl DataRecord<'_> { pub fn get_size(&self) -> usize { let id_size = self.id.len(); @@ -28,4 +51,8 @@ impl DataRecord<'_> { }; id_size + embedding_size + metadata_size + document_size } + + pub fn to_owned(&self) -> OwnedDataRecord { + self.into() + } } diff --git a/rust/worker/src/execution/operators/filter.rs b/rust/worker/src/execution/operators/filter.rs index 635636b1137..5a3fdb53a7d 100644 --- a/rust/worker/src/execution/operators/filter.rs +++ b/rust/worker/src/execution/operators/filter.rs @@ -106,7 +106,7 @@ pub(crate) struct MetadataLogReader<'me> { } impl<'me> MetadataLogReader<'me> { - pub(crate) fn new(logs: &'me Chunk>) -> Self { + pub(crate) fn new(logs: &'me Chunk) -> Self { let mut compact_metadata: HashMap<_, BTreeMap<&MetadataValue, RoaringBitmap>> = HashMap::new(); let mut document = HashMap::new(); diff --git a/rust/worker/src/execution/operators/hnsw_knn.rs b/rust/worker/src/execution/operators/hnsw_knn.rs index c5e57467d5d..5d8623a1f95 100644 --- a/rust/worker/src/execution/operators/hnsw_knn.rs +++ b/rust/worker/src/execution/operators/hnsw_knn.rs @@ -65,7 +65,7 @@ impl ChromaError for HnswKnnOperatorError { impl HnswKnnOperator { async fn get_disallowed_ids<'referred_data>( &self, - logs: Chunk>, + logs: Chunk, record_segment_reader: &RecordSegmentReader<'_>, ) -> Result, Box> { let mut disallowed_ids = Vec::new(); diff --git a/rust/worker/src/segment/distributed_hnsw_segment.rs b/rust/worker/src/segment/distributed_hnsw_segment.rs index bff4a208b86..cb616393c8b 100644 --- a/rust/worker/src/segment/distributed_hnsw_segment.rs +++ b/rust/worker/src/segment/distributed_hnsw_segment.rs @@ -238,10 +238,10 @@ impl DistributedHNSWSegmentWriter { } } -impl<'a> SegmentWriter<'a> for DistributedHNSWSegmentWriter { +impl SegmentWriter for DistributedHNSWSegmentWriter { async fn apply_materialized_log_chunk( &self, - records: chroma_types::Chunk>, + records: chroma_types::Chunk, ) -> Result<(), ApplyMaterializedLogError> { for (record, _) in records.iter() { match record.final_operation { diff --git a/rust/worker/src/segment/metadata_segment.rs b/rust/worker/src/segment/metadata_segment.rs index c5410479cf9..7cc280ce7c8 100644 --- a/rust/worker/src/segment/metadata_segment.rs +++ b/rust/worker/src/segment/metadata_segment.rs @@ -530,16 +530,20 @@ impl<'me> MetadataSegmentWriter<'me> { } } -impl<'log_records> SegmentWriter<'log_records> for MetadataSegmentWriter<'_> { +impl SegmentWriter for MetadataSegmentWriter<'_> { async fn apply_materialized_log_chunk( &self, - records: Chunk>, + records: Chunk, ) -> Result<(), ApplyMaterializedLogError> { let mut count = 0u64; let full_text_writer_batch = records.iter().filter_map(|record| { let offset_id = record.0.offset_id; - let old_document = record.0.data_record.as_ref().and_then(|r| r.document); - let new_document = record.0.final_document; + let old_document = record + .0 + .data_record + .as_ref() + .and_then(|r| r.document.as_deref()); + let new_document = &record.0.final_document; if matches!( record.0.final_operation, diff --git a/rust/worker/src/segment/record_segment.rs b/rust/worker/src/segment/record_segment.rs index 379cc918749..701774b5350 100644 --- a/rust/worker/src/segment/record_segment.rs +++ b/rust/worker/src/segment/record_segment.rs @@ -61,9 +61,9 @@ pub enum RecordSegmentWriterCreationError { } impl RecordSegmentWriter { - async fn construct_and_set_data_record<'a>( + async fn construct_and_set_data_record( &self, - mat_record: &MaterializedLogRecord<'a>, + mat_record: &MaterializedLogRecord, user_id: &str, offset_id: u32, ) -> Result<(), ApplyMaterializedLogError> { @@ -337,10 +337,10 @@ impl ChromaError for ApplyMaterializedLogError { } } -impl<'a> SegmentWriter<'a> for RecordSegmentWriter { +impl SegmentWriter for RecordSegmentWriter { async fn apply_materialized_log_chunk( &self, - records: Chunk>, + records: Chunk, ) -> Result<(), ApplyMaterializedLogError> { // The max new offset id introduced by materialized logs is initialized as zero // Since offset id should start from 1, we use this to indicate no new offset id @@ -357,7 +357,11 @@ impl<'a> SegmentWriter<'a> for RecordSegmentWriter { .user_id_to_id .as_ref() .unwrap() - .set::<&str, u32>("", log_record.user_id.unwrap(), log_record.offset_id) + .set::<&str, u32>( + "", + log_record.user_id.as_ref().unwrap(), + log_record.offset_id, + ) .await { Ok(()) => (), @@ -370,7 +374,11 @@ impl<'a> SegmentWriter<'a> for RecordSegmentWriter { .id_to_user_id .as_ref() .unwrap() - .set::("", log_record.offset_id, log_record.user_id.unwrap().to_string()) + .set::( + "", + log_record.offset_id, + log_record.user_id.clone().unwrap(), + ) .await { Ok(()) => (), @@ -382,7 +390,7 @@ impl<'a> SegmentWriter<'a> for RecordSegmentWriter { match self .construct_and_set_data_record( log_record, - log_record.user_id.unwrap(), + log_record.user_id.as_ref().unwrap(), log_record.offset_id, ) .await @@ -415,7 +423,7 @@ impl<'a> SegmentWriter<'a> for RecordSegmentWriter { match self .construct_and_set_data_record( log_record, - log_record.data_record.as_ref().unwrap().id, + &log_record.data_record.as_ref().unwrap().id, log_record.offset_id, ) .await @@ -432,7 +440,7 @@ impl<'a> SegmentWriter<'a> for RecordSegmentWriter { .user_id_to_id .as_ref() .unwrap() - .delete::<&str, u32>("", log_record.data_record.as_ref().unwrap().id) + .delete::<&str, u32>("", &log_record.data_record.as_ref().unwrap().id) .await { Ok(()) => (), diff --git a/rust/worker/src/segment/types.rs b/rust/worker/src/segment/types.rs index 11e1845741f..f822b8ce98a 100644 --- a/rust/worker/src/segment/types.rs +++ b/rust/worker/src/segment/types.rs @@ -3,7 +3,7 @@ use chroma_error::{ChromaError, ErrorCodes}; use chroma_types::{ Chunk, DataRecord, DeletedMetadata, LogRecord, MaterializedLogOperation, Metadata, MetadataDelta, MetadataValue, MetadataValueConversionError, Operation, OperationRecord, - UpdateMetadata, UpdateMetadataValue, + OwnedDataRecord, UpdateMetadata, UpdateMetadataValue, }; use std::collections::{HashMap, HashSet}; use std::sync::atomic::AtomicU32; @@ -113,10 +113,10 @@ impl ChromaError for LogMaterializerError { } #[derive(Debug, Clone)] -pub struct MaterializedLogRecord<'referred_data> { +pub struct MaterializedLogRecord { // This is the data record read from the record segment for this id. // None if the record exists only in the log. - pub(crate) data_record: Option>, + pub(crate) data_record: Option, // If present in the record segment then it is the offset id // in the record segment at which the record was found. // If not present in the segment then it is the offset id @@ -125,7 +125,7 @@ pub struct MaterializedLogRecord<'referred_data> { // Set only for the records that are being inserted for the first time // in the log since data_record will be None in such cases. For other // cases, just read from data record. - pub(crate) user_id: Option<&'referred_data str>, + pub(crate) user_id: Option, // There can be several entries in the log for an id. This is the final // operation that needs to be done on it. For e.g. // If log has [Update, Update, Delete] then final operation is Delete. @@ -149,15 +149,15 @@ pub struct MaterializedLogRecord<'referred_data> { // This is the final document obtained from the last non null operation. // E.g. if log has [Insert(str0), Update(str1), Update(str2), Update()] then this will contain // str2. None if final operation is Delete. - pub(crate) final_document: Option<&'referred_data str>, + pub(crate) final_document: Option, // Similar to above, this is the final embedding obtained // from the last non null operation. // E.g. if log has [Insert(emb0), Update(emb1), Update(emb2), Update()] // then this will contain emb2. None if final operation is Delete. - pub(crate) final_embedding: Option<&'referred_data [f32]>, + pub(crate) final_embedding: Option>, } -impl<'referred_data> MaterializedLogRecord<'referred_data> { +impl MaterializedLogRecord { // Performs a deep copy of the document so only use it if really // needed. If you only need a reference then use merged_document_ref // defined below. @@ -165,12 +165,12 @@ impl<'referred_data> MaterializedLogRecord<'referred_data> { if self.final_operation == MaterializedLogOperation::OverwriteExisting || self.final_operation == MaterializedLogOperation::AddNew { - return self.final_document.map(|doc| doc.to_string()); + return self.final_document.clone(); } - return match self.final_document { - Some(doc) => Some(doc.to_string()), + return match self.final_document.clone() { + Some(doc) => Some(doc), None => match self.data_record.as_ref() { - Some(data_record) => data_record.document.map(|doc| doc.to_string()), + Some(data_record) => data_record.document.clone(), None => None, }, }; @@ -180,18 +180,12 @@ impl<'referred_data> MaterializedLogRecord<'referred_data> { if self.final_operation == MaterializedLogOperation::OverwriteExisting || self.final_operation == MaterializedLogOperation::AddNew { - return match self.final_document { - Some(doc) => Some(doc), - None => None, - }; + return self.final_document.as_deref(); } - return match self.final_document { + return match &self.final_document { Some(doc) => Some(doc), None => match self.data_record.as_ref() { - Some(data_record) => match data_record.document { - Some(doc) => Some(doc), - None => None, - }, + Some(data_record) => data_record.document.as_deref(), None => None, }, }; @@ -200,8 +194,8 @@ impl<'referred_data> MaterializedLogRecord<'referred_data> { // Performs a deep copy of the user id so only use it if really // needed. If you only need reference then use merged_user_id_ref below. pub(crate) fn merged_user_id(&self) -> String { - match self.user_id { - Some(id) => id.to_string(), + match &self.user_id { + Some(id) => id.clone(), None => match &self.data_record { Some(data_record) => data_record.id.to_string(), None => panic!("Expected at least one user id to be set"), @@ -209,11 +203,12 @@ impl<'referred_data> MaterializedLogRecord<'referred_data> { } } + // todo: needed? pub(crate) fn merged_user_id_ref(&self) -> &str { - match self.user_id { - Some(id) => id, + match &self.user_id { + Some(id) => id.as_str(), None => match &self.data_record { - Some(data_record) => data_record.id, + Some(data_record) => &data_record.id, None => panic!("Expected at least one user id to be set"), }, } @@ -249,7 +244,7 @@ impl<'referred_data> MaterializedLogRecord<'referred_data> { final_metadata } - pub(crate) fn metadata_delta(&'referred_data self) -> MetadataDelta<'referred_data> { + pub(crate) fn metadata_delta(&self) -> MetadataDelta<'_> { let mut metadata_delta = MetadataDelta::new(); let mut base_metadata: HashMap<&str, &MetadataValue> = HashMap::new(); if let Some(data_record) = &self.data_record { @@ -321,29 +316,27 @@ impl<'referred_data> MaterializedLogRecord<'referred_data> { if self.final_operation == MaterializedLogOperation::OverwriteExisting || self.final_operation == MaterializedLogOperation::AddNew { - return match self.final_embedding { + return match &self.final_embedding { Some(embed) => embed, None => panic!("Expected source of embedding"), }; } - return match self.final_embedding { + return match &self.final_embedding { Some(embed) => embed, None => match self.data_record.as_ref() { - Some(data_record) => data_record.embedding, + Some(data_record) => &data_record.embedding, None => panic!("Expected at least one source of embedding"), }, }; } } -impl<'referred_data> From<(DataRecord<'referred_data>, u32)> - for MaterializedLogRecord<'referred_data> -{ +impl<'referred_data> From<(DataRecord<'referred_data>, u32)> for MaterializedLogRecord { fn from(data_record_info: (DataRecord<'referred_data>, u32)) -> Self { let data_record = data_record_info.0; let offset_id = data_record_info.1; Self { - data_record: Some(data_record), + data_record: Some(data_record.to_owned()), offset_id, user_id: None, final_operation: MaterializedLogOperation::Initial, @@ -359,7 +352,7 @@ impl<'referred_data> From<(DataRecord<'referred_data>, u32)> // in the log (OperationRecord), offset id in storage where it will be stored (u32) // and user id (str). impl<'referred_data> TryFrom<(&'referred_data OperationRecord, u32, &'referred_data str)> - for MaterializedLogRecord<'referred_data> + for MaterializedLogRecord { type Error = LogMaterializerError; @@ -387,9 +380,8 @@ impl<'referred_data> TryFrom<(&'referred_data OperationRecord, u32, &'referred_d } }; - let document = log_record.document.as_deref(); let embedding = match &log_record.embedding { - Some(embedding) => Some(embedding.as_slice()), + Some(embedding) => Some(embedding.clone()), None => { return Err(LogMaterializerError::EmbeddingMaterialization); } @@ -398,11 +390,11 @@ impl<'referred_data> TryFrom<(&'referred_data OperationRecord, u32, &'referred_d Ok(Self { data_record: None, offset_id, - user_id: Some(user_id), + user_id: Some(user_id.to_string()), final_operation: MaterializedLogOperation::AddNew, metadata_to_be_merged: merged_metadata, metadata_to_be_deleted: deleted_metadata, - final_document: document, + final_document: log_record.document.clone(), final_embedding: embedding, }) } @@ -417,7 +409,7 @@ pub async fn materialize_logs<'me>( // for materializing. Writers pass this value to the materializer // because they need to share this across all log partitions. next_offset_id: Option>, -) -> Result>, LogMaterializerError> { +) -> Result, LogMaterializerError> { // Trace the total_len since len() iterates over the entire chunk // and we don't want to do that just to trace the length. tracing::info!("Total length of logs in materializer: {}", logs.total_len()); @@ -610,11 +602,11 @@ pub async fn materialize_logs<'me>( return Err(LogMaterializerError::MetadataMaterialization(e)); } }; - if let Some(doc) = log_record.record.document.as_ref() { - record_from_map.final_document = Some(doc); + if let Some(doc) = &log_record.record.document { + record_from_map.final_document = Some(doc.clone()); } - if let Some(emb) = log_record.record.embedding.as_ref() { - record_from_map.final_embedding = Some(emb.as_slice()); + if let Some(emb) = &log_record.record.embedding { + record_from_map.final_embedding = Some(emb.clone()); } match record_from_map.final_operation { MaterializedLogOperation::Initial => { @@ -673,11 +665,11 @@ pub async fn materialize_logs<'me>( return Err(LogMaterializerError::MetadataMaterialization(e)); } }; - if let Some(doc) = log_record.record.document.as_ref() { - record_from_map.final_document = Some(doc); + if let Some(doc) = &log_record.record.document { + record_from_map.final_document = Some(doc.clone()); } - if let Some(emb) = log_record.record.embedding.as_ref() { - record_from_map.final_embedding = Some(emb.as_slice()); + if let Some(emb) = &log_record.record.embedding { + record_from_map.final_embedding = Some(emb.clone()); } match record_from_map.final_operation { MaterializedLogOperation::Initial => { @@ -715,11 +707,11 @@ pub async fn materialize_logs<'me>( return Err(LogMaterializerError::MetadataMaterialization(e)); } }; - if let Some(doc) = log_record.record.document.as_ref() { - record_from_map.final_document = Some(doc); + if let Some(doc) = &log_record.record.document { + record_from_map.final_document = Some(doc.clone()); } - if let Some(emb) = log_record.record.embedding.as_ref() { - record_from_map.final_embedding = Some(emb.as_slice()); + if let Some(emb) = &log_record.record.embedding { + record_from_map.final_embedding = Some(emb.clone()); } // This record is not present on storage yet hence final operation is // AddNew and not UpdateExisting. @@ -763,10 +755,10 @@ pub async fn materialize_logs<'me>( // This needs to be public for testing #[allow(async_fn_in_trait)] -pub trait SegmentWriter<'a> { +pub trait SegmentWriter { async fn apply_materialized_log_chunk( &self, - records: Chunk>, + records: Chunk, ) -> Result<(), ApplyMaterializedLogError>; async fn commit(self) -> Result>; } @@ -1836,10 +1828,10 @@ mod tests { // Embedding 3. if log.user_id.is_some() { id3_found += 1; - assert_eq!("embedding_id_3", log.user_id.unwrap()); + assert_eq!("embedding_id_3", log.user_id.clone().unwrap()); assert!(log.data_record.is_none()); - assert_eq!("doc3", log.final_document.unwrap()); - assert_eq!(vec![7.0, 8.0, 9.0], log.final_embedding.unwrap()); + assert_eq!("doc3", log.final_document.clone().unwrap()); + assert_eq!(vec![7.0, 8.0, 9.0], log.final_embedding.clone().unwrap()); assert_eq!(3, log.offset_id); assert_eq!(MaterializedLogOperation::AddNew, log.final_operation); let mut hello_found = 0; @@ -1895,7 +1887,10 @@ mod tests { assert_eq!(hello_found, 1); assert_eq!(hello_again_found, 1); assert!(log.data_record.is_some()); - assert_eq!(log.data_record.as_ref().unwrap().document, Some("doc1")); + assert_eq!( + log.data_record.as_ref().unwrap().document, + Some("doc1".to_string()) + ); assert_eq!( log.data_record.as_ref().unwrap().embedding, vec![1.0, 2.0, 3.0].as_slice()