From 982c30f1ac09acb5b048d76e2b0be8cfe7dfd168 Mon Sep 17 00:00:00 2001 From: Eric Newberry Date: Fri, 18 Apr 2025 20:50:53 +0000 Subject: [PATCH 1/6] storage: user-mode storvsc implementation --- Cargo.lock | 26 + Cargo.toml | 2 + vm/devices/storage/storvsc_driver/Cargo.toml | 38 + vm/devices/storage/storvsc_driver/src/lib.rs | 884 ++++++++++++++++++ .../storvsc_driver/src/test_helpers.rs | 622 ++++++++++++ 5 files changed, 1572 insertions(+) create mode 100644 vm/devices/storage/storvsc_driver/Cargo.toml create mode 100644 vm/devices/storage/storvsc_driver/src/lib.rs create mode 100644 vm/devices/storage/storvsc_driver/src/test_helpers.rs diff --git a/Cargo.lock b/Cargo.lock index dda653c599..e629b352f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6563,6 +6563,32 @@ dependencies = [ "zerocopy 0.8.24", ] +[[package]] +name = "storvsc_driver" +version = "0.0.0" +dependencies = [ + "async-channel", + "futures", + "futures-concurrency", + "guestmem", + "inspect", + "pal_async", + "parking_lot", + "scsi_buffers", + "scsi_defs", + "storvsp_protocol", + "task_control", + "test_with_tracing", + "thiserror 2.0.12", + "tracing", + "tracing_helpers", + "vmbus_async", + "vmbus_channel", + "vmbus_ring", + "vmcore", + "zerocopy 0.8.23", +] + [[package]] name = "storvsp" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index 1083398cb4..45b3ae129c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ members = [ "vm/devices/storage/disk_nvme/nvme_driver/fuzz", "vm/devices/storage/ide/fuzz", "vm/devices/storage/scsi_buffers/fuzz", + "vm/devices/storage/storvsc_driver", # TODO: Remove "vm/devices/storage/storvsp/fuzz", "vm/vmcore/guestmem/fuzz", "vm/x86/x86emu/fuzz", @@ -277,6 +278,7 @@ scsi_core = { path = "vm/devices/storage/scsi_core" } scsi_defs = { path = "vm/devices/storage/scsi_defs" } scsidisk = { path = "vm/devices/storage/scsidisk" } scsidisk_resources = { path = "vm/devices/storage/scsidisk_resources" } +storvsc_driver = { path = "vm/devices/storage/storvsc_driver" } storvsp = { path = "vm/devices/storage/storvsp" } storvsp_protocol = { path = "vm/devices/storage/storvsp_protocol" } storvsp_resources = { path = "vm/devices/storage/storvsp_resources" } diff --git a/vm/devices/storage/storvsc_driver/Cargo.toml b/vm/devices/storage/storvsc_driver/Cargo.toml new file mode 100644 index 0000000000..77e0260091 --- /dev/null +++ b/vm/devices/storage/storvsc_driver/Cargo.toml @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +[package] +name = "storvsc_driver" +edition.workspace = true +rust-version.workspace = true + +[dependencies] +scsi_buffers.workspace = true +scsi_defs.workspace = true + +vmbus_async.workspace = true +vmbus_channel.workspace = true +vmbus_ring.workspace = true + +storvsp_protocol.workspace = true + +guestmem.workspace = true +inspect.workspace = true +pal_async.workspace = true +task_control.workspace = true +vmcore.workspace = true +async-channel.workspace = true +futures.workspace = true +futures-concurrency.workspace = true +parking_lot.workspace = true +thiserror.workspace = true +tracing.workspace = true +tracing_helpers.workspace = true +zerocopy.workspace = true + +[dev-dependencies] +pal_async.workspace = true +test_with_tracing.workspace = true + +[lints] +workspace = true diff --git a/vm/devices/storage/storvsc_driver/src/lib.rs b/vm/devices/storage/storvsc_driver/src/lib.rs new file mode 100644 index 0000000000..e4ee6a85f5 --- /dev/null +++ b/vm/devices/storage/storvsc_driver/src/lib.rs @@ -0,0 +1,884 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Storvsc driver for use as a disk backend. + +#[cfg(test)] +mod test_helpers; + +use async_channel::Receiver; +use async_channel::RecvError; +use async_channel::Sender; +use futures::FutureExt; +use futures_concurrency::future::Race; +use guestmem::AccessError; +use guestmem::MemoryRead; +use guestmem::ranges::PagedRange; +use parking_lot::Mutex; +use std::collections::HashMap; +use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering; +use task_control::AsyncRun; +use task_control::StopTask; +use task_control::TaskControl; +use thiserror::Error; +use tracing_helpers::ErrorValueExt; +use vmbus_async::queue; +use vmbus_async::queue::CompletionPacket; +use vmbus_async::queue::ExternalDataError; +use vmbus_async::queue::IncomingPacket; +use vmbus_async::queue::OutgoingPacket; +use vmbus_async::queue::PacketRef; +use vmbus_async::queue::Queue; +use vmbus_channel::RawAsyncChannel; +use vmbus_ring::OutgoingPacketType; +use vmbus_ring::PAGE_SIZE; +use vmbus_ring::RingMem; +use vmcore::vm_task::VmTaskDriverSource; +use zerocopy::FromBytes; +use zerocopy::Immutable; +use zerocopy::IntoBytes; +use zerocopy::KnownLayout; + +/// Storvsc to provide a backend for SCSI devices over VMBus. +pub struct StorvscDriver { + storvsc: TaskControl>, + version: storvsp_protocol::ProtocolVersion, + driver_source: VmTaskDriverSource, + new_request_sender: Option>, +} + +/// Storvsc backend for SCSI devices. +struct Storvsc { + inner: StorvscInner, + version: storvsp_protocol::ProtocolVersion, + queue: Queue, + num_sub_channels: Option, + has_negotiated: bool, +} + +struct StorvscInner { + next_transaction_id: AtomicU64, + new_request_receiver: Receiver, + transactions: Mutex>, +} + +struct StorvscRequest { + request: storvsp_protocol::ScsiRequest, + buf_gpa: u64, + byte_len: usize, + completion_sender: Sender, +} + +/// Result of a Storvsc operation. If None, then operation was cancelled. +pub struct StorvscCompletion { + completion: Option, +} + +struct PendingOperation { + sender: Sender, +} + +impl PendingOperation { + fn new(sender: Sender) -> Self { + Self { sender } + } + + fn complete(&mut self, result: storvsp_protocol::ScsiRequest) -> Result<(), StorvscError> { + self.sender + .send_blocking(StorvscCompletion { + completion: Some(result), + }) + .map_err(|_err| StorvscError::NotificationError) + } + + fn cancel(&mut self) { + // Sending completion with an empty result indicates cancellation or other error. + self.sender + .send_blocking(StorvscCompletion { completion: None }) + .unwrap(); + } +} + +/// Errors resulting from storvsc. +#[derive(Debug, Error)] +pub enum StorvscError { + /// Packet error. + #[error("packet error")] + PacketError(#[source] PacketError), + /// Queue error. + #[error("queue error")] + Queue(#[source] queue::Error), + /// Queue out of space. + #[error("queue should have enough space but no longer does")] + NotEnoughSpace, + /// Unsupported protocol version. + #[error("requested protocol version unsupported by storvsp")] + UnsupportedProtocolVersion, + /// Unexpected protocol data or operation. + #[error("unexpected protocol data or operation")] + UnexpectedOperation, + /// Error notifying completion of operation. + #[error("error notifying completion of operation")] + NotificationError, + /// Error sending request to storvsc driver. + #[error("error sending request to storvsc")] + RequestError, + /// Error receiving new operations from channel. + #[error("error receiving new operations from channel")] + RequestReceiveError(#[source] RecvError), + /// Error waiting for completion of operation. + #[error("error waiting for completion of operation")] + CompletionError(#[source] RecvError), + /// Operation cancelled. + #[error("pending operation cancelled")] + Cancelled, + /// Storvsc driver not fully initialized. + #[error("driver not initialized")] + Uninitialized, +} + +/// Errors with packet parsing between storvsc and storvsp. +#[derive(Debug, Error)] +pub enum PacketError { + /// Not transactional. + #[error("Not transactional")] + NotTransactional, + /// Unenxpected transaction. + #[error("Unexpected transaction {0:?}")] + UnexpectedTransaction(u64), + /// Unexpected status. + #[error("Unexpected status {0:?}")] + UnexpectedStatus(storvsp_protocol::NtStatus), + /// Unrecognzied operation. + #[error("Unrecognized operation {0:?}")] + UnrecognizedOperation(storvsp_protocol::Operation), + /// Invalid packet type. + #[error("Invalid packet type")] + InvalidPacketType, + /// Invalid data transfer length. + #[error("Invalid data transfer length")] + InvalidDataTransferLength, + /// Access error. + #[error("Access error: {0}")] + Access(#[source] AccessError), + /// Range error. + #[error("Range error")] + Range(#[source] ExternalDataError), +} + +impl StorvscDriver { + /// Create a new driver instance connected to storvsp over VMBus. + pub fn new( + driver_source: &VmTaskDriverSource, + version: storvsp_protocol::ProtocolVersion, + ) -> Self { + Self { + storvsc: TaskControl::new(StorvscState), + version, + driver_source: driver_source.clone(), + new_request_sender: None, + } + } + + /// Start Storvsc. + pub fn run(&mut self, channel: RawAsyncChannel, target_vp: u32) -> Result<(), StorvscError> { + let driver = self + .driver_source + .builder() + .target_vp(target_vp) + .run_on_target(true) + .build("storvsc"); + let (new_request_sender, new_request_receiver) = + async_channel::unbounded::(); + let storvsc = Storvsc::new(channel, self.version, new_request_receiver)?; + self.new_request_sender = Some(new_request_sender); + + self.storvsc.insert(&driver, "storvsc", storvsc); + self.storvsc.start(); + Ok(()) + } + + /// Stop Storvsc. + pub async fn stop(&mut self) { + self.storvsc.stop().await; + self.storvsc.remove(); + } + + /// Send a SCSI request to storvsp over VMBus. + pub async fn send_request( + &mut self, + request: &storvsp_protocol::ScsiRequest, + buf_gpa: u64, + byte_len: usize, + ) -> Result { + let (sender, receiver) = async_channel::unbounded::(); + let storvsc_request = StorvscRequest { + request: *request, + buf_gpa, + byte_len, + completion_sender: sender, + }; + match &self.new_request_sender { + Some(request_sender) => { + request_sender + .send(storvsc_request) + .await + .map_err(|_err| StorvscError::RequestError)?; + Ok(()) + } + None => Err(StorvscError::Uninitialized), + }?; + + let resp = receiver + .recv() + .await + .map_err(StorvscError::CompletionError)?; + + if resp.completion.is_some() { + Ok(resp.completion.unwrap()) + } else { + Err(StorvscError::Cancelled) + } + } +} + +struct StorvscState; + +impl AsyncRun> for StorvscState { + async fn run( + &mut self, + stop: &mut StopTask<'_>, + worker: &mut Storvsc, + ) -> Result<(), task_control::Cancelled> { + let fut = async { + if !worker.has_negotiated { + worker.negotiate().await?; + } + worker.process_main().await + }; + + match stop.until_stopped(fut).await? { + Ok(_) => {} + Err(err) => tracing::error!(error = err.as_error(), "storvsc run error"), + } + Ok(()) + } +} + +impl Storvsc { + pub(crate) fn new( + channel: RawAsyncChannel, + version: storvsp_protocol::ProtocolVersion, + new_request_receiver: Receiver, + ) -> Result { + let queue = Queue::new(channel).map_err(StorvscError::Queue)?; + + Ok(Self { + inner: StorvscInner { + next_transaction_id: AtomicU64::new(1), + new_request_receiver, + transactions: Mutex::new(HashMap::new()), + }, + version, + queue, + num_sub_channels: None, + has_negotiated: false, + }) + } +} + +impl Storvsc { + async fn negotiate(&mut self) -> Result<(), StorvscError> { + // Negotiate protocol with storvsp instance on the other end of VMBus + // Step 1: BEGIN_INITIALIZATION + self.inner + .send_packet_and_expect_completion( + &mut self.queue, + storvsp_protocol::Operation::BEGIN_INITIALIZATION, + &(), + ) + .await?; + + // Step 2: QUERY_PROTOCOL_VERSION - request latest version + self.inner + .send_packet_and_expect_completion( + &mut self.queue, + storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION, + &self.version, + ) + .await + .map_err(|err| match err { + StorvscError::PacketError(PacketError::UnexpectedStatus( + storvsp_protocol::NtStatus::INVALID_DEVICE_STATE, + )) => StorvscError::UnsupportedProtocolVersion, + _ => err, + })?; + + // Step 3: QUERY_PROPERTIES + let properties_packet = self + .inner + .send_packet_and_expect_completion( + &mut self.queue, + storvsp_protocol::Operation::QUERY_PROPERTIES, + &(), + ) + .await?; + let properties = storvsp_protocol::ChannelProperties::ref_from_prefix( + &properties_packet.data[0..properties_packet.data_size], + ) + .map_err(|_err| StorvscError::UnexpectedOperation)? + .0 + .to_owned(); + + // Step 4: CREATE_SUB_CHANNELS (if supported) + if properties.maximum_sub_channel_count > 0 { + self.num_sub_channels = Some(properties.maximum_sub_channel_count); + // Decrease by 1 until we are able to negotiate (or give up if we reach 0) + while self.num_sub_channels.unwrap() > 0 { + match self + .inner + .send_packet_and_expect_completion( + &mut self.queue, + storvsp_protocol::Operation::CREATE_SUB_CHANNELS, + &self.num_sub_channels.unwrap(), + ) + .await + { + Ok(_packet) => break, + Err(_err) => { + self.num_sub_channels = Some(self.num_sub_channels.unwrap() - 1); + } + }; + } + } + + // Step 5: END_INITIALIZATION + self.inner + .send_packet_and_expect_completion( + &mut self.queue, + storvsp_protocol::Operation::END_INITIALIZATION, + &(), + ) + .await?; + + self.has_negotiated = true; + + tracing::info!( + version = self.version.major_minor, + num_sub_channels = self.num_sub_channels, + "Negotiated protocol" + ); + + Ok(()) + } + + /// Main loop to poll for and handle new operations and incoming completions for operations + async fn process_main(&mut self) -> Result<(), StorvscError> { + match self.inner.process_main(&mut self.queue).await { + Ok(_) => Ok(()), + Err(StorvscError::Queue(err2)) => { + if err2.is_closed_error() { + // This is expected, cancel any pending completions + self.inner.cancel_pending_completions().await; + Ok(()) + } else { + Err(StorvscError::Queue(err2)) + } + } + Err(err) => Err(err), + } + } +} + +impl StorvscInner { + async fn process_main(&mut self, queue: &mut Queue) -> Result<(), StorvscError> { + loop { + enum Event<'a, M: RingMem> { + NewRequestReceived(Result), + VmbusPacketReceived(Result, queue::Error>), + } + let (mut reader, mut writer) = queue.split(); + match ( + self.new_request_receiver + .recv() + .map(Event::NewRequestReceived), + reader.read().map(Event::VmbusPacketReceived), + ) + .race() + .await + { + Event::NewRequestReceived(result) => match result { + Ok(request) => { + match self.send_request( + &request.request, + request.buf_gpa, + request.byte_len, + &mut writer, + request.completion_sender, + ) { + Ok(()) => Ok(()), + Err(err) => { + tracing::error!( + "Unable to send new request to VMBus, err={:?}", + err + ); + Err(err) + } + } + } + Err(err) => { + tracing::error!("Unable to receive new request, err={:?}", err); + Err(StorvscError::RequestError) + } + }, + Event::VmbusPacketReceived(result) => match result { + Ok(packet_ref) => self.handle_packet(packet_ref.as_ref()), + Err(err) => { + tracing::error!("Error receiving VMBus packet, err={:?}", err); + Err(StorvscError::Queue(err)) + } + }, + }?; + } + } + + fn send_request( + &mut self, + request: &storvsp_protocol::ScsiRequest, + buf_gpa: u64, + byte_len: usize, + writer: &mut queue::WriteHalf<'_, M>, + completion_sender: Sender, + ) -> Result<(), StorvscError> { + // Fetch a transaction ID for this operation + let transaction_id = self.get_next_transaction_id(); + + // Create pending transaction record + { + self.transactions + .lock() + .insert(transaction_id, PendingOperation::new(completion_sender)); + } + + self.send_gpa_direct_packet( + writer, + storvsp_protocol::Operation::EXECUTE_SRB, + storvsp_protocol::NtStatus::SUCCESS, + transaction_id, + request, + buf_gpa, + byte_len, + ) + } + + async fn cancel_pending_completions(&mut self) { + for transaction in self.transactions.lock().values_mut() { + transaction.cancel(); + } + } + + fn handle_packet( + &mut self, + packet: &IncomingPacket<'_, M>, + ) -> Result<(), StorvscError> { + let packet = parse_packet(packet)?; + let completion = match packet { + //Packet::Data(_) => Err(StorvscError::UnexpectedOperation), + Packet::Completion(completion) => Ok(completion), + }?; + + // Parse ScsiRequest (contains response) from bytes + let result = storvsp_protocol::ScsiRequest::ref_from_bytes(completion.data.as_slice()) + .map_err(|_err| StorvscError::UnexpectedOperation)? + .to_owned(); + + // Match completion against pending transactions + { + match self.transactions.lock().get_mut(&completion.transaction_id) { + Some(t) => Ok(t), + None => Err(StorvscError::PacketError( + PacketError::UnexpectedTransaction(completion.transaction_id), + )), + } + }? + .complete(result)?; + + Ok(()) + } + + /// Awaits the next incoming packet. Increments the count of outstanding packets when returning `Ok(Packet)`. + async fn next_packet<'a, M: RingMem>( + &mut self, + reader: &'a mut queue::ReadHalf<'a, M>, + ) -> Result { + let packet = reader.read().await.map_err(StorvscError::Queue)?; + parse_packet(&packet) + } + + fn get_next_transaction_id(&mut self) -> u64 { + self.next_transaction_id.fetch_add(1, Ordering::AcqRel) + } + + /// Send a non-GPA Direct packet over VMBus. + fn send_packet( + &mut self, + writer: &mut queue::WriteHalf<'_, M>, + operation: storvsp_protocol::Operation, + status: storvsp_protocol::NtStatus, + transaction_id: u64, + payload: &P, + ) -> Result<(), StorvscError> { + let payload_bytes = payload.as_bytes(); + self.send_vmbus_packet( + &mut writer.batched(), + OutgoingPacketType::InBandWithCompletion, + payload_bytes.len(), + transaction_id, + operation, + status, + payload_bytes, + )?; + Ok(()) + } + + /// Send a GPA Direct packet over VMBus. + fn send_gpa_direct_packet( + &mut self, + writer: &mut queue::WriteHalf<'_, M>, + operation: storvsp_protocol::Operation, + status: storvsp_protocol::NtStatus, + transaction_id: u64, + payload: &P, + gpa_start: u64, + byte_len: usize, + ) -> Result<(), StorvscError> { + let payload_bytes = payload.as_bytes(); + let start_page: u64 = gpa_start / PAGE_SIZE as u64; + let end_page: u64 = (gpa_start + (byte_len + PAGE_SIZE - 1) as u64) / PAGE_SIZE as u64; + let gpas: Vec = (start_page..end_page).collect(); + let pages = + PagedRange::new(gpa_start as usize % PAGE_SIZE, byte_len, gpas.as_slice()).unwrap(); + self.send_vmbus_packet( + &mut writer.batched(), + OutgoingPacketType::GpaDirect(&[pages]), + payload_bytes.len(), + transaction_id, + operation, + status, + payload_bytes, + )?; + Ok(()) + } + + /// Send a VMBus packet. + fn send_vmbus_packet( + &mut self, + writer: &mut queue::WriteBatch<'_, M>, + packet_type: OutgoingPacketType<'_>, + request_size: usize, + transaction_id: u64, + operation: storvsp_protocol::Operation, + status: storvsp_protocol::NtStatus, + payload: &[u8], + ) -> Result<(), StorvscError> { + let header = storvsp_protocol::Packet { + operation, + flags: 0, + status, + }; + + let packet_size = size_of_val(&header) + request_size; + + // Zero pad or truncate the payload to the queue's packet size. This is + // necessary because Windows guests check that each packet's size is + // exactly the largest possible packet size for the negotiated protocol + // version. + let len = size_of_val(&header) + size_of_val(payload); + let padding = [0; storvsp_protocol::SCSI_REQUEST_LEN_MAX]; + let (payload_bytes, padding_bytes) = if len > packet_size { + (&payload[..packet_size - size_of_val(&header)], &[][..]) + } else { + (payload, &padding[..packet_size - len]) + }; + assert_eq!( + size_of_val(&header) + payload_bytes.len() + padding_bytes.len(), + packet_size + ); + writer + .try_write(&OutgoingPacket { + transaction_id, + packet_type, + payload: &[header.as_bytes(), payload_bytes, padding_bytes], + }) + .map_err(|err| match err { + queue::TryWriteError::Full(_) => StorvscError::NotEnoughSpace, + queue::TryWriteError::Queue(err) => StorvscError::Queue(err), + }) + } + + async fn send_packet_and_expect_completion< + M: RingMem, + P: IntoBytes + Immutable + KnownLayout, + >( + &mut self, + queue: &mut Queue, + operation: storvsp_protocol::Operation, + payload: &P, + ) -> Result { + let (mut reader, mut writer) = queue.split(); + let transaction_id = self.get_next_transaction_id(); + self.send_packet( + &mut writer, + operation, + storvsp_protocol::NtStatus::SUCCESS, + transaction_id, + payload, + )?; + // Wait for completion + let completion = match self.next_packet(&mut reader).await? { + Packet::Completion(packet) => Ok(packet), + //Packet::Data(_) => Err(StorvscError::PacketError(PacketError::InvalidPacketType)), + }?; + expect_success( + expect_transaction_id(completion, transaction_id).map_err(StorvscError::PacketError)?, + ) + .map_err(StorvscError::PacketError) + } +} + +enum Packet { + Completion(StorvscCompletionPacket), + //Data(StorvscDataPacket), +} + +#[derive(Debug)] +struct StorvscCompletionPacket { + transaction_id: u64, + status: storvsp_protocol::NtStatus, + data_size: usize, + data: [u8; storvsp_protocol::SCSI_REQUEST_LEN_MAX], +} + +/*#[derive(Debug)] +struct StorvscDataPacket { + transaction_id: u64, + request_size: usize, + operation: storvsp_protocol::Operation, + flags: u32, + status: storvsp_protocol::NtStatus, + data: [u8; storvsp_protocol::SCSI_REQUEST_LEN_MAX], +}*/ + +fn parse_packet(packet: &IncomingPacket<'_, T>) -> Result { + match packet { + IncomingPacket::Completion(completion) => { + parse_completion(completion).map_err(StorvscError::PacketError) + } + IncomingPacket::Data(_) => { + // TODO + Err(StorvscError::PacketError(PacketError::InvalidPacketType)) + //parse_data(data).map_err(StorvscError::PacketError) + } + } +} + +fn parse_completion(packet: &CompletionPacket<'_, T>) -> Result { + let mut reader = packet.reader(); + let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?; + if header.operation != storvsp_protocol::Operation::COMPLETE_IO { + return Err(PacketError::NotTransactional); + } + let data_size = reader.len(); + let mut data = [0_u8; storvsp_protocol::SCSI_REQUEST_LEN_MAX]; + let data_temp: Vec = reader.read_n(data_size).map_err(PacketError::Access)?; + data[..data_size].clone_from_slice(data_temp.as_slice()); + Ok(Packet::Completion(StorvscCompletionPacket { + transaction_id: packet.transaction_id(), + status: header.status, + data_size, + data, + })) +} + +/*fn parse_data(packet: &IncomingPacket<'_, T>) -> Result { + let packet = match packet { + IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType), + IncomingPacket::Data(packet) => packet, + }; + let transaction_id = packet.transaction_id(); + + let mut reader = packet.reader(); + let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?; + // You would expect that this should be limited to the current protocol + // version's maximum packet size, but this is not what Hyper-V does, and + // Linux 6.1 relies on this behavior during protocol initialization. + let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX); + let operation = header.operation; + let flags = header.flags; + let status = header.status; + + let mut data = [0_u8; storvsp_protocol::SCSI_REQUEST_LEN_MAX]; + reader.read(&mut data).map_err(PacketError::Access)?; + + Ok(Packet { + transaction_id, + request_size, + operation, + flags, + status, + data, + }) +}*/ + +fn expect_success(packet: StorvscCompletionPacket) -> Result { + if packet.status != storvsp_protocol::NtStatus::SUCCESS { + return Err(PacketError::UnexpectedStatus(packet.status)); + } + Ok(packet) +} + +fn expect_transaction_id( + packet: StorvscCompletionPacket, + transaction_id: u64, +) -> Result { + if packet.transaction_id != transaction_id { + return Err(PacketError::UnexpectedTransaction(packet.transaction_id)); + } + Ok(packet) +} + +#[cfg(test)] +mod tests { + use crate::test_helpers::TestStorvscWorker; + use crate::test_helpers::TestStorvspWorker; + use guestmem::GuestMemory; + use pal_async::DefaultDriver; + use pal_async::async_test; + use pal_async::timer::PolledTimer; + use scsi_defs::ScsiOp; + use std::time; + use test_with_tracing::test; + use vmbus_async::queue::Queue; + use vmbus_channel::connected_async_channels; + use zerocopy::FromZeros; + use zerocopy::IntoBytes; + + // This function assumes the sector size is 512. + fn generate_write_packet( + target_id: u8, + path_id: u8, + lun: u8, + block: u32, + byte_len: usize, + ) -> storvsp_protocol::ScsiRequest { + let cdb = scsi_defs::Cdb10 { + operation_code: ScsiOp::WRITE, + logical_block: block.into(), + transfer_blocks: ((byte_len / 512) as u16).into(), + ..FromZeros::new_zeroed() + }; + + let mut scsi_req = storvsp_protocol::ScsiRequest { + target_id, + path_id, + lun, + length: storvsp_protocol::SCSI_REQUEST_LEN_V2 as u16, + cdb_length: size_of::() as u8, + data_transfer_length: byte_len as u32, + ..FromZeros::new_zeroed() + }; + + scsi_req.payload[0..10].copy_from_slice(cdb.as_bytes()); + scsi_req + } + + // This function assumes the sector size is 512. + fn generate_read_packet( + target_id: u8, + path_id: u8, + lun: u8, + block: u32, + byte_len: usize, + ) -> storvsp_protocol::ScsiRequest { + let cdb = scsi_defs::Cdb10 { + operation_code: ScsiOp::READ, + logical_block: block.into(), + transfer_blocks: ((byte_len / 512) as u16).into(), + ..FromZeros::new_zeroed() + }; + + let mut scsi_req = storvsp_protocol::ScsiRequest { + target_id, + path_id, + lun, + length: storvsp_protocol::SCSI_REQUEST_LEN_V2 as u16, + cdb_length: size_of::() as u8, + data_transfer_length: byte_len as u32, + ..FromZeros::new_zeroed() + }; + + scsi_req.payload[0..10].copy_from_slice(cdb.as_bytes()); + scsi_req + } + + #[async_test] + async fn test_negotiation(driver: DefaultDriver) { + let (guest, host) = connected_async_channels(16 * 1024); + let host_queue = Queue::new(host).unwrap(); + let test_guest_mem = GuestMemory::allocate(16384); + + let storvsp = TestStorvspWorker::start( + driver.clone(), + test_guest_mem.clone(), + host_queue, + Vec::new(), + ); + let mut storvsc = TestStorvscWorker::new(); + storvsc.start(driver.clone(), guest); + + let mut timer = PolledTimer::new(&driver); + timer.sleep(time::Duration::from_secs(1)).await; + + storvsc.teardown().await; + storvsp.teardown().await; + } + + #[async_test] + async fn test_request_response(driver: DefaultDriver) { + let (guest, host) = connected_async_channels(16 * 1024); + let host_queue = Queue::new(host).unwrap(); + let test_guest_mem = GuestMemory::allocate(16384); + + let storvsp = TestStorvspWorker::start( + driver.clone(), + test_guest_mem.clone(), + host_queue, + Vec::new(), + ); + let mut storvsc = TestStorvscWorker::new(); + storvsc.start(driver.clone(), guest); + + let mut timer = PolledTimer::new(&driver); + timer.sleep(time::Duration::from_secs(1)).await; + + // Send SCSI write request + let write_buf = [7u8; 4096]; + test_guest_mem.write_at(4096, &write_buf).unwrap(); + storvsc + .send_request(&generate_write_packet(0, 1, 2, 4096, 4096), 4096, 4096) + .await + .unwrap(); + + // Send SCSI read request + let write_buf = [7u8; 4096]; + test_guest_mem.write_at(4096, &write_buf).unwrap(); + storvsc + .send_request(&generate_read_packet(0, 1, 2, 4096, 4096), 4096, 4096) + .await + .unwrap(); + + storvsc.teardown().await; + storvsp.teardown().await; + } +} diff --git a/vm/devices/storage/storvsc_driver/src/test_helpers.rs b/vm/devices/storage/storvsc_driver/src/test_helpers.rs new file mode 100644 index 0000000000..1476d949ac --- /dev/null +++ b/vm/devices/storage/storvsc_driver/src/test_helpers.rs @@ -0,0 +1,622 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Helpers for unit tests. + +#![cfg_attr(not(test), expect(dead_code))] + +use crate::PacketError; +use crate::Storvsc; +use crate::StorvscCompletion; +use crate::StorvscError; +use crate::StorvscRequest; +use crate::StorvscState; +use async_channel::Sender; +use guestmem::GuestMemory; +use guestmem::MemoryRead; +use guestmem::ranges::PagedRange; +use inspect::Inspect; +use pal_async::task::Spawn; +use pal_async::task::Task; +use scsi_buffers::RequestBuffers; +use std::future::poll_fn; +use std::sync::Arc; +use std::task::Context; +use std::task::Poll; +use task_control::TaskControl; +use thiserror::Error; +use vmbus_async::queue; +use vmbus_async::queue::IncomingPacket; +use vmbus_async::queue::OutgoingPacket; +use vmbus_async::queue::Queue; +use vmbus_channel::RawAsyncChannel; +use vmbus_ring::FlatRingMem; +use vmbus_ring::OutgoingPacketType; +use vmbus_ring::RingMem; +use vmbus_ring::gparange::GpnList; +use vmbus_ring::gparange::MultiPagedRangeBuf; +use zerocopy::FromZeros; +use zerocopy::Immutable; +use zerocopy::IntoBytes; +use zerocopy::KnownLayout; + +const MAX_VMBUS_PACKET_SIZE: usize = vmbus_ring::PacketSize::in_band( + size_of::() + storvsp_protocol::SCSI_REQUEST_LEN_MAX, +); + +#[derive(Debug)] +struct StorvspPacket { + data: StorvspPacketData, + transaction_id: u64, + request_size: usize, +} + +#[derive(Debug, Clone)] +enum StorvspPacketData { + BeginInitialization, + EndInitialization, + QueryProtocolVersion(u16), + QueryProperties, + CreateSubChannels(u16), + ExecuteScsi(Arc), + ResetBus, + ResetAdapter, + ResetLun, +} + +#[repr(u16)] +#[derive(Copy, Clone, Debug, Inspect, PartialEq, Eq, PartialOrd, Ord)] +enum Version { + Win6 = storvsp_protocol::VERSION_WIN6, + Win7 = storvsp_protocol::VERSION_WIN7, + Win8 = storvsp_protocol::VERSION_WIN8, + Blue = storvsp_protocol::VERSION_BLUE, +} + +#[derive(Debug, Error)] +#[error("protocol version {0:#x} not supported")] +struct UnsupportedVersion(u16); + +impl Version { + fn parse(major_minor: u16) -> Result { + let version = match major_minor { + storvsp_protocol::VERSION_WIN6 => Self::Win6, + storvsp_protocol::VERSION_WIN7 => Self::Win7, + storvsp_protocol::VERSION_WIN8 => Self::Win8, + storvsp_protocol::VERSION_BLUE => Self::Blue, + version => return Err(UnsupportedVersion(version)), + }; + assert_eq!(version as u16, major_minor); + Ok(version) + } + + fn max_request_size(&self) -> usize { + match self { + Version::Win8 | Version::Blue => storvsp_protocol::SCSI_REQUEST_LEN_V2, + Version::Win6 | Version::Win7 => storvsp_protocol::SCSI_REQUEST_LEN_V1, + } + } +} + +#[allow(dead_code)] +#[derive(Debug, Default, Clone)] +struct Range { + buf: MultiPagedRangeBuf, + len: usize, + is_write: bool, +} + +#[allow(dead_code)] +impl Range { + fn new( + buf: MultiPagedRangeBuf, + request: &storvsp_protocol::ScsiRequest, + ) -> Option { + let len = request.data_transfer_length as usize; + let is_write = request.data_in != 0; + // Ensure there is exactly one range and it's large enough, or there are + // zero ranges and there is no associated SCSI buffer. + if buf.range_count() > 1 || (len > 0 && buf.first()?.len() < len) { + return None; + } + Some(Self { buf, len, is_write }) + } + + fn buffer<'a>(&'a self, guest_memory: &'a GuestMemory) -> RequestBuffers<'a> { + let mut range = self.buf.first().unwrap_or_else(PagedRange::empty); + range.truncate(self.len); + RequestBuffers::new(guest_memory, range, self.is_write) + } +} + +#[allow(dead_code)] +#[derive(Debug, Clone)] +pub(crate) struct ScsiRequestAndRange { + external_data: Range, + request: storvsp_protocol::ScsiRequest, + request_size: usize, +} + +fn parse_storvsp_packet( + packet: &IncomingPacket<'_, T>, + pool: &mut Vec>, +) -> Result { + let packet = match packet { + IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType), + IncomingPacket::Data(packet) => packet, + }; + let transaction_id = packet + .transaction_id() + .ok_or(PacketError::NotTransactional)?; + + let mut reader = packet.reader(); + let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?; + // You would expect that this should be limited to the current protocol + // version's maximum packet size, but this is not what Hyper-V does, and + // Linux 6.1 relies on this behavior during protocol initialization. + let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX); + let data = match header.operation { + storvsp_protocol::Operation::BEGIN_INITIALIZATION => StorvspPacketData::BeginInitialization, + storvsp_protocol::Operation::END_INITIALIZATION => StorvspPacketData::EndInitialization, + storvsp_protocol::Operation::QUERY_PROTOCOL_VERSION => { + let mut version = storvsp_protocol::ProtocolVersion::new_zeroed(); + reader + .read(version.as_mut_bytes()) + .map_err(PacketError::Access)?; + StorvspPacketData::QueryProtocolVersion(version.major_minor) + } + storvsp_protocol::Operation::QUERY_PROPERTIES => StorvspPacketData::QueryProperties, + storvsp_protocol::Operation::EXECUTE_SRB => { + let mut full_request = pool.pop().unwrap_or_else(|| { + Arc::new(ScsiRequestAndRange { + external_data: Range::default(), + request: storvsp_protocol::ScsiRequest::new_zeroed(), + request_size, + }) + }); + + { + let full_request = Arc::get_mut(&mut full_request).unwrap(); + let request_buf = &mut full_request.request.as_mut_bytes()[..request_size]; + reader.read(request_buf).map_err(PacketError::Access)?; + + let buf = packet.read_external_ranges().map_err(PacketError::Range)?; + + full_request.external_data = Range::new(buf, &full_request.request) + .ok_or(PacketError::InvalidDataTransferLength)?; + } + + StorvspPacketData::ExecuteScsi(full_request) + } + storvsp_protocol::Operation::RESET_LUN => StorvspPacketData::ResetLun, + storvsp_protocol::Operation::RESET_ADAPTER => StorvspPacketData::ResetAdapter, + storvsp_protocol::Operation::RESET_BUS => StorvspPacketData::ResetBus, + storvsp_protocol::Operation::CREATE_SUB_CHANNELS => { + let mut sub_channel_count: u16 = 0; + reader + .read(sub_channel_count.as_mut_bytes()) + .map_err(PacketError::Access)?; + StorvspPacketData::CreateSubChannels(sub_channel_count) + } + _ => return Err(PacketError::UnrecognizedOperation(header.operation)), + }; + + Ok(StorvspPacket { + data, + request_size, + transaction_id, + }) +} + +pub(crate) struct TestStorvscWorker { + task: TaskControl>, + new_request_sender: Option>, +} + +impl TestStorvscWorker { + pub fn new() -> Self { + Self { + task: TaskControl::new(StorvscState), + new_request_sender: None, + } + } + + pub fn start(&mut self, spawner: impl Spawn, channel: RawAsyncChannel) { + let (new_request_sender, new_request_receiver) = + async_channel::unbounded::(); + let storvsc = Storvsc::new( + channel, + storvsp_protocol::ProtocolVersion { + major_minor: storvsp_protocol::VERSION_BLUE, + reserved: 0, + }, + new_request_receiver, + ) + .unwrap(); + self.new_request_sender = Some(new_request_sender); + + self.task.insert(spawner, "storvsc", storvsc); + self.task.start(); + } + + pub async fn teardown(&mut self) { + self.task.stop().await; + self.task.remove(); + } + + /// Send a SCSI request to storvsp over VMBus. + pub async fn send_request( + &mut self, + request: &storvsp_protocol::ScsiRequest, + buf_gpa: u64, + byte_len: usize, + ) -> Result { + let (sender, receiver) = async_channel::unbounded::(); + let storvsc_request = StorvscRequest { + request: *request, + buf_gpa, + byte_len, + completion_sender: sender, + }; + match &self.new_request_sender { + Some(request_sender) => { + request_sender + .send(storvsc_request) + .await + .map_err(|_err| StorvscError::RequestError)?; + Ok(()) + } + None => Err(StorvscError::Uninitialized), + }?; + + let resp = receiver + .recv() + .await + .map_err(StorvscError::CompletionError)?; + + if resp.completion.is_some() { + Ok(resp.completion.unwrap()) + } else { + Err(StorvscError::Cancelled) + } + } +} + +pub(crate) struct TestStorvspWorker { + task: Task<()>, +} + +struct TestStorvsp { + _mem: GuestMemory, + queue: Queue, + full_request_pool: Vec>, + version: storvsp_protocol::ProtocolVersion, + subchannel_count: u16, + inner: TestStorvspInner, +} + +struct TestStorvspInner { + request_size: usize, +} + +impl TestStorvspWorker { + pub fn start( + spawner: impl Spawn, + mem: GuestMemory, + queue: Queue, + full_request_pool: Vec>, + ) -> Self { + let task = spawner.spawn("test_storvsp", async move { + let mut worker = TestStorvsp::new(mem, queue, full_request_pool); + worker.run().await; + }); + + Self { task } + } + + pub async fn teardown(self) { + self.task.cancel().await; + } +} + +impl TestStorvsp { + fn new( + mem: GuestMemory, + queue: Queue, + full_request_pool: Vec>, + ) -> Self { + TestStorvsp { + _mem: mem, + queue, + full_request_pool, + subchannel_count: 0, + version: storvsp_protocol::ProtocolVersion { + major_minor: 0, + reserved: 0, + }, + inner: TestStorvspInner { + request_size: storvsp_protocol::SCSI_REQUEST_LEN_V1, + }, + } + } + + pub async fn run(&mut self) { + self.negotiate().await.unwrap(); + self.process_packets().await.unwrap(); // It's normal to exit here when the channel closes + tracing::error!("TestStorvsp shouldn't have reached here!"); + } + + async fn negotiate(&mut self) -> Result<(), StorvscError> { + let mut has_begin_initialization = false; + let mut has_query_protocol_version = false; + let mut has_query_properties = false; + let mut has_end_initialization = false; + while !has_end_initialization { + tracing::trace!("Waiting for next initialization packet"); + let (mut reader, mut writer) = self.queue.split(); + let packet = reader.read().await.map_err(StorvscError::Queue).unwrap(); + let stor_packet = parse_storvsp_packet(&packet, &mut self.full_request_pool) + .map_err(StorvscError::PacketError) + .unwrap(); + + match stor_packet.data { + StorvspPacketData::BeginInitialization => { + tracing::debug!("Received BeginInitialization"); + + // Ensure that subsequent calls to `send_completion` won't + // fail due to lack of ring space, to avoid keeping (and saving/restoring) interim states. + poll_fn(|cx| self.inner.poll_for_ring_space(cx, &mut writer)).await?; + + if !has_begin_initialization + && !has_query_protocol_version + && !has_query_properties + && !has_end_initialization + { + has_begin_initialization = true; + self.inner.send_completion( + &mut writer, + &stor_packet, + storvsp_protocol::NtStatus::SUCCESS, + &(), + )?; + } else { + tracing::warn!(data = ?stor_packet.data, "Unexpected initialization packet order"); + self.inner.send_completion( + &mut writer, + &stor_packet, + storvsp_protocol::NtStatus::INVALID_DEVICE_STATE, + &(), + )?; + } + Ok(()) + } + StorvspPacketData::QueryProtocolVersion(major_minor) => { + tracing::debug!(major_minor = major_minor, "Received QueryProtocolVersion"); + if has_begin_initialization + && !has_query_protocol_version + && !has_query_properties + && !has_end_initialization + { + has_query_protocol_version = true; + + if let Ok(version) = Version::parse(major_minor) { + self.inner.send_completion( + &mut writer, + &stor_packet, + storvsp_protocol::NtStatus::SUCCESS, + &storvsp_protocol::ProtocolVersion { + major_minor, + reserved: 0, + }, + )?; + self.inner.request_size = version.max_request_size(); + } else { + self.inner.send_completion( + &mut writer, + &stor_packet, + storvsp_protocol::NtStatus::REVISION_MISMATCH, + &storvsp_protocol::ProtocolVersion { + major_minor, + reserved: 0, + }, + )?; + } + } else { + tracing::warn!(data = ?stor_packet.data, "Unexpected initialization packet order"); + self.inner.send_completion( + &mut writer, + &stor_packet, + storvsp_protocol::NtStatus::INVALID_DEVICE_STATE, + &(), + )?; + } + Ok(()) + } + StorvspPacketData::QueryProperties => { + tracing::debug!("Received QueryProperties"); + if has_begin_initialization + && has_query_protocol_version + && !has_query_properties + && !has_end_initialization + { + has_query_properties = true; + self.inner.send_completion( + &mut writer, + &stor_packet, + storvsp_protocol::NtStatus::SUCCESS, + &storvsp_protocol::ChannelProperties { + max_transfer_bytes: 0x40000, // 256KB + flags: storvsp_protocol::STORAGE_CHANNEL_SUPPORTS_MULTI_CHANNEL, + maximum_sub_channel_count: 16, + reserved: 0, + reserved2: 0, + reserved3: [0, 0], + }, + )?; + } else { + tracing::warn!(data = ?stor_packet.data, "Unexpected initialization packet order"); + self.inner.send_completion( + &mut writer, + &stor_packet, + storvsp_protocol::NtStatus::INVALID_DEVICE_STATE, + &(), + )?; + } + Ok(()) + } + StorvspPacketData::CreateSubChannels(sub_channel_count) => { + tracing::debug!( + sub_channel_count = sub_channel_count, + "Received CreateSubChannels" + ); + self.subchannel_count = sub_channel_count; + self.inner.send_completion( + &mut writer, + &stor_packet, + storvsp_protocol::NtStatus::SUCCESS, + &(), + )?; + Ok(()) + } + StorvspPacketData::EndInitialization => { + tracing::debug!("Received EndInitialization"); + if has_begin_initialization + && has_query_protocol_version + && has_query_properties + && !has_end_initialization + { + has_end_initialization = true; + self.inner.send_completion( + &mut writer, + &stor_packet, + storvsp_protocol::NtStatus::SUCCESS, + &(), + )?; + } else { + tracing::warn!(data = ?stor_packet.data, "Unexpected initialization packet order"); + self.inner.send_completion( + &mut writer, + &stor_packet, + storvsp_protocol::NtStatus::INVALID_DEVICE_STATE, + &(), + )?; + } + Ok(()) + } + _ => { + tracing::warn!(data = ?stor_packet.data, "Unexpected packet received during initialization"); + self.inner.send_completion( + &mut writer, + &stor_packet, + storvsp_protocol::NtStatus::INVALID_DEVICE_STATE, + &(), + )?; + Ok(()) + } + }?; + } + + tracing::info!( + version = self.version.major_minor, + subchannel_count = self.subchannel_count, + "storvsp negoiated" + ); + + Ok(()) + } + + async fn process_packets(&mut self) -> Result<(), StorvscError> { + loop { + let (mut reader, mut writer) = self.queue.split(); + let packet = reader.read().await.map_err(StorvscError::Queue)?; + let stor_packet = parse_storvsp_packet(&packet, &mut self.full_request_pool) + .map_err(StorvscError::PacketError)?; + tracing::info!("storvsp received request packet"); + + match stor_packet.data.clone() { + StorvspPacketData::ExecuteScsi(_request) => { + tracing::info!("storvsp responding to EXECUTE_SRB"); + self.inner.send_completion( + &mut writer, + &stor_packet, + storvsp_protocol::NtStatus::SUCCESS, + &(), + )?; + } + _ => { + tracing::info!("storvsp received unexpected request packet type"); + self.inner.send_completion( + &mut writer, + &stor_packet, + storvsp_protocol::NtStatus::INVALID_DEVICE_STATE, + &(), + )?; + } + } + } + } +} + +impl TestStorvspInner { + fn send_completion( + &mut self, + writer: &mut queue::WriteHalf<'_, M>, + packet: &StorvspPacket, + status: storvsp_protocol::NtStatus, + payload: &P, + ) -> Result<(), StorvscError> { + self.send_vmbus_packet( + &mut writer.batched(), + OutgoingPacketType::Completion, + packet.request_size, + packet.transaction_id, + storvsp_protocol::Operation::COMPLETE_IO, + status, + payload.as_bytes(), + ) + } + + fn send_vmbus_packet( + &mut self, + writer: &mut queue::WriteBatch<'_, M>, + packet_type: OutgoingPacketType<'_>, + _request_size: usize, // Unused, but kept for compatibility with similar APIs + transaction_id: u64, + operation: storvsp_protocol::Operation, + status: storvsp_protocol::NtStatus, + payload: &[u8], + ) -> Result<(), StorvscError> { + let header = storvsp_protocol::Packet { + operation, + flags: 0, + status, + }; + + writer + .try_write(&OutgoingPacket { + transaction_id, + packet_type, + payload: &[header.as_bytes(), payload], + }) + .map_err(|err| match err { + queue::TryWriteError::Full(_) => StorvscError::NotEnoughSpace, + queue::TryWriteError::Queue(err) => StorvscError::Queue(err), + }) + } + + /// Polls for enough ring space in the outgoing ring to send a packet. + /// + /// This is used to ensure there is enough space in the ring before + /// committing to sending a packet. This avoids the need to save pending + /// packets on the side if queue processing is interrupted while the ring is + /// full. + fn poll_for_ring_space( + &mut self, + cx: &mut Context<'_>, + writer: &mut queue::WriteHalf<'_, M>, + ) -> Poll> { + writer + .poll_ready(cx, MAX_VMBUS_PACKET_SIZE) + .map_err(StorvscError::Queue) + } +} From 718a452afce2a8f11145c9d5a12175206e4913d9 Mon Sep 17 00:00:00 2001 From: Eric Newberry Date: Tue, 6 May 2025 22:22:53 +0000 Subject: [PATCH 2/6] Remove unused code --- vm/devices/storage/storvsc_driver/src/lib.rs | 42 -------------------- 1 file changed, 42 deletions(-) diff --git a/vm/devices/storage/storvsc_driver/src/lib.rs b/vm/devices/storage/storvsc_driver/src/lib.rs index e4ee6a85f5..d29013537e 100644 --- a/vm/devices/storage/storvsc_driver/src/lib.rs +++ b/vm/devices/storage/storvsc_driver/src/lib.rs @@ -660,25 +660,13 @@ struct StorvscCompletionPacket { data: [u8; storvsp_protocol::SCSI_REQUEST_LEN_MAX], } -/*#[derive(Debug)] -struct StorvscDataPacket { - transaction_id: u64, - request_size: usize, - operation: storvsp_protocol::Operation, - flags: u32, - status: storvsp_protocol::NtStatus, - data: [u8; storvsp_protocol::SCSI_REQUEST_LEN_MAX], -}*/ - fn parse_packet(packet: &IncomingPacket<'_, T>) -> Result { match packet { IncomingPacket::Completion(completion) => { parse_completion(completion).map_err(StorvscError::PacketError) } IncomingPacket::Data(_) => { - // TODO Err(StorvscError::PacketError(PacketError::InvalidPacketType)) - //parse_data(data).map_err(StorvscError::PacketError) } } } @@ -701,36 +689,6 @@ fn parse_completion(packet: &CompletionPacket<'_, T>) -> Result(packet: &IncomingPacket<'_, T>) -> Result { - let packet = match packet { - IncomingPacket::Completion(_) => return Err(PacketError::InvalidPacketType), - IncomingPacket::Data(packet) => packet, - }; - let transaction_id = packet.transaction_id(); - - let mut reader = packet.reader(); - let header: storvsp_protocol::Packet = reader.read_plain().map_err(PacketError::Access)?; - // You would expect that this should be limited to the current protocol - // version's maximum packet size, but this is not what Hyper-V does, and - // Linux 6.1 relies on this behavior during protocol initialization. - let request_size = reader.len().min(storvsp_protocol::SCSI_REQUEST_LEN_MAX); - let operation = header.operation; - let flags = header.flags; - let status = header.status; - - let mut data = [0_u8; storvsp_protocol::SCSI_REQUEST_LEN_MAX]; - reader.read(&mut data).map_err(PacketError::Access)?; - - Ok(Packet { - transaction_id, - request_size, - operation, - flags, - status, - data, - }) -}*/ - fn expect_success(packet: StorvscCompletionPacket) -> Result { if packet.status != storvsp_protocol::NtStatus::SUCCESS { return Err(PacketError::UnexpectedStatus(packet.status)); From 72f6a83090091249c93b085abbecee74096d7174 Mon Sep 17 00:00:00 2001 From: Eric Newberry Date: Tue, 6 May 2025 23:04:30 +0000 Subject: [PATCH 3/6] Fix formatting --- vm/devices/storage/storvsc_driver/src/lib.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vm/devices/storage/storvsc_driver/src/lib.rs b/vm/devices/storage/storvsc_driver/src/lib.rs index d29013537e..f6900d4bb3 100644 --- a/vm/devices/storage/storvsc_driver/src/lib.rs +++ b/vm/devices/storage/storvsc_driver/src/lib.rs @@ -665,9 +665,7 @@ fn parse_packet(packet: &IncomingPacket<'_, T>) -> Result { parse_completion(completion).map_err(StorvscError::PacketError) } - IncomingPacket::Data(_) => { - Err(StorvscError::PacketError(PacketError::InvalidPacketType)) - } + IncomingPacket::Data(_) => Err(StorvscError::PacketError(PacketError::InvalidPacketType)), } } From 36c4bcebe8b46763ac4bec648d436527084ec469 Mon Sep 17 00:00:00 2001 From: Eric Newberry Date: Tue, 6 May 2025 23:08:02 +0000 Subject: [PATCH 4/6] Fix Cargo.lock --- Cargo.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index e629b352f6..60dd09616c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6586,7 +6586,7 @@ dependencies = [ "vmbus_channel", "vmbus_ring", "vmcore", - "zerocopy 0.8.23", + "zerocopy 0.8.24", ] [[package]] From 3788f7239b1d87c334b0347a6e7b746201a60db6 Mon Sep 17 00:00:00 2001 From: Eric Newberry Date: Mon, 12 May 2025 19:13:52 +0000 Subject: [PATCH 5/6] Add integration test for storvsc with storvsp --- Cargo.lock | 18 +++ Cargo.toml | 2 +- vm/devices/storage/storage_tests/Cargo.toml | 24 ++++ .../storage/storage_tests/tests/storvsc.rs | 134 ++++++++++++++++++ .../storage/storage_tests/tests/tests.rs | 5 + vm/devices/storage/storvsc_driver/src/lib.rs | 46 ++++-- .../storvsc_driver/src/test_helpers.rs | 14 +- vm/devices/storage/storvsp/src/lib.rs | 5 +- .../storage/storvsp/src/test_helpers.rs | 1 - 9 files changed, 227 insertions(+), 22 deletions(-) create mode 100644 vm/devices/storage/storage_tests/Cargo.toml create mode 100644 vm/devices/storage/storage_tests/tests/storvsc.rs create mode 100644 vm/devices/storage/storage_tests/tests/tests.rs diff --git a/Cargo.lock b/Cargo.lock index 60dd09616c..cfa14f5d87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6563,6 +6563,24 @@ dependencies = [ "zerocopy 0.8.24", ] +[[package]] +name = "storage_tests" +version = "0.0.0" +dependencies = [ + "disklayer_ram", + "guestmem", + "pal_async", + "scsi_defs", + "scsidisk", + "storvsc_driver", + "storvsp", + "storvsp_protocol", + "storvsp_resources", + "test_with_tracing", + "vmbus_channel", + "zerocopy 0.8.24", +] + [[package]] name = "storvsc_driver" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index 45b3ae129c..b82a2d792b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ members = [ "vm/devices/storage/disk_nvme/nvme_driver/fuzz", "vm/devices/storage/ide/fuzz", "vm/devices/storage/scsi_buffers/fuzz", - "vm/devices/storage/storvsc_driver", # TODO: Remove + "vm/devices/storage/storage_tests", "vm/devices/storage/storvsp/fuzz", "vm/vmcore/guestmem/fuzz", "vm/x86/x86emu/fuzz", diff --git a/vm/devices/storage/storage_tests/Cargo.toml b/vm/devices/storage/storage_tests/Cargo.toml new file mode 100644 index 0000000000..ba1e2437c6 --- /dev/null +++ b/vm/devices/storage/storage_tests/Cargo.toml @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +[package] +name = "storage_tests" +edition.workspace = true +rust-version.workspace = true + +[dev-dependencies] +disklayer_ram.workspace = true +guestmem.workspace = true +pal_async.workspace = true +scsi_defs.workspace = true +scsidisk.workspace = true +storvsc_driver.workspace = true +storvsp.workspace = true +storvsp_protocol.workspace = true +storvsp_resources.workspace = true +test_with_tracing.workspace = true +vmbus_channel.workspace = true +zerocopy.workspace = true + +[lints] +workspace = true diff --git a/vm/devices/storage/storage_tests/tests/storvsc.rs b/vm/devices/storage/storage_tests/tests/storvsc.rs new file mode 100644 index 0000000000..7de280b352 --- /dev/null +++ b/vm/devices/storage/storage_tests/tests/storvsc.rs @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Tests of user-mode storvsc implementation with user-mode storvsp. + +use guestmem::GuestMemory; +use pal_async::DefaultDriver; +use pal_async::async_test; +use pal_async::timer::PolledTimer; +use scsi_defs::ScsiOp; +use std::sync::Arc; +use std::time; +use storvsc_driver::test_helpers::TestStorvscWorker; +use storvsp::ScsiController; +use storvsp::ScsiControllerDisk; +use storvsp::test_helpers::TestWorker; +use storvsp_resources::ScsiPath; +use test_with_tracing::test; +use vmbus_channel::connected_async_channels; +use zerocopy::FromZeros; +use zerocopy::IntoBytes; + +// This function assumes the sector size is 512. +fn generate_write_packet( + target_id: u8, + path_id: u8, + lun: u8, + block: u32, + byte_len: usize, +) -> storvsp_protocol::ScsiRequest { + let cdb = scsi_defs::Cdb10 { + operation_code: ScsiOp::WRITE, + logical_block: block.into(), + transfer_blocks: ((byte_len / 512) as u16).into(), + ..FromZeros::new_zeroed() + }; + + let mut scsi_req = storvsp_protocol::ScsiRequest { + target_id, + path_id, + lun, + length: storvsp_protocol::SCSI_REQUEST_LEN_V2 as u16, + cdb_length: size_of::() as u8, + data_transfer_length: byte_len as u32, + ..FromZeros::new_zeroed() + }; + + scsi_req.payload[0..10].copy_from_slice(cdb.as_bytes()); + scsi_req +} + +// This function assumes the sector size is 512. +fn generate_read_packet( + target_id: u8, + path_id: u8, + lun: u8, + block: u32, + byte_len: usize, +) -> storvsp_protocol::ScsiRequest { + let cdb = scsi_defs::Cdb10 { + operation_code: ScsiOp::READ, + logical_block: block.into(), + transfer_blocks: ((byte_len / 512) as u16).into(), + ..FromZeros::new_zeroed() + }; + + let mut scsi_req = storvsp_protocol::ScsiRequest { + target_id, + path_id, + lun, + length: storvsp_protocol::SCSI_REQUEST_LEN_V2 as u16, + cdb_length: size_of::() as u8, + data_transfer_length: byte_len as u32, + ..FromZeros::new_zeroed() + }; + + scsi_req.payload[0..10].copy_from_slice(cdb.as_bytes()); + scsi_req +} + +#[async_test] +async fn test_request_response(driver: DefaultDriver) { + let (host, guest) = connected_async_channels(16 * 1024); + + let test_guest_mem = GuestMemory::allocate(16384); + let controller = ScsiController::new(); + let disk = scsidisk::SimpleScsiDisk::new( + disklayer_ram::ram_disk(10 * 1024 * 1024, false).unwrap(), + Default::default(), + ); + controller + .attach( + ScsiPath { + path: 0, + target: 0, + lun: 0, + }, + ScsiControllerDisk::new(Arc::new(disk)), + ) + .unwrap(); + + let storvsp = TestWorker::start( + controller, + driver.clone(), + test_guest_mem.clone(), + host, + None, + ); + + let mut storvsc = TestStorvscWorker::new(); + storvsc.start(driver.clone(), guest); + + let mut timer = PolledTimer::new(&driver); + timer.sleep(time::Duration::from_secs(1)).await; + + // Send SCSI write request + let write_buf = [7u8; 4096]; + test_guest_mem.write_at(4096, &write_buf).unwrap(); + storvsc + .send_request(&generate_write_packet(0, 1, 2, 4096, 4096), 4096, 4096) + .await + .unwrap(); + + // Send SCSI read request + let write_buf = [7u8; 4096]; + test_guest_mem.write_at(4096, &write_buf).unwrap(); + storvsc + .send_request(&generate_read_packet(0, 1, 2, 4096, 4096), 4096, 4096) + .await + .unwrap(); + + storvsc.teardown().await; + storvsp.teardown_ignore().await; +} diff --git a/vm/devices/storage/storage_tests/tests/tests.rs b/vm/devices/storage/storage_tests/tests/tests.rs new file mode 100644 index 0000000000..e706a1886e --- /dev/null +++ b/vm/devices/storage/storage_tests/tests/tests.rs @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//! Tests for storage devices that don't qualify as unit tests, including integration tests. +mod storvsc; diff --git a/vm/devices/storage/storvsc_driver/src/lib.rs b/vm/devices/storage/storvsc_driver/src/lib.rs index f6900d4bb3..ede5d6d067 100644 --- a/vm/devices/storage/storvsc_driver/src/lib.rs +++ b/vm/devices/storage/storvsc_driver/src/lib.rs @@ -3,8 +3,7 @@ //! Storvsc driver for use as a disk backend. -#[cfg(test)] -mod test_helpers; +pub mod test_helpers; use async_channel::Receiver; use async_channel::RecvError; @@ -19,6 +18,7 @@ use std::collections::HashMap; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; use task_control::AsyncRun; +use task_control::InspectTask; use task_control::StopTask; use task_control::TaskControl; use thiserror::Error; @@ -266,6 +266,15 @@ impl AsyncRun> for StorvscState { } } +impl InspectTask> for StorvscState { + fn inspect(&self, req: inspect::Request<'_>, worker: Option<&Storvsc>) { + if let Some(worker) = worker { + let mut resp = req.respond(); + resp.field("has_negotiated", worker.has_negotiated); + } + } +} + impl Storvsc { pub(crate) fn new( channel: RawAsyncChannel, @@ -533,7 +542,6 @@ impl StorvscInner { self.send_vmbus_packet( &mut writer.batched(), OutgoingPacketType::InBandWithCompletion, - payload_bytes.len(), transaction_id, operation, status, @@ -562,7 +570,6 @@ impl StorvscInner { self.send_vmbus_packet( &mut writer.batched(), OutgoingPacketType::GpaDirect(&[pages]), - payload_bytes.len(), transaction_id, operation, status, @@ -576,7 +583,6 @@ impl StorvscInner { &mut self, writer: &mut queue::WriteBatch<'_, M>, packet_type: OutgoingPacketType<'_>, - request_size: usize, transaction_id: u64, operation: storvsp_protocol::Operation, status: storvsp_protocol::NtStatus, @@ -588,22 +594,25 @@ impl StorvscInner { status, }; - let packet_size = size_of_val(&header) + request_size; - - // Zero pad or truncate the payload to the queue's packet size. This is - // necessary because Windows guests check that each packet's size is - // exactly the largest possible packet size for the negotiated protocol - // version. + // storvsp limits the size of the completion packet to the size of the request packet, + // so we need to pad the payload to the maximum size to ensure we get a complete response. + // The maximum size includes header + payload. let len = size_of_val(&header) + size_of_val(payload); let padding = [0; storvsp_protocol::SCSI_REQUEST_LEN_MAX]; - let (payload_bytes, padding_bytes) = if len > packet_size { - (&payload[..packet_size - size_of_val(&header)], &[][..]) + let (payload_bytes, padding_bytes) = if len > storvsp_protocol::SCSI_REQUEST_LEN_MAX { + ( + &payload[..storvsp_protocol::SCSI_REQUEST_LEN_MAX - size_of_val(&header)], + &[][..], + ) } else { - (payload, &padding[..packet_size - len]) + ( + payload, + &padding[..storvsp_protocol::SCSI_REQUEST_LEN_MAX - len], + ) }; assert_eq!( size_of_val(&header) + payload_bytes.len() + padding_bytes.len(), - packet_size + storvsp_protocol::SCSI_REQUEST_LEN_MAX ); writer .try_write(&OutgoingPacket { @@ -796,6 +805,9 @@ mod tests { let mut timer = PolledTimer::new(&driver); timer.sleep(time::Duration::from_secs(1)).await; + storvsc.stop().await; + assert!(storvsc.get_mut().has_negotiated); + storvsc.teardown().await; storvsp.teardown().await; } @@ -818,6 +830,10 @@ mod tests { let mut timer = PolledTimer::new(&driver); timer.sleep(time::Duration::from_secs(1)).await; + storvsc.stop().await; + assert!(storvsc.get_mut().has_negotiated); + storvsc.resume().await; + // Send SCSI write request let write_buf = [7u8; 4096]; test_guest_mem.write_at(4096, &write_buf).unwrap(); diff --git a/vm/devices/storage/storvsc_driver/src/test_helpers.rs b/vm/devices/storage/storvsc_driver/src/test_helpers.rs index 1476d949ac..a8a3c18ab4 100644 --- a/vm/devices/storage/storvsc_driver/src/test_helpers.rs +++ b/vm/devices/storage/storvsc_driver/src/test_helpers.rs @@ -208,7 +208,7 @@ fn parse_storvsp_packet( }) } -pub(crate) struct TestStorvscWorker { +pub struct TestStorvscWorker { task: TaskControl>, new_request_sender: Option>, } @@ -239,6 +239,18 @@ impl TestStorvscWorker { self.task.start(); } + pub async fn stop(&mut self) { + self.task.stop().await; + } + + pub async fn resume(&mut self) { + self.task.start(); + } + + pub fn get_mut(&mut self) -> &Storvsc { + self.task.get_mut().1.unwrap() + } + pub async fn teardown(&mut self) { self.task.stop().await; self.task.remove(); diff --git a/vm/devices/storage/storvsp/src/lib.rs b/vm/devices/storage/storvsp/src/lib.rs index 87004a9686..3869086a60 100644 --- a/vm/devices/storage/storvsp/src/lib.rs +++ b/vm/devices/storage/storvsp/src/lib.rs @@ -7,12 +7,9 @@ #[cfg(feature = "ioperf")] pub mod ioperf; -#[cfg(feature = "fuzz_helpers")] +// Needs to be pub for use with storage tests in a separate crate. pub mod test_helpers; -#[cfg(not(feature = "fuzz_helpers"))] -mod test_helpers; - pub mod resolver; mod save_restore; diff --git a/vm/devices/storage/storvsp/src/test_helpers.rs b/vm/devices/storage/storvsp/src/test_helpers.rs index b156349946..b9681a4348 100644 --- a/vm/devices/storage/storvsp/src/test_helpers.rs +++ b/vm/devices/storage/storvsp/src/test_helpers.rs @@ -49,7 +49,6 @@ impl TestWorker { /// Like `teardown`, but ignore the result. Nice for the fuzzer, /// so that the `storvsp` crate doesn't need to expose `WorkerError` /// as pub. - #[cfg(feature = "fuzz_helpers")] pub async fn teardown_ignore(self) { let _ = self.task.await; } From a81cdad02d406e2192887f3e527ca2ba7ae4cf5f Mon Sep 17 00:00:00 2001 From: Eric Newberry Date: Mon, 12 May 2025 20:37:55 +0000 Subject: [PATCH 6/6] Fix missing docs and other errors --- vm/devices/storage/storvsc_driver/src/test_helpers.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vm/devices/storage/storvsc_driver/src/test_helpers.rs b/vm/devices/storage/storvsc_driver/src/test_helpers.rs index a8a3c18ab4..f38773618b 100644 --- a/vm/devices/storage/storvsc_driver/src/test_helpers.rs +++ b/vm/devices/storage/storvsc_driver/src/test_helpers.rs @@ -4,6 +4,7 @@ //! Helpers for unit tests. #![cfg_attr(not(test), expect(dead_code))] +#![expect(missing_docs)] use crate::PacketError; use crate::Storvsc; @@ -247,7 +248,7 @@ impl TestStorvscWorker { self.task.start(); } - pub fn get_mut(&mut self) -> &Storvsc { + pub(crate) fn get_mut(&mut self) -> &Storvsc { self.task.get_mut().1.unwrap() }