diff --git a/drv/auxflash-api/src/lib.rs b/drv/auxflash-api/src/lib.rs index 7db3f0b1f..9cb3076d2 100644 --- a/drv/auxflash-api/src/lib.rs +++ b/drv/auxflash-api/src/lib.rs @@ -77,7 +77,8 @@ pub struct AuxFlashBlob { /// Extension trait to do auxflash operations on anything that /// implements `TlvcRead`. pub trait TlvcReadAuxFlash { - fn read_checksum(self) -> Result; + fn calculate_checksum(self) -> Result; + fn read_stored_checksum(self) -> Result; fn get_blob_by_tag( self, slot: u32, @@ -89,29 +90,31 @@ impl TlvcReadAuxFlash for R where R: TlvcRead, { - fn read_checksum(self) -> Result { + fn read_stored_checksum(self) -> Result { let mut reader = TlvcReader::begin(self) .map_err(|_| AuxFlashError::TlvcReaderBeginFailed)?; - let mut chck_expected = None; - let mut chck_actual = None; while let Ok(Some(chunk)) = reader.next() { if &chunk.header().tag == b"CHCK" { - if chck_expected.is_some() { - return Err(AuxFlashError::MultipleChck); - } else if chunk.len() != 32 { + if chunk.len() != 32 { return Err(AuxFlashError::BadChckSize); } let mut out = [0; 32]; chunk .read_exact(0, &mut out) .map_err(|_| AuxFlashError::ChunkReadFail)?; - chck_expected = Some(out); - } else if &chunk.header().tag == b"AUXI" { - if chck_actual.is_some() { - return Err(AuxFlashError::MultipleAuxi); - } + return Ok(AuxFlashChecksum(out)); + } + } + Err(AuxFlashError::MissingChck) + } + + fn calculate_checksum(self) -> Result { + let mut reader = TlvcReader::begin(self) + .map_err(|_| AuxFlashError::TlvcReaderBeginFailed)?; + while let Ok(Some(chunk)) = reader.next() { + if &chunk.header().tag == b"AUXI" { // Read data and calculate the checksum using a scratch buffer let mut sha = Sha3_256::new(); let mut scratch = [0u8; 256]; @@ -126,23 +129,12 @@ where } let sha_out = sha.finalize(); - // Save the checksum in `chck_actual` let mut out = [0; 32]; out.copy_from_slice(sha_out.as_slice()); - chck_actual = Some(out); - } - } - match (chck_expected, chck_actual) { - (None, _) => Err(AuxFlashError::MissingChck), - (_, None) => Err(AuxFlashError::MissingAuxi), - (Some(a), Some(b)) => { - if a != b { - Err(AuxFlashError::ChckMismatch) - } else { - Ok(AuxFlashChecksum(chck_expected.unwrap())) - } + return Ok(AuxFlashChecksum(out)); } } + Err(AuxFlashError::MissingAuxi) } fn get_blob_by_tag( diff --git a/drv/auxflash-server/src/main.rs b/drv/auxflash-server/src/main.rs index a93b61bdc..2e1efbb12 100644 --- a/drv/auxflash-server/src/main.rs +++ b/drv/auxflash-server/src/main.rs @@ -161,7 +161,7 @@ impl ServerImpl { &self, slot: u32, ) -> Result { - read_slot_checksum(&self.qspi, slot) + read_and_check_slot_checksum(&self.qspi, slot) } /// Checks that the matched slot in this even/odd pair also has valid data. @@ -434,16 +434,36 @@ impl NotificationHandler for ServerImpl { fn scan_for_active_slot(qspi: &Qspi) -> Option { for i in 0..SLOT_COUNT { - if let Ok(chck) = read_slot_checksum(qspi, i) { - if chck.0 == AUXI_CHECKSUM { - return Some(i); - } + let handle = SlotReader { + qspi, + base: i * SLOT_SIZE as u32, + }; + + let Ok(chck) = handle.read_stored_checksum() else { + // Just skip to the next slot if it's empty or invalid. + continue; + }; + + if chck.0 != AUXI_CHECKSUM { + // If it's not the chunk we're interested in, don't bother hashing + // it. + continue; + } + + let Ok(actual) = handle.calculate_checksum() else { + // TODO: this ignores I/O errors, but, this is how the code has + // always been structured... + continue; + }; + + if chck == actual { + return Some(i); } } None } -fn read_slot_checksum( +fn read_and_check_slot_checksum( qspi: &Qspi, slot: u32, ) -> Result { @@ -454,7 +474,13 @@ fn read_slot_checksum( qspi, base: slot * SLOT_SIZE as u32, }; - handle.read_checksum() + let claimed = handle.read_stored_checksum()?; + let actual = handle.calculate_checksum()?; + if claimed == actual { + Ok(actual) + } else { + Err(AuxFlashError::ChckMismatch) + } } ////////////////////////////////////////////////////////////////////////////////