diff --git a/protocol/protos/NetRemoteDataStream.proto b/protocol/protos/NetRemoteDataStream.proto index d0585dc5..b60dd9de 100644 --- a/protocol/protos/NetRemoteDataStream.proto +++ b/protocol/protos/NetRemoteDataStream.proto @@ -8,7 +8,7 @@ enum DataStreamOperationStatusCode DataStreamOperationStatusCodeUnknown = 0; DataStreamOperationStatusCodeSucceeded = 1; DataStreamOperationStatusCodeFailed = 2; - DataStreamOperationStatusCodeCancelled = 3; + DataStreamOperationStatusCodeCanceled = 3; DataStreamOperationStatusCodeTimedOut = 4; } @@ -23,8 +23,54 @@ message DataStreamUploadData bytes Data = 1; } +message DataStreamDownloadData +{ + DataStreamOperationStatus Status = 1; + bytes Data = 2; + uint32 SequenceNumber = 3; +} + message DataStreamUploadResult { - uint32 NumberOfDataBlocksReceived = 1; - DataStreamOperationStatus Status = 2; + DataStreamOperationStatus Status = 1; + uint32 NumberOfDataBlocksReceived = 2; +} + +enum DataStreamType +{ + DataStreamTypeUnknown = 0; + DataStreamTypeFixed = 1; + DataStreamTypeContinuous = 2; +} + +enum DataStreamPattern +{ + DataStreamPatternUnknown = 0; + DataStreamPatternConstant = 1; +} + +message DataStreamFixedTypeProperties +{ + uint32 NumberOfDataBlocksToStream = 1; +} + +message DataStreamContinuousTypeProperties +{ + +} + +message DataStreamProperties +{ + DataStreamType Type = 1; + DataStreamPattern Pattern = 2; + oneof Properties + { + DataStreamFixedTypeProperties Fixed = 3; + DataStreamContinuousTypeProperties Continuous = 4; + } +} + +message DataStreamDownloadRequest +{ + DataStreamProperties Properties = 1; } diff --git a/protocol/protos/NetRemoteDataStreamingService.proto b/protocol/protos/NetRemoteDataStreamingService.proto index fea41e1a..99c23ca3 100644 --- a/protocol/protos/NetRemoteDataStreamingService.proto +++ b/protocol/protos/NetRemoteDataStreamingService.proto @@ -9,4 +9,5 @@ import "NetRemoteDataStream.proto"; service NetRemoteDataStreaming { rpc DataStreamUpload (stream Microsoft.Net.Remote.DataStream.DataStreamUploadData) returns (Microsoft.Net.Remote.DataStream.DataStreamUploadResult); + rpc DataStreamDownload (Microsoft.Net.Remote.DataStream.DataStreamDownloadRequest) returns (stream Microsoft.Net.Remote.DataStream.DataStreamDownloadData); } diff --git a/src/common/service/NetRemoteDataStreamingReactors.cxx b/src/common/service/NetRemoteDataStreamingReactors.cxx index f2977439..75ecd33f 100644 --- a/src/common/service/NetRemoteDataStreamingReactors.cxx +++ b/src/common/service/NetRemoteDataStreamingReactors.cxx @@ -1,18 +1,66 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + #include "NetRemoteDataStreamingReactors.hxx" +using logging::FunctionTracer; + +namespace Microsoft::Net::Remote::Service::Reactors::Helpers +{ +DataGenerator::DataGenerator() +{ + m_generator.seed(std::random_device{}()); +} + +std::string +DataGenerator::GenerateRandomData(const std::size_t length) +{ + std::string result; + result.reserve(length); + + for (std::size_t i = 0; i < length; i++) { + result.push_back(static_cast(GetRandomByte())); + } + + return result; +} + +uint8_t +DataGenerator::GetRandomByte() +{ + std::uniform_int_distribution distribution(0, std::numeric_limits::max()); + return static_cast(distribution(m_generator)); +} +} // namespace Microsoft::Net::Remote::Service::Reactors::Helpers + using namespace Microsoft::Net::Remote::DataStream; using namespace Microsoft::Net::Remote::Service::Reactors; DataStreamReader::DataStreamReader(DataStreamUploadResult* result) : m_result(result) { + const FunctionTracer traceMe{}; StartRead(&m_data); } void DataStreamReader::OnReadDone(bool isOk) { + const FunctionTracer traceMe{}; + if (isOk) { m_numberOfDataBlocksReceived++; m_readStatus.set_code(DataStreamOperationStatusCode::DataStreamOperationStatusCodeSucceeded); @@ -30,9 +78,11 @@ DataStreamReader::OnReadDone(bool isOk) void DataStreamReader::OnCancel() { + const FunctionTracer traceMe{}; + m_result->set_numberofdatablocksreceived(m_numberOfDataBlocksReceived); - m_readStatus.set_code(DataStreamOperationStatusCode::DataStreamOperationStatusCodeCancelled); - m_readStatus.set_message("RPC cancelled"); + m_readStatus.set_code(DataStreamOperationStatusCode::DataStreamOperationStatusCodeCanceled); + m_readStatus.set_message("RPC canceled"); *m_result->mutable_status() = std::move(m_readStatus); Finish(grpc::Status::CANCELLED); } @@ -40,5 +90,148 @@ DataStreamReader::OnCancel() void DataStreamReader::OnDone() { + const FunctionTracer traceMe{}; delete this; } + +DataStreamWriter::DataStreamWriter(const DataStreamDownloadRequest* request) : + m_dataStreamProperties(request->properties()) +{ + const FunctionTracer traceMe{}; + + switch (m_dataStreamProperties.type()) { + case DataStreamType::DataStreamTypeFixed: { + if (m_dataStreamProperties.Properties_case() == DataStreamProperties::kFixed) { + m_numberOfDataBlocksToStream = m_dataStreamProperties.fixed().numberofdatablockstostream(); + } else { + HandleFailure("Invalid properties for this streaming type. Expected Fixed for DataStreamTypeFixed"); + return; + } + + break; + } + case DataStreamType::DataStreamTypeContinuous: { + if (m_dataStreamProperties.Properties_case() == DataStreamProperties::kContinuous) { + m_numberOfDataBlocksToStream = 0; + } else { + HandleFailure("Invalid properties for this streaming type. Expected Continuous for DataStreamTypeContinuous"); + return; + } + + break; + } + default: { + HandleFailure(std::format("Invalid streaming type: {}", magic_enum::enum_name(m_dataStreamProperties.type()))); + return; + } + }; + + m_writeStatus.set_code(DataStreamOperationStatusCode::DataStreamOperationStatusCodeUnknown); + m_writeStatus.set_message("No data sent yet"); + NextWrite(); +} + +void +DataStreamWriter::OnWriteDone(bool isOk) +{ + const FunctionTracer traceMe{}; + + // Client may have canceled the RPC, so check for cancelation to prevent writing more data + // when we shouldn't. + if (m_isCanceled.load(std::memory_order_relaxed)) { + LOGD << "RPC canceled, returning early"; + return; + } + + // Check for a failed status code from HandleWriteFailure since that invoked a final write, thus causing this callback to be invoked. + if (m_writeStatus.code() == DataStreamOperationStatusCode::DataStreamOperationStatusCodeFailed) { + Finish(::grpc::Status::OK); + return; + } + + // Continue writing if previous write was successful, otherwise handle the failure. + if (isOk) { + if (m_dataStreamProperties.type() == DataStreamType::DataStreamTypeFixed) { + m_numberOfDataBlocksToStream--; + } + m_writeStatus.set_code(DataStreamOperationStatusCode::DataStreamOperationStatusCodeSucceeded); + m_writeStatus.set_message("Data write successful"); + NextWrite(); + } else { + HandleFailure("Data write failed"); + } +} + +void +DataStreamWriter::OnCancel() +{ + const FunctionTracer traceMe{}; + + // The RPC is canceled by the client, so call Finish to complete it from the server perspective. + bool isCanceledExpected{ false }; + if (m_isCanceled.compare_exchange_strong(isCanceledExpected, true, std::memory_order_relaxed, std::memory_order_relaxed)) { + Finish(grpc::Status::CANCELLED); + } +} + +void +DataStreamWriter::OnDone() +{ + const FunctionTracer traceMe{}; + delete this; +} + +void +DataStreamWriter::NextWrite() +{ + const FunctionTracer traceMe{}; + + // Client may have canceled the RPC, so check for cancelation to prevent writing more data + // when we shouldn't. + if (m_isCanceled.load(std::memory_order_relaxed)) { + LOGD << "RPC canceled, aborting write"; + return; + } + + if (m_dataStreamProperties.type() == DataStreamType::DataStreamTypeContinuous || + (m_dataStreamProperties.type() == DataStreamType::DataStreamTypeFixed && m_numberOfDataBlocksToStream > 0)) { + // Check the requested data streaming pattern. + const auto pattern = m_dataStreamProperties.pattern(); + + switch (pattern) { + // Generate data with the constant default size and write to the client at a constant speed. + case DataStreamPattern::DataStreamPatternConstant: { + // Create data to write to the client. + const auto data = m_dataGenerator.GenerateRandomData(); + m_numberOfDataBlocksWritten++; + + // Write data to the client. + m_data.set_data(data); + m_data.set_sequencenumber(m_numberOfDataBlocksWritten); + *m_data.mutable_status() = m_writeStatus; + StartWrite(&m_data); + + break; + } + default: + HandleFailure(std::format("Unexpected data stream pattern {}", magic_enum::enum_name(pattern))); + }; + } else { + // No more data to write. + Finish(::grpc::Status::OK); + } +} + +void +DataStreamWriter::HandleFailure(const std::string& errorMessage) +{ + const FunctionTracer traceMe{}; + + m_writeStatus.set_code(DataStreamOperationStatusCode::DataStreamOperationStatusCodeFailed); + m_writeStatus.set_message(errorMessage); + *m_data.mutable_status() = m_writeStatus; + + // Write a final message to the client. The OnWriteDone() callback will check for the + // DataStreamOperationStatusCodeFailed status code set here to know to complete the RPC. + StartWrite(&m_data); +} diff --git a/src/common/service/NetRemoteDataStreamingReactors.hxx b/src/common/service/NetRemoteDataStreamingReactors.hxx index 990e0597..2fbd2655 100644 --- a/src/common/service/NetRemoteDataStreamingReactors.hxx +++ b/src/common/service/NetRemoteDataStreamingReactors.hxx @@ -2,11 +2,53 @@ #ifndef NET_REMOTE_DATA_STREAMING_REACTORS_HXX #define NET_REMOTE_DATA_STREAMING_REACTORS_HXX +#include #include +#include +#include #include #include +namespace Microsoft::Net::Remote::Service::Reactors::Helpers +{ +/** + * @brief A simple random data generator. + */ +class DataGenerator +{ +public: + static constexpr std::size_t DefaultDataLength{ 8 }; + + /** + * @brief Construct a DataGenerator object. + * + */ + explicit DataGenerator(); + + /** + * @brief Generate a random data string of the specified length. + * + * @param length The length of the random data string. + * @return std::string + */ + std::string + GenerateRandomData(const std::size_t length = DefaultDataLength); + +private: + /** + * @brief Generate a random byte of data. + * + * @return uint8_t + */ + uint8_t + GetRandomByte(); + +private: + std::mt19937 m_generator{}; +}; +} // namespace Microsoft::Net::Remote::Service::Reactors::Helpers + namespace Microsoft::Net::Remote::Service::Reactors { /** @@ -32,7 +74,7 @@ public: OnReadDone(bool isOk) override; /** - * @brief Callback that is executed when an RPC is cancelled before successfully sending a status to the client. + * @brief Callback that is executed when an RPC is canceled before successfully sending a status to the client. */ void OnCancel() override; @@ -49,6 +91,65 @@ private: uint32_t m_numberOfDataBlocksReceived{}; Microsoft::Net::Remote::DataStream::DataStreamOperationStatus m_readStatus{}; }; + +/** + * @brief Implementation of the gRPC ServerWriteReactor for server-side data stream writing. + */ +class DataStreamWriter : + public grpc::ServerWriteReactor +{ +public: + /** + * @brief Construct a new DataStreamWriter object with the specified download request. + * + * @param request The download request from the client. + */ + explicit DataStreamWriter(const Microsoft::Net::Remote::DataStream::DataStreamDownloadRequest* request); + + /** + * @brief Callback that is executed when a write operation is completed. + * + * @param isOk Indicates whether a write was successfully sent. + */ + void + OnWriteDone(bool isOk) override; + + /** + * @brief Callback that is executed when an RPC is canceled before successfully sending a status to the client. + */ + void + OnCancel() override; + + /** + * @brief Callback that is executed when all RPC operations are completed for a given RPC. + */ + void + OnDone() override; + +private: + /** + * @brief Facilitate the next write operation. + */ + void + NextWrite(); + + /** + * @brief Handle a failed operation. + * + * @param errorMessage The error message associated with the failed operation. + */ + void + HandleFailure(const std::string& errorMessage); + +private: + Microsoft::Net::Remote::DataStream::DataStreamDownloadData m_data{}; + Microsoft::Net::Remote::DataStream::DataStreamProperties m_dataStreamProperties{}; + uint32_t m_numberOfDataBlocksToStream{}; + uint32_t m_numberOfDataBlocksWritten{}; + Microsoft::Net::Remote::DataStream::DataStreamOperationStatus m_writeStatus{}; + std::atomic m_isCanceled{}; + Microsoft::Net::Remote::Service::Reactors::Helpers::DataGenerator m_dataGenerator{}; +}; } // namespace Microsoft::Net::Remote::Service::Reactors #endif // NET_REMOTE_DATA_STREAMING_REACTORS_HXX diff --git a/src/common/service/NetRemoteDataStreamingService.cxx b/src/common/service/NetRemoteDataStreamingService.cxx index 8be58875..0ce6384e 100644 --- a/src/common/service/NetRemoteDataStreamingService.cxx +++ b/src/common/service/NetRemoteDataStreamingService.cxx @@ -1,12 +1,30 @@ -#include "NetRemoteDataStreamingReactors.hxx" +#include + +#include +#include #include +#include + +#include "NetRemoteApiTrace.hxx" +#include "NetRemoteDataStreamingReactors.hxx" using namespace Microsoft::Net::Remote::DataStream; using namespace Microsoft::Net::Remote::Service; +using namespace Microsoft::Net::Remote::Service::Tracing; grpc::ServerReadReactor* NetRemoteDataStreamingService::DataStreamUpload([[maybe_unused]] grpc::CallbackServerContext* context, DataStreamUploadResult* result) { - return new Reactors::DataStreamReader(result); + const NetRemoteApiTrace traceMe{}; + + return std::make_unique(result).release(); +} + +grpc::ServerWriteReactor* +NetRemoteDataStreamingService::DataStreamDownload([[maybe_unused]] grpc::CallbackServerContext* context, const DataStreamDownloadRequest* request) +{ + const NetRemoteApiTrace traceMe{}; + + return std::make_unique(request).release(); } diff --git a/src/common/service/include/microsoft/net/remote/NetRemoteDataStreamingService.hxx b/src/common/service/include/microsoft/net/remote/NetRemoteDataStreamingService.hxx index af107275..e902bc8a 100644 --- a/src/common/service/include/microsoft/net/remote/NetRemoteDataStreamingService.hxx +++ b/src/common/service/include/microsoft/net/remote/NetRemoteDataStreamingService.hxx @@ -2,6 +2,8 @@ #ifndef NET_REMOTE_DATA_STREAMING_SERVICE_HXX #define NET_REMOTE_DATA_STREAMING_SERVICE_HXX +#include +#include #include #include @@ -29,6 +31,16 @@ private: */ grpc::ServerReadReactor* DataStreamUpload(grpc::CallbackServerContext* context, Microsoft::Net::Remote::DataStream::DataStreamUploadResult* result) override; + + /** + * @brief Stream data from the server to the client. + * + * @param context + * @param request + * @return grpc::ServerWriteReactor* + */ + grpc::ServerWriteReactor* + DataStreamDownload(grpc::CallbackServerContext* context, const Microsoft::Net::Remote::DataStream::DataStreamDownloadRequest* request) override; }; } // namespace Microsoft::Net::Remote::Service diff --git a/tests/unit/TestNetRemoteDataStreamingReactors.cxx b/tests/unit/TestNetRemoteDataStreamingReactors.cxx index eb9a8bac..3f7d2d7d 100644 --- a/tests/unit/TestNetRemoteDataStreamingReactors.cxx +++ b/tests/unit/TestNetRemoteDataStreamingReactors.cxx @@ -1,6 +1,13 @@ -#include +#include #include +#include +#include + +#include +#include +#include +#include #include "TestNetRemoteDataStreamingReactors.hxx" @@ -8,8 +15,6 @@ using namespace Microsoft::Net::Remote::DataStream; using namespace Microsoft::Net::Remote::Service; using namespace Microsoft::Net::Remote::Test; -using namespace std::chrono_literals; - DataStreamWriter::DataStreamWriter(NetRemoteDataStreaming::Stub* client, uint32_t numberOfDataBlocksToWrite) : m_numberOfDataBlocksToWrite(numberOfDataBlocksToWrite) { @@ -31,7 +36,7 @@ DataStreamWriter::OnWriteDone(bool isOk) void DataStreamWriter::OnDone(const grpc::Status& status) { - std::unique_lock lock(m_writeStatusGate); + const std::unique_lock lock(m_writeStatusGate); m_status = status; m_done = true; @@ -42,9 +47,8 @@ grpc::Status DataStreamWriter::Await(DataStreamUploadResult* result) { std::unique_lock lock(m_writeStatusGate); - static constexpr auto timeoutValue = 10s; - const auto isDone = m_writesDone.wait_for(lock, timeoutValue, [this] { + const auto isDone = m_writesDone.wait_for(lock, m_writesDoneTimeoutValue, [this] { return m_done; }); @@ -70,3 +74,68 @@ DataStreamWriter::NextWrite() StartWritesDone(); } } + +DataStreamReader::DataStreamReader(NetRemoteDataStreaming::Stub* client, DataStreamDownloadRequest* request) +{ + client->async()->DataStreamDownload(&m_clientContext, request, this); + StartCall(); + StartRead(&m_data); +} + +void +DataStreamReader::OnReadDone(bool isOk) +{ + if (isOk) { + m_numberOfDataBlocksReceived++; + StartRead(&m_data); + } + // If read fails, then there is likely no more data to be read, so do nothing. +} + +void +DataStreamReader::OnDone(const grpc::Status& status) +{ + const std::unique_lock lock(m_readStatusGate); + + m_status = status; + m_done = true; + m_readsDone.notify_one(); +} + +grpc::Status +DataStreamReader::Await(uint32_t* numberOfDataBlocksReceived, DataStreamOperationStatus* operationStatus) +{ + std::unique_lock lock(m_readStatusGate); + + const auto isDone = m_readsDone.wait_for(lock, m_readsDoneTimeoutValue, [this] { + return m_done; + }); + + // Handle timeout from waiting for reads to be completed. + if (!isDone) { + DataStreamOperationStatus status{}; + status.set_code(DataStreamOperationStatusCode::DataStreamOperationStatusCodeTimedOut); + status.set_message("Timeout occurred while waiting for all reads to be completed"); + *m_data.mutable_status() = std::move(status); + } + + // Handle mismatched sequence number and number of data blocks received. + if (m_data.sequencenumber() != m_numberOfDataBlocksReceived) { + DataStreamOperationStatus status{}; + status.set_code(DataStreamOperationStatusCode::DataStreamOperationStatusCodeFailed); + status.set_message(std::format("Sequence number {} does not match the number of data blocks received {}", m_data.sequencenumber(), m_numberOfDataBlocksReceived)); + *m_data.mutable_status() = std::move(status); + } + + *numberOfDataBlocksReceived = m_numberOfDataBlocksReceived; + *operationStatus = m_data.status(); + + return m_status; +} + +void +DataStreamReader::Cancel() +{ + LOGD << "Attempting to cancel RPC"; + m_clientContext.TryCancel(); +} diff --git a/tests/unit/TestNetRemoteDataStreamingReactors.hxx b/tests/unit/TestNetRemoteDataStreamingReactors.hxx index 624790a4..729fea9c 100644 --- a/tests/unit/TestNetRemoteDataStreamingReactors.hxx +++ b/tests/unit/TestNetRemoteDataStreamingReactors.hxx @@ -2,6 +2,7 @@ #ifndef TEST_NET_REMOTE_DATA_STREAMING_REACTORS_HXX #define TEST_NET_REMOTE_DATA_STREAMING_REACTORS_HXX +#include #include #include #include @@ -9,6 +10,8 @@ #include #include +using namespace std::chrono_literals; + namespace Microsoft::Net::Remote::Test { /** @@ -59,6 +62,8 @@ private: NextWrite(); private: + static inline constexpr auto DefaultTimeoutValue{ 10s }; + grpc::ClientContext m_clientContext{}; Microsoft::Net::Remote::DataStream::DataStreamUploadData m_data{}; Microsoft::Net::Remote::DataStream::DataStreamUploadResult m_result{}; @@ -68,6 +73,67 @@ private: std::mutex m_writeStatusGate{}; std::condition_variable m_writesDone{}; bool m_done{ false }; + std::chrono::duration m_writesDoneTimeoutValue{ DefaultTimeoutValue }; +}; + +/** + * @brief Implementation of the gRPC ClientReadReactor for client-side data stream reading. + */ +class DataStreamReader : + public grpc::ClientReadReactor +{ +public: + /** + * @brief Construct a new DataStreamReader object with the client stub and specified number of data blocks to write. + * + * @param client The data streaming client stub. + * @param request The download request to be sent to the server. + */ + explicit DataStreamReader(Microsoft::Net::Remote::Service::NetRemoteDataStreaming::Stub* client, Microsoft::Net::Remote::DataStream::DataStreamDownloadRequest* request); + + /** + * @brief Callback that is executed when a read operation is completed. + * + * @param isOk Indicates whether a message was read as expected. + */ + void + OnReadDone(bool isOk) override; + + /** + * @brief Callback that is executed when all RPC operations are completed for a given RPC. + * + * @param status The status of the RPC sent by the server or provided by the library to indicate a failure. + */ + void + OnDone(const grpc::Status& status) override; + + /** + * @brief Wait for all operations to complete and transfer the resulting information to the output parameters. + * + * @param numberOfDataBlocksReceived The number of data blocks received by the client. + * @param operationStatus The status of the data stream read operation. + * @return grpc::Status + */ + grpc::Status + Await(uint32_t* numberOfDataBlocksReceived, Microsoft::Net::Remote::DataStream::DataStreamOperationStatus* operationStatus); + + /** + * @brief Cancel the ongoing RPC. + */ + void + Cancel(); + +private: + static inline constexpr auto DefaultTimeoutValue{ 10s }; + + grpc::ClientContext m_clientContext{}; + Microsoft::Net::Remote::DataStream::DataStreamDownloadData m_data{}; + uint32_t m_numberOfDataBlocksReceived{}; + grpc::Status m_status{}; + std::mutex m_readStatusGate{}; + std::condition_variable m_readsDone{}; + bool m_done{ false }; + std::chrono::duration m_readsDoneTimeoutValue{ DefaultTimeoutValue }; }; } // namespace Microsoft::Net::Remote::Test diff --git a/tests/unit/TestNetRemoteDataStreamingServiceClient.cxx b/tests/unit/TestNetRemoteDataStreamingServiceClient.cxx index 8b17da45..70f3413c 100644 --- a/tests/unit/TestNetRemoteDataStreamingServiceClient.cxx +++ b/tests/unit/TestNetRemoteDataStreamingServiceClient.cxx @@ -1,14 +1,19 @@ -#include +#include +#include #include -#include #include #include #include #include #include +#include +#include +#include #include +#include +#include #include #include "TestNetRemoteCommon.hxx" @@ -23,7 +28,7 @@ TEST_CASE("DataStreamUpload API", "[basic][rpc][client][remote][stream]") using Microsoft::Net::Remote::Test::DataStreamWriter; using Microsoft::Net::Remote::Test::RemoteServiceAddressHttp; - NetRemoteServerConfiguration Configuration{ + const NetRemoteServerConfiguration Configuration{ .ServerAddress = RemoteServiceAddressHttp, }; @@ -40,7 +45,7 @@ TEST_CASE("DataStreamUpload API", "[basic][rpc][client][remote][stream]") auto dataStreamWriter = std::make_unique(client.get(), numberOfDataBlocksToWrite); DataStreamUploadResult result{}; - grpc::Status status = dataStreamWriter->Await(&result); + const grpc::Status status = dataStreamWriter->Await(&result); REQUIRE(status.ok()); REQUIRE(result.numberofdatablocksreceived() == numberOfDataBlocksToWrite); REQUIRE(result.status().code() == DataStreamOperationStatusCodeSucceeded); @@ -59,7 +64,7 @@ TEST_CASE("DataStreamUpload API", "[basic][rpc][client][remote][stream]") clientThreads.emplace_back([dataStreamWriter = std::move(dataStreamWriter)]() { DataStreamUploadResult result{}; - grpc::Status status = dataStreamWriter->Await(&result); + const grpc::Status status = dataStreamWriter->Await(&result); REQUIRE(status.ok()); REQUIRE(result.numberofdatablocksreceived() == numberOfDataBlocksToWrite); REQUIRE(result.status().code() == DataStreamOperationStatusCodeSucceeded); @@ -71,3 +76,75 @@ TEST_CASE("DataStreamUpload API", "[basic][rpc][client][remote][stream]") } } } + +TEST_CASE("DataStreamDownload API", "[basic][rpc][client][remote][stream]") +{ + using namespace Microsoft::Net::Remote; + using namespace Microsoft::Net::Remote::DataStream; + using namespace Microsoft::Net::Remote::Service; + + using Microsoft::Net::Remote::Test::DataStreamReader; + using Microsoft::Net::Remote::Test::RemoteServiceAddressHttp; + + const NetRemoteServerConfiguration Configuration{ + .ServerAddress = RemoteServiceAddressHttp, + }; + + NetRemoteServer server{ Configuration }; + server.Run(); + + auto channel = grpc::CreateChannel(RemoteServiceAddressHttp, grpc::InsecureChannelCredentials()); + auto client = NetRemoteDataStreaming::NewStub(channel); + + static constexpr auto fixedNumberOfDataBlocksToStream = 10; + + SECTION("Can be called with DataStreamTypeFixed and DataStreamPatternConstant") + { + DataStreamFixedTypeProperties fixedTypeProperties{}; + fixedTypeProperties.set_numberofdatablockstostream(fixedNumberOfDataBlocksToStream); + + DataStreamProperties properties{}; + properties.set_type(DataStreamType::DataStreamTypeFixed); + properties.set_pattern(DataStreamPattern::DataStreamPatternConstant); + *properties.mutable_fixed() = std::move(fixedTypeProperties); + + DataStreamDownloadRequest request{}; + *request.mutable_properties() = std::move(properties); + + DataStreamReader dataStreamReader{ client.get(), &request }; + + uint32_t numberOfDataBlocksReceived{}; + DataStreamOperationStatus operationStatus{}; + const grpc::Status status = dataStreamReader.Await(&numberOfDataBlocksReceived, &operationStatus); + REQUIRE(status.ok()); + REQUIRE(numberOfDataBlocksReceived == fixedNumberOfDataBlocksToStream); + REQUIRE(operationStatus.code() == DataStreamOperationStatusCodeSucceeded); + } + + SECTION("Can be called with DataStreamTypeContinuous and DataStreamPatternConstant") + { + static constexpr auto StreamingDelayTime = 5s; + + DataStreamContinuousTypeProperties continuousTypeProperties{}; + + DataStreamProperties properties{}; + properties.set_type(DataStreamType::DataStreamTypeContinuous); + properties.set_pattern(DataStreamPattern::DataStreamPatternConstant); + *properties.mutable_continuous() = std::move(continuousTypeProperties); + + DataStreamDownloadRequest request{}; + *request.mutable_properties() = std::move(properties); + + DataStreamReader dataStreamReader{ client.get(), &request }; + + // Allow some time of continuous streaming by the server, then cancel the RPC. + std::this_thread::sleep_for(StreamingDelayTime); + dataStreamReader.Cancel(); + + uint32_t numberOfDataBlocksReceived{}; + DataStreamOperationStatus operationStatus{}; + const grpc::Status status = dataStreamReader.Await(&numberOfDataBlocksReceived, &operationStatus); + REQUIRE(status.error_code() == grpc::StatusCode::CANCELLED); + REQUIRE(operationStatus.code() == DataStreamOperationStatusCodeSucceeded); + } +}