Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] feat: incremental partial witness generation #12631

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ use near_network::state_witness::{
use near_network::types::{NetworkRequests, PeerManagerAdapter, PeerManagerMessageRequest};
use near_parameters::RuntimeConfig;
use near_performance_metrics_macros::perf;
use near_primitives::reed_solomon::{ReedSolomonEncoder, ReedSolomonEncoderCache};
use near_primitives::reed_solomon::{
IncrementalEncoder, ReedSolomonEncoder, ReedSolomonEncoderCache, ReedSolomonEncoderSerialize,
};
use near_primitives::sharding::ShardChunkHeader;
use near_primitives::stateless_validation::contract_distribution::{
ChunkContractAccesses, ChunkContractDeploys, CodeBytes, CodeHash, ContractCodeRequest,
Expand Down Expand Up @@ -249,7 +251,36 @@ impl PartialWitnessActor {
}

// Function to generate the parts of the state witness and return them as a tuple of chunk_validator and part.
fn generate_state_witness_parts(
fn generate_next_state_witness_part(
&mut self,
epoch_id: EpochId,
chunk_header: &ShardChunkHeader,
chunk_validator: &AccountId,
part_ord: usize,
signer: &ValidatorSigner,
incremental_encoder: &mut IncrementalEncoder,
) -> Option<(AccountId, PartialEncodedStateWitness)> {
tracing::debug!(
target: "client",
chunk_hash=?chunk_header.chunk_hash(),
?chunk_validator,
"generate_next_state_witness_part",
);
incremental_encoder.encode_next().map(|part| {
let partial_witness = PartialEncodedStateWitness::new(
epoch_id,
chunk_header.clone(),
part_ord,
part.into_vec(),
incremental_encoder.encoded_length(),
signer,
);
(chunk_validator.clone(), partial_witness)
})
}

// Function to generate the parts of the state witness and return them as a tuple of chunk_validator and part.
fn _generate_state_witness_parts(
&mut self,
epoch_id: EpochId,
chunk_header: &ShardChunkHeader,
Expand Down Expand Up @@ -338,32 +369,60 @@ impl PartialWitnessActor {
let chunk_hash = chunk_header.chunk_hash();
let witness_size_in_bytes = witness_bytes.size_bytes();

// Record time taken to encode the state witness parts.
let shard_id_label = chunk_header.shard_id().to_string();
let encode_timer = metrics::PARTIAL_WITNESS_ENCODE_TIME
.with_label_values(&[shard_id_label.as_str()])
.start_timer();
let validator_witness_tuple = self.generate_state_witness_parts(
epoch_id,
chunk_header,
witness_bytes,
chunk_validators,
signer,
)?;
encode_timer.observe_duration();

// Record the witness in order to match the incoming acks for measuring round-trip times.
// See process_chunk_state_witness_ack for the handling of the ack messages.
self.state_witness_tracker.record_witness_sent(
chunk_hash,
witness_size_in_bytes,
validator_witness_tuple.len(),
chunk_validators.len(),
);

// Send the parts to the corresponding chunk validator owners.
self.network_adapter.send(PeerManagerMessageRequest::NetworkRequests(
NetworkRequests::PartialEncodedStateWitness(validator_witness_tuple),
));
// Record time taken to encode and send the state witness parts.
let shard_id_label = chunk_header.shard_id().to_string();
let encode_timer = metrics::PARTIAL_WITNESS_ENCODE_TIME
.with_label_values(&[shard_id_label.as_str()])
.start_timer();

let encoder = self.witness_encoders.entry(chunk_validators.len());
if let Ok(mut incremental_encoder) = encoder.incremental_encoder(&witness_bytes) {
for (part_ord, chunk_validator) in chunk_validators.iter().enumerate() {
let validator_witness_tuple = self
.generate_next_state_witness_part(
epoch_id,
chunk_header,
chunk_validator,
part_ord,
signer,
&mut incremental_encoder,
)
.unwrap();

// Send the part to the corresponding chunk validator owner.
self.network_adapter.send(PeerManagerMessageRequest::NetworkRequests(
NetworkRequests::PartialEncodedStateWitness(validator_witness_tuple),
));
}
} else {
let bytes = witness_bytes.serialize_single_part().unwrap();
let size = bytes.len();
let partial_witness = PartialEncodedStateWitness::new(
epoch_id,
chunk_header.clone(),
0,
bytes,
size,
signer,
);
// Send the part to the corresponding chunk validator owner.
self.network_adapter.send(PeerManagerMessageRequest::NetworkRequests(
NetworkRequests::PartialEncodedStateWitness((
chunk_validators[0].clone(),
partial_witness,
)),
));
};
encode_timer.observe_duration();

Ok(())
}

Expand Down Expand Up @@ -598,7 +657,7 @@ impl PartialWitnessActor {

/// Sends the contract accesses to the same chunk validators
/// (except for the chunk producers that track the same shard),
/// which will receive the state witness for the new chunk.
/// which will receive the state witness for the new chunk.
fn send_contract_accesses_to_chunk_validators(
&self,
key: ChunkProductionKey,
Expand Down
11 changes: 10 additions & 1 deletion chain/client/src/test_utils/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,16 @@ fn process_peer_manager_message_default(
}
}
}
NetworkRequests::PartialEncodedStateWitness(partial_witnesses) => {
NetworkRequests::PartialEncodedStateWitness((account, partial_witness)) => {
for (i, name) in validators.iter().enumerate() {
if name == account {
connectors[i]
.partial_witness_sender
.send(PartialEncodedStateWitnessMessage(partial_witness.clone()));
}
}
}
NetworkRequests::PartialEncodedStateWitnesses(partial_witnesses) => {
for (account, partial_witness) in partial_witnesses {
for (i, name) in validators.iter().enumerate() {
if name == account {
Expand Down
11 changes: 10 additions & 1 deletion chain/network/src/peer_manager/peer_manager_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ impl PeerManagerActor {
);
NetworkResponses::NoResponse
}
NetworkRequests::PartialEncodedStateWitness(validator_witness_tuple) => {
NetworkRequests::PartialEncodedStateWitnesses(validator_witness_tuple) => {
for (chunk_validator, partial_witness) in validator_witness_tuple {
self.state.send_message_to_account(
&self.clock,
Expand All @@ -1079,6 +1079,15 @@ impl PeerManagerActor {
}
NetworkResponses::NoResponse
}
NetworkRequests::PartialEncodedStateWitness(validator_witness_tuple) => {
let (chunk_validator, partial_witness) = validator_witness_tuple;
self.state.send_message_to_account(
&self.clock,
&chunk_validator,
RoutedMessageBody::PartialEncodedStateWitness(partial_witness),
);
NetworkResponses::NoResponse
}
NetworkRequests::PartialEncodedStateWitnessForward(
chunk_validators,
partial_witness,
Expand Down
10 changes: 8 additions & 2 deletions chain/network/src/test_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,14 @@ fn network_message_to_partial_witness_handler(
.send(ChunkStateWitnessAckMessage(witness_ack));
None
}

NetworkRequests::PartialEncodedStateWitness(validator_witness_tuple) => {
NetworkRequests::PartialEncodedStateWitness((target, partial_witness)) => {
shared_state
.senders_for_account(&target)
.partial_witness_sender
.send(PartialEncodedStateWitnessMessage(partial_witness));
None
}
NetworkRequests::PartialEncodedStateWitnesses(validator_witness_tuple) => {
for (target, partial_witness) in validator_witness_tuple.into_iter() {
shared_state
.senders_for_account(&target)
Expand Down
4 changes: 3 additions & 1 deletion chain/network/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,10 @@ pub enum NetworkRequests {
ChunkStateWitnessAck(AccountId, ChunkStateWitnessAck),
/// Message for a chunk endorsement, sent by a chunk validator to the block producer.
ChunkEndorsement(AccountId, ChunkEndorsement),
/// Message from chunk producer to a chunk validator to send a state witness part.
PartialEncodedStateWitness((AccountId, PartialEncodedStateWitness)),
/// Message from chunk producer to set of chunk validators to send state witness part.
PartialEncodedStateWitness(Vec<(AccountId, PartialEncodedStateWitness)>),
PartialEncodedStateWitnesses(Vec<(AccountId, PartialEncodedStateWitness)>),
/// Message from chunk validator to all other chunk validators to forward state witness part.
PartialEncodedStateWitnessForward(Vec<AccountId>, PartialEncodedStateWitness),
/// Requests an epoch sync
Expand Down
126 changes: 125 additions & 1 deletion core/primitives/src/reed_solomon.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use borsh::{BorshDeserialize, BorshSerialize};
use itertools::Itertools;
use reed_solomon_erasure::galois_8::ReedSolomon;
use reed_solomon_erasure::{galois_8::ReedSolomon, ShardByShard};
use std::collections::HashMap;
use std::io::Error;
use std::sync::Arc;
Expand Down Expand Up @@ -150,6 +150,17 @@ impl ReedSolomonEncoder {
}
}
}

/// Creates an `IncrementalEncoder` using the same `ReedSolomon` instance
pub fn incremental_encoder<'a, T: ReedSolomonEncoderSerialize>(
&'a self,
data: &T,
) -> Result<IncrementalEncoder<'a>, std::io::Error> {
let rs = self.rs.as_ref().ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::Other, "No encoder available")
})?;
Ok(IncrementalEncoder::new(rs, data))
}
}

pub struct ReedSolomonEncoderCache {
Expand Down Expand Up @@ -245,3 +256,116 @@ impl<T: ReedSolomonEncoderDeserialize> ReedSolomonPartsTracker<T> {
}
}
}

/// Struct for incremental encoding
pub struct IncrementalEncoder<'a> {
sbs: ShardByShard<'a, reed_solomon_erasure::galois_8::Field>,
parts: Vec<Box<[u8]>>,
total_parts: usize,
data_parts: usize,
current_input_index: usize,
// Keep track of how many shards (data + parity) we've returned.
shards_returned: usize,
encoded_length: usize,
}

impl<'a> IncrementalEncoder<'a> {
pub fn new<T: ReedSolomonEncoderSerialize>(rs: &'a ReedSolomon, data: &T) -> Self {
let mut bytes = borsh::to_vec(&data).unwrap();
let encoded_length = bytes.len();

let data_parts = rs.data_shard_count();
let parity_parts = rs.parity_shard_count();
let total_parts = rs.total_shard_count();
let part_length = reed_solomon_part_length(encoded_length, data_parts);

bytes.resize(data_parts * part_length, 0);

// Prepare data shards
let mut parts: Vec<Box<[u8]>> = bytes
.chunks_exact(part_length)
.map(|chunk| chunk.to_vec().into_boxed_slice())
.collect();

// Prepare parity shards as zero
for _ in 0..parity_parts {
parts.push(vec![0; part_length].into_boxed_slice());
}

let sbs = ShardByShard::new(&rs);

Self {
sbs,
parts,
total_parts,
data_parts,
current_input_index: 0,
shards_returned: 0,
encoded_length,
}
}

/// If we still have data shards left, we call `encode` once per data shard.
/// After finishing the data shards, parity is ready and the shards get returned in order.
pub fn encode_next(&mut self) -> ReedSolomonPart {
if self.shards_returned >= self.total_parts {
return None; // All shards returned
}

// Process the next shard
if self.current_input_index < self.data_parts {
let mut shard_refs: Vec<&mut [u8]> =
self.parts.iter_mut().map(|p| p.as_mut()).collect();
// Process one data shard into parity
if let Err(e) = self.sbs.encode(&mut shard_refs) {
panic!("Encoding failed: {:?}", e);
}
self.current_input_index += 1;
}

let part = self.parts[self.shards_returned].clone();
self.shards_returned += 1;

Some(part)
}

pub fn encoded_length(&self) -> usize {
self.encoded_length
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::stateless_validation::state_witness::EncodedChunkStateWitness;
use rand::Rng;

// Test that the parts generated by the incremental encoder are the same as the parts
// generated by the normal encoder.
#[test]
fn test_reed_solomon_incremental_generated_parts() {
const SIZE: usize = 5000;
let total_parts = 50;
let ratio = 0.6;
let mut rng = rand::thread_rng();
let bytes: Vec<u8> = (0..SIZE).map(|_| rng.gen::<u8>()).collect();
let data = EncodedChunkStateWitness::from(bytes.into_boxed_slice());
let encoder = ReedSolomonEncoder::new(total_parts, ratio);
let (bytes_encoded1, n) = encoder.encode(&data);
let mut incremental_encoder = encoder.incremental_encoder(&data).unwrap();
let mut bytes_encoded2 = Vec::new();
for i in 0..total_parts {
let part = incremental_encoder.encode_next();
assert_eq!(part, bytes_encoded1[i]);
bytes_encoded2.push(part);
}

assert_eq!(n, incremental_encoder.encoded_length());
assert_eq!(bytes_encoded1, bytes_encoded2);

let (bytes_encoded3, n) =
encoder.rs.as_ref().map(|rs| reed_solomon_encode(rs, &data)).unwrap();
assert_eq!(n, incremental_encoder.encoded_length());
assert_eq!(bytes_encoded3, bytes_encoded2);
}
}
Loading