diff --git a/chain/client/src/stateless_validation/partial_witness/partial_witness_actor.rs b/chain/client/src/stateless_validation/partial_witness/partial_witness_actor.rs index 34f9d0278f6..ac50090a697 100644 --- a/chain/client/src/stateless_validation/partial_witness/partial_witness_actor.rs +++ b/chain/client/src/stateless_validation/partial_witness/partial_witness_actor.rs @@ -20,7 +20,9 @@ use near_network::state_witness::{ use near_network::types::{NetworkRequests, PeerManagerAdapter, PeerManagerMessageRequest}; use near_parameters::RuntimeConfig; use near_performance_metrics_macros::perf; -use near_primitives::reed_solomon::{ReedSolomonEncoder, ReedSolomonEncoderCache}; +use near_primitives::reed_solomon::{ + IncrementalEncoder, ReedSolomonEncoder, ReedSolomonEncoderCache, ReedSolomonEncoderSerialize, +}; use near_primitives::sharding::ShardChunkHeader; use near_primitives::stateless_validation::contract_distribution::{ ChunkContractAccesses, ChunkContractDeploys, CodeBytes, CodeHash, ContractCodeRequest, @@ -249,7 +251,36 @@ impl PartialWitnessActor { } // Function to generate the parts of the state witness and return them as a tuple of chunk_validator and part. - fn generate_state_witness_parts( + fn generate_next_state_witness_part( + &mut self, + epoch_id: EpochId, + chunk_header: &ShardChunkHeader, + chunk_validator: &AccountId, + part_ord: usize, + signer: &ValidatorSigner, + incremental_encoder: &mut IncrementalEncoder, + ) -> Option<(AccountId, PartialEncodedStateWitness)> { + tracing::debug!( + target: "client", + chunk_hash=?chunk_header.chunk_hash(), + ?chunk_validator, + "generate_next_state_witness_part", + ); + incremental_encoder.encode_next().map(|part| { + let partial_witness = PartialEncodedStateWitness::new( + epoch_id, + chunk_header.clone(), + part_ord, + part.into_vec(), + incremental_encoder.encoded_length(), + signer, + ); + (chunk_validator.clone(), partial_witness) + }) + } + + // Function to generate the parts of the state witness and return them as a tuple of chunk_validator and part. + fn _generate_state_witness_parts( &mut self, epoch_id: EpochId, chunk_header: &ShardChunkHeader, @@ -338,32 +369,60 @@ impl PartialWitnessActor { let chunk_hash = chunk_header.chunk_hash(); let witness_size_in_bytes = witness_bytes.size_bytes(); - // Record time taken to encode the state witness parts. - let shard_id_label = chunk_header.shard_id().to_string(); - let encode_timer = metrics::PARTIAL_WITNESS_ENCODE_TIME - .with_label_values(&[shard_id_label.as_str()]) - .start_timer(); - let validator_witness_tuple = self.generate_state_witness_parts( - epoch_id, - chunk_header, - witness_bytes, - chunk_validators, - signer, - )?; - encode_timer.observe_duration(); - // Record the witness in order to match the incoming acks for measuring round-trip times. // See process_chunk_state_witness_ack for the handling of the ack messages. self.state_witness_tracker.record_witness_sent( chunk_hash, witness_size_in_bytes, - validator_witness_tuple.len(), + chunk_validators.len(), ); - // Send the parts to the corresponding chunk validator owners. - self.network_adapter.send(PeerManagerMessageRequest::NetworkRequests( - NetworkRequests::PartialEncodedStateWitness(validator_witness_tuple), - )); + // Record time taken to encode and send the state witness parts. + let shard_id_label = chunk_header.shard_id().to_string(); + let encode_timer = metrics::PARTIAL_WITNESS_ENCODE_TIME + .with_label_values(&[shard_id_label.as_str()]) + .start_timer(); + + let encoder = self.witness_encoders.entry(chunk_validators.len()); + if let Ok(mut incremental_encoder) = encoder.incremental_encoder(&witness_bytes) { + for (part_ord, chunk_validator) in chunk_validators.iter().enumerate() { + let validator_witness_tuple = self + .generate_next_state_witness_part( + epoch_id, + chunk_header, + chunk_validator, + part_ord, + signer, + &mut incremental_encoder, + ) + .unwrap(); + + // Send the part to the corresponding chunk validator owner. + self.network_adapter.send(PeerManagerMessageRequest::NetworkRequests( + NetworkRequests::PartialEncodedStateWitness(validator_witness_tuple), + )); + } + } else { + let bytes = witness_bytes.serialize_single_part().unwrap(); + let size = bytes.len(); + let partial_witness = PartialEncodedStateWitness::new( + epoch_id, + chunk_header.clone(), + 0, + bytes, + size, + signer, + ); + // Send the part to the corresponding chunk validator owner. + self.network_adapter.send(PeerManagerMessageRequest::NetworkRequests( + NetworkRequests::PartialEncodedStateWitness(( + chunk_validators[0].clone(), + partial_witness, + )), + )); + }; + encode_timer.observe_duration(); + Ok(()) } @@ -598,7 +657,7 @@ impl PartialWitnessActor { /// Sends the contract accesses to the same chunk validators /// (except for the chunk producers that track the same shard), - /// which will receive the state witness for the new chunk. + /// which will receive the state witness for the new chunk. fn send_contract_accesses_to_chunk_validators( &self, key: ChunkProductionKey, diff --git a/chain/client/src/test_utils/setup.rs b/chain/client/src/test_utils/setup.rs index 22ca09bde99..db20dd384fc 100644 --- a/chain/client/src/test_utils/setup.rs +++ b/chain/client/src/test_utils/setup.rs @@ -756,7 +756,16 @@ fn process_peer_manager_message_default( } } } - NetworkRequests::PartialEncodedStateWitness(partial_witnesses) => { + NetworkRequests::PartialEncodedStateWitness((account, partial_witness)) => { + for (i, name) in validators.iter().enumerate() { + if name == account { + connectors[i] + .partial_witness_sender + .send(PartialEncodedStateWitnessMessage(partial_witness.clone())); + } + } + } + NetworkRequests::PartialEncodedStateWitnesses(partial_witnesses) => { for (account, partial_witness) in partial_witnesses { for (i, name) in validators.iter().enumerate() { if name == account { diff --git a/chain/network/src/peer_manager/peer_manager_actor.rs b/chain/network/src/peer_manager/peer_manager_actor.rs index 1f4a00cf8d8..f198bf16689 100644 --- a/chain/network/src/peer_manager/peer_manager_actor.rs +++ b/chain/network/src/peer_manager/peer_manager_actor.rs @@ -1069,7 +1069,7 @@ impl PeerManagerActor { ); NetworkResponses::NoResponse } - NetworkRequests::PartialEncodedStateWitness(validator_witness_tuple) => { + NetworkRequests::PartialEncodedStateWitnesses(validator_witness_tuple) => { for (chunk_validator, partial_witness) in validator_witness_tuple { self.state.send_message_to_account( &self.clock, @@ -1079,6 +1079,15 @@ impl PeerManagerActor { } NetworkResponses::NoResponse } + NetworkRequests::PartialEncodedStateWitness(validator_witness_tuple) => { + let (chunk_validator, partial_witness) = validator_witness_tuple; + self.state.send_message_to_account( + &self.clock, + &chunk_validator, + RoutedMessageBody::PartialEncodedStateWitness(partial_witness), + ); + NetworkResponses::NoResponse + } NetworkRequests::PartialEncodedStateWitnessForward( chunk_validators, partial_witness, diff --git a/chain/network/src/test_loop.rs b/chain/network/src/test_loop.rs index 08e01385487..44d812109ea 100644 --- a/chain/network/src/test_loop.rs +++ b/chain/network/src/test_loop.rs @@ -347,8 +347,14 @@ fn network_message_to_partial_witness_handler( .send(ChunkStateWitnessAckMessage(witness_ack)); None } - - NetworkRequests::PartialEncodedStateWitness(validator_witness_tuple) => { + NetworkRequests::PartialEncodedStateWitness((target, partial_witness)) => { + shared_state + .senders_for_account(&target) + .partial_witness_sender + .send(PartialEncodedStateWitnessMessage(partial_witness)); + None + } + NetworkRequests::PartialEncodedStateWitnesses(validator_witness_tuple) => { for (target, partial_witness) in validator_witness_tuple.into_iter() { shared_state .senders_for_account(&target) diff --git a/chain/network/src/types.rs b/chain/network/src/types.rs index cd83e57fbd3..a5cc5664ec2 100644 --- a/chain/network/src/types.rs +++ b/chain/network/src/types.rs @@ -287,8 +287,10 @@ pub enum NetworkRequests { ChunkStateWitnessAck(AccountId, ChunkStateWitnessAck), /// Message for a chunk endorsement, sent by a chunk validator to the block producer. ChunkEndorsement(AccountId, ChunkEndorsement), + /// Message from chunk producer to a chunk validator to send a state witness part. + PartialEncodedStateWitness((AccountId, PartialEncodedStateWitness)), /// Message from chunk producer to set of chunk validators to send state witness part. - PartialEncodedStateWitness(Vec<(AccountId, PartialEncodedStateWitness)>), + PartialEncodedStateWitnesses(Vec<(AccountId, PartialEncodedStateWitness)>), /// Message from chunk validator to all other chunk validators to forward state witness part. PartialEncodedStateWitnessForward(Vec, PartialEncodedStateWitness), /// Requests an epoch sync diff --git a/core/primitives/src/reed_solomon.rs b/core/primitives/src/reed_solomon.rs index f6830a4f1d3..a31e3beeedc 100644 --- a/core/primitives/src/reed_solomon.rs +++ b/core/primitives/src/reed_solomon.rs @@ -1,6 +1,6 @@ use borsh::{BorshDeserialize, BorshSerialize}; use itertools::Itertools; -use reed_solomon_erasure::galois_8::ReedSolomon; +use reed_solomon_erasure::{galois_8::ReedSolomon, ShardByShard}; use std::collections::HashMap; use std::io::Error; use std::sync::Arc; @@ -150,6 +150,17 @@ impl ReedSolomonEncoder { } } } + + /// Creates an `IncrementalEncoder` using the same `ReedSolomon` instance + pub fn incremental_encoder<'a, T: ReedSolomonEncoderSerialize>( + &'a self, + data: &T, + ) -> Result, std::io::Error> { + let rs = self.rs.as_ref().ok_or_else(|| { + std::io::Error::new(std::io::ErrorKind::Other, "No encoder available") + })?; + Ok(IncrementalEncoder::new(rs, data)) + } } pub struct ReedSolomonEncoderCache { @@ -245,3 +256,116 @@ impl ReedSolomonPartsTracker { } } } + +/// Struct for incremental encoding +pub struct IncrementalEncoder<'a> { + sbs: ShardByShard<'a, reed_solomon_erasure::galois_8::Field>, + parts: Vec>, + total_parts: usize, + data_parts: usize, + current_input_index: usize, + // Keep track of how many shards (data + parity) we've returned. + shards_returned: usize, + encoded_length: usize, +} + +impl<'a> IncrementalEncoder<'a> { + pub fn new(rs: &'a ReedSolomon, data: &T) -> Self { + let mut bytes = borsh::to_vec(&data).unwrap(); + let encoded_length = bytes.len(); + + let data_parts = rs.data_shard_count(); + let parity_parts = rs.parity_shard_count(); + let total_parts = rs.total_shard_count(); + let part_length = reed_solomon_part_length(encoded_length, data_parts); + + bytes.resize(data_parts * part_length, 0); + + // Prepare data shards + let mut parts: Vec> = bytes + .chunks_exact(part_length) + .map(|chunk| chunk.to_vec().into_boxed_slice()) + .collect(); + + // Prepare parity shards as zero + for _ in 0..parity_parts { + parts.push(vec![0; part_length].into_boxed_slice()); + } + + let sbs = ShardByShard::new(&rs); + + Self { + sbs, + parts, + total_parts, + data_parts, + current_input_index: 0, + shards_returned: 0, + encoded_length, + } + } + + /// If we still have data shards left, we call `encode` once per data shard. + /// After finishing the data shards, parity is ready and the shards get returned in order. + pub fn encode_next(&mut self) -> ReedSolomonPart { + if self.shards_returned >= self.total_parts { + return None; // All shards returned + } + + // Process the next shard + if self.current_input_index < self.data_parts { + let mut shard_refs: Vec<&mut [u8]> = + self.parts.iter_mut().map(|p| p.as_mut()).collect(); + // Process one data shard into parity + if let Err(e) = self.sbs.encode(&mut shard_refs) { + panic!("Encoding failed: {:?}", e); + } + self.current_input_index += 1; + } + + let part = self.parts[self.shards_returned].clone(); + self.shards_returned += 1; + + Some(part) + } + + pub fn encoded_length(&self) -> usize { + self.encoded_length + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::stateless_validation::state_witness::EncodedChunkStateWitness; + use rand::Rng; + + // Test that the parts generated by the incremental encoder are the same as the parts + // generated by the normal encoder. + #[test] + fn test_reed_solomon_incremental_generated_parts() { + const SIZE: usize = 5000; + let total_parts = 50; + let ratio = 0.6; + let mut rng = rand::thread_rng(); + let bytes: Vec = (0..SIZE).map(|_| rng.gen::()).collect(); + let data = EncodedChunkStateWitness::from(bytes.into_boxed_slice()); + let encoder = ReedSolomonEncoder::new(total_parts, ratio); + let (bytes_encoded1, n) = encoder.encode(&data); + let mut incremental_encoder = encoder.incremental_encoder(&data).unwrap(); + let mut bytes_encoded2 = Vec::new(); + for i in 0..total_parts { + let part = incremental_encoder.encode_next(); + assert_eq!(part, bytes_encoded1[i]); + bytes_encoded2.push(part); + } + + assert_eq!(n, incremental_encoder.encoded_length()); + assert_eq!(bytes_encoded1, bytes_encoded2); + + let (bytes_encoded3, n) = + encoder.rs.as_ref().map(|rs| reed_solomon_encode(rs, &data)).unwrap(); + assert_eq!(n, incremental_encoder.encoded_length()); + assert_eq!(bytes_encoded3, bytes_encoded2); + } +}