diff --git a/src/common/service/NetRemoteDataStreamingReactors.cxx b/src/common/service/NetRemoteDataStreamingReactors.cxx index 2335edda..a044294a 100644 --- a/src/common/service/NetRemoteDataStreamingReactors.cxx +++ b/src/common/service/NetRemoteDataStreamingReactors.cxx @@ -303,11 +303,10 @@ DataStreamReaderWriter::OnCancel() { const FunctionTracer traceMe{}; - // The RPC is canceled by the client, so call Finish to complete it from the server perspective. + // Because the RPC is canceled by the client, there will be no more data to read. Thus, OnReadDone will fail and will + // call Finish(), so don't call Finish() here. bool isCanceledExpected{ false }; - if (m_isCanceled.compare_exchange_strong(isCanceledExpected, true, std::memory_order_relaxed, std::memory_order_relaxed)) { - Finish(grpc::Status::CANCELLED); - } + m_isCanceled.compare_exchange_strong(isCanceledExpected, true, std::memory_order_relaxed, std::memory_order_relaxed); } void diff --git a/tests/unit/TestNetRemoteDataStreamingReactors.cxx b/tests/unit/TestNetRemoteDataStreamingReactors.cxx index 4f11763d..5fd34a6b 100644 --- a/tests/unit/TestNetRemoteDataStreamingReactors.cxx +++ b/tests/unit/TestNetRemoteDataStreamingReactors.cxx @@ -1,5 +1,5 @@ -#include +#include #include #include #include @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -143,9 +144,36 @@ DataStreamReader::Cancel() m_clientContext.TryCancel(); } -DataStreamReaderWriter::DataStreamReaderWriter(NetRemoteDataStreaming::Stub* client, uint32_t numberOfDataBlocksToWrite) : - m_numberOfDataBlocksToWrite(numberOfDataBlocksToWrite) +DataStreamReaderWriter::DataStreamReaderWriter(NetRemoteDataStreaming::Stub* client, DataStreamProperties dataStreamProperties) : + m_dataStreamProperties(std::move(dataStreamProperties)) { + switch (m_dataStreamProperties.type()) { + case DataStreamType::DataStreamTypeFixed: { + if (m_dataStreamProperties.Properties_case() == DataStreamProperties::kFixed) { + m_numberOfDataBlocksToWrite = m_dataStreamProperties.fixed().numberofdatablockstostream(); + } else { + LOGE << "Invalid properties for this streaming type. Expected Fixed for DataStreamTypeFixed"; + return; + } + + break; + } + case DataStreamType::DataStreamTypeContinuous: { + if (m_dataStreamProperties.Properties_case() == DataStreamProperties::kContinuous) { + m_numberOfDataBlocksToWrite = 0; + } else { + LOGE << "Invalid properties for this streaming type. Expected Continuous for DataStreamTypeContinuous"; + return; + } + + break; + } + default: { + LOGE << std::format("Invalid streaming type: {}", magic_enum::enum_name(m_dataStreamProperties.type())); + return; + } + }; + client->async()->DataStreamBidirectional(&m_clientContext, this); StartCall(); StartRead(&m_readData); @@ -173,9 +201,16 @@ void DataStreamReaderWriter::OnWriteDone(bool isOk) { if (isOk) { + if (m_dataStreamProperties.type() == DataStreamType::DataStreamTypeFixed) { + m_numberOfDataBlocksToWrite--; + } NextWrite(); } else { - StartWritesDone(); + // If StopWrites() was called and continuous data streaming is used, then StartWritesDone() + // was already called. + if (!m_writesStopped.load(std::memory_order_relaxed)) { + StartWritesDone(); + } } } @@ -213,12 +248,24 @@ DataStreamReaderWriter::Await(uint32_t* numberOfDataBlocksReceived, DataStreamOp return m_operationStatus; } +void +DataStreamReaderWriter::StopWrites() +{ + if (m_dataStreamProperties.type() == DataStreamType::DataStreamTypeContinuous) { + bool writesStoppedExpected{ false }; + if (m_writesStopped.compare_exchange_strong(writesStoppedExpected, true, std::memory_order_relaxed, std::memory_order_relaxed)) { + LOGD << "Stopping all write operations"; + StartWritesDone(); + } + } +} + void DataStreamReaderWriter::NextWrite() { - if (m_numberOfDataBlocksToWrite > 0) { + if (m_dataStreamProperties.type() == DataStreamType::DataStreamTypeContinuous || + (m_dataStreamProperties.type() == DataStreamType::DataStreamTypeFixed && m_numberOfDataBlocksToWrite > 0)) { m_writeData.set_data(std::format("Data #{}", ++m_numberOfDataBlocksWritten)); - m_numberOfDataBlocksToWrite--; StartWrite(&m_writeData); } else { StartWritesDone(); diff --git a/tests/unit/TestNetRemoteDataStreamingReactors.hxx b/tests/unit/TestNetRemoteDataStreamingReactors.hxx index 5de95646..9134ab92 100644 --- a/tests/unit/TestNetRemoteDataStreamingReactors.hxx +++ b/tests/unit/TestNetRemoteDataStreamingReactors.hxx @@ -151,9 +151,9 @@ public: * @brief Construct a new DataStreamReaderWriter object with the client stub and specified number of data blocks to write. * * @param client The data streaming client stub. - * @param numberOfDataBlocksToWrite The number of data blocks to write to the server. + * @param dataStreamProperties The properties associated with the type of data streaming. */ - explicit DataStreamReaderWriter(Microsoft::Net::Remote::Service::NetRemoteDataStreaming::Stub* client, uint32_t numberOfDataBlocksToWrite); + explicit DataStreamReaderWriter(Microsoft::Net::Remote::Service::NetRemoteDataStreaming::Stub* client, Microsoft::Net::Remote::DataStream::DataStreamProperties dataStreamProperties); /** * @brief Callback that is executed when a read operation is completed. @@ -190,6 +190,12 @@ public: grpc::Status Await(uint32_t* numberOfDataBlocksReceived, Microsoft::Net::Remote::DataStream::DataStreamOperationStatus* operationStatus, std::span lostDataBlockSequenceNumbers); + /** + * @brief Stops writing data to the client. Should only be called with DataStreamTypeContinuous. + */ + void + StopWrites(); + private: /** * @brief Facilitate the next write operation. @@ -203,6 +209,7 @@ private: grpc::ClientContext m_clientContext{}; Microsoft::Net::Remote::DataStream::DataStreamDownloadData m_readData{}; Microsoft::Net::Remote::DataStream::DataStreamUploadData m_writeData{}; + Microsoft::Net::Remote::DataStream::DataStreamProperties m_dataStreamProperties{}; uint32_t m_numberOfDataBlocksToWrite{}; uint32_t m_numberOfDataBlocksWritten{}; uint32_t m_numberOfDataBlocksReceived{}; @@ -211,6 +218,7 @@ private: std::condition_variable m_operationsDone{}; bool m_done{ false }; std::vector m_lostDataBlockSequenceNumbers{}; + std::atomic m_writesStopped{}; }; } // namespace Microsoft::Net::Remote::Test diff --git a/tests/unit/TestNetRemoteDataStreamingServiceClient.cxx b/tests/unit/TestNetRemoteDataStreamingServiceClient.cxx index ba69a92e..4c69a35d 100644 --- a/tests/unit/TestNetRemoteDataStreamingServiceClient.cxx +++ b/tests/unit/TestNetRemoteDataStreamingServiceClient.cxx @@ -173,11 +173,45 @@ TEST_CASE("DataStreamBidirectional API", "[basic][rpc][client][remote][stream]") auto channel = grpc::CreateChannel(RemoteServiceAddressHttp, grpc::InsecureChannelCredentials()); auto client = NetRemoteDataStreaming::NewStub(channel); - static constexpr auto numberOfDataBlocksToStream = 10; + static constexpr auto fixedNumberOfDataBlocksToStream = 10; + + SECTION("Can be called with DataStreamTypeFixed") + { + DataStreamFixedTypeProperties fixedTypeProperties{}; + fixedTypeProperties.set_numberofdatablockstostream(fixedNumberOfDataBlocksToStream); + + DataStreamProperties properties{}; + properties.set_type(DataStreamType::DataStreamTypeFixed); + *properties.mutable_fixed() = std::move(fixedTypeProperties); + + DataStreamReaderWriter dataStreamReaderWriter{ client.get(), std::move(properties) }; + + uint32_t numberOfDataBlocksReceived{}; + DataStreamOperationStatus operationStatus{}; + std::vector lostDataBlockSequenceNumbers{}; + const grpc::Status status = dataStreamReaderWriter.Await(&numberOfDataBlocksReceived, &operationStatus, lostDataBlockSequenceNumbers); + REQUIRE(status.ok()); + REQUIRE(numberOfDataBlocksReceived > 0); + REQUIRE(operationStatus.code() == DataStreamOperationStatusCodeSucceeded); + REQUIRE(lostDataBlockSequenceNumbers.empty()); + } - SECTION("Can be called") + SECTION("Can be called with DataStreamTypeContinuous") { - DataStreamReaderWriter dataStreamReaderWriter{ client.get(), numberOfDataBlocksToStream }; + static constexpr auto StreamingDelayTime = 5s; + + DataStreamContinuousTypeProperties continuousTypeProperties{}; + + DataStreamProperties properties{}; + properties.set_type(DataStreamType::DataStreamTypeContinuous); + *properties.mutable_continuous() = std::move(continuousTypeProperties); + + DataStreamReaderWriter dataStreamReaderWriter{ client.get(), std::move(properties) }; + + // Allow some time of continuous streaming, then stop writing data. This will prompt the + // server to stop writing too, eliminating the need to cancel the RPC. + std::this_thread::sleep_for(StreamingDelayTime); + dataStreamReaderWriter.StopWrites(); uint32_t numberOfDataBlocksReceived{}; DataStreamOperationStatus operationStatus{};