Skip to content

Commit

Permalink
Merge pull request #208 from microsoft/user/corbinphipps/add-cancelat…
Browse files Browse the repository at this point in the history
…ion-test

Continuous client-side streaming using DataStreamBidirectional API
  • Loading branch information
corbin-phipps authored Mar 4, 2024
2 parents 79e1f27 + 829183b commit 8395123
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 15 deletions.
7 changes: 3 additions & 4 deletions src/common/service/NetRemoteDataStreamingReactors.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,10 @@ DataStreamReaderWriter::OnCancel()
{
const FunctionTracer traceMe{};

// The RPC is canceled by the client, so call Finish to complete it from the server perspective.
// Because the RPC is canceled by the client, there will be no more data to read. Thus, OnReadDone will fail and will
// call Finish(), so don't call Finish() here.
bool isCanceledExpected{ false };
if (m_isCanceled.compare_exchange_strong(isCanceledExpected, true, std::memory_order_relaxed, std::memory_order_relaxed)) {
Finish(grpc::Status::CANCELLED);
}
m_isCanceled.compare_exchange_strong(isCanceledExpected, true, std::memory_order_relaxed, std::memory_order_relaxed);
}

void
Expand Down
59 changes: 53 additions & 6 deletions tests/unit/TestNetRemoteDataStreamingReactors.cxx
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@

#include <cstddef>
#include <atomic>
#include <cstdint>
#include <format>
#include <mutex>
#include <span>
#include <utility>

#include <grpcpp/impl/codegen/status.h>
#include <magic_enum.hpp>
#include <microsoft/net/remote/protocol/NetRemoteDataStream.pb.h>
#include <microsoft/net/remote/protocol/NetRemoteDataStreamingService.grpc.pb.h>
#include <plog/Log.h>
Expand Down Expand Up @@ -143,9 +144,36 @@ DataStreamReader::Cancel()
m_clientContext.TryCancel();
}

DataStreamReaderWriter::DataStreamReaderWriter(NetRemoteDataStreaming::Stub* client, uint32_t numberOfDataBlocksToWrite) :
m_numberOfDataBlocksToWrite(numberOfDataBlocksToWrite)
DataStreamReaderWriter::DataStreamReaderWriter(NetRemoteDataStreaming::Stub* client, DataStreamProperties dataStreamProperties) :
m_dataStreamProperties(std::move(dataStreamProperties))
{
switch (m_dataStreamProperties.type()) {
case DataStreamType::DataStreamTypeFixed: {
if (m_dataStreamProperties.Properties_case() == DataStreamProperties::kFixed) {
m_numberOfDataBlocksToWrite = m_dataStreamProperties.fixed().numberofdatablockstostream();
} else {
LOGE << "Invalid properties for this streaming type. Expected Fixed for DataStreamTypeFixed";
return;
}

break;
}
case DataStreamType::DataStreamTypeContinuous: {
if (m_dataStreamProperties.Properties_case() == DataStreamProperties::kContinuous) {
m_numberOfDataBlocksToWrite = 0;
} else {
LOGE << "Invalid properties for this streaming type. Expected Continuous for DataStreamTypeContinuous";
return;
}

break;
}
default: {
LOGE << std::format("Invalid streaming type: {}", magic_enum::enum_name(m_dataStreamProperties.type()));
return;
}
};

client->async()->DataStreamBidirectional(&m_clientContext, this);
StartCall();
StartRead(&m_readData);
Expand Down Expand Up @@ -173,9 +201,16 @@ void
DataStreamReaderWriter::OnWriteDone(bool isOk)
{
if (isOk) {
if (m_dataStreamProperties.type() == DataStreamType::DataStreamTypeFixed) {
m_numberOfDataBlocksToWrite--;
}
NextWrite();
} else {
StartWritesDone();
// If StopWrites() was called and continuous data streaming is used, then StartWritesDone()
// was already called.
if (!m_writesStopped.load(std::memory_order_relaxed)) {
StartWritesDone();
}
}
}

Expand Down Expand Up @@ -213,12 +248,24 @@ DataStreamReaderWriter::Await(uint32_t* numberOfDataBlocksReceived, DataStreamOp
return m_operationStatus;
}

void
DataStreamReaderWriter::StopWrites()
{
if (m_dataStreamProperties.type() == DataStreamType::DataStreamTypeContinuous) {
bool writesStoppedExpected{ false };
if (m_writesStopped.compare_exchange_strong(writesStoppedExpected, true, std::memory_order_relaxed, std::memory_order_relaxed)) {
LOGD << "Stopping all write operations";
StartWritesDone();
}
}
}

void
DataStreamReaderWriter::NextWrite()
{
if (m_numberOfDataBlocksToWrite > 0) {
if (m_dataStreamProperties.type() == DataStreamType::DataStreamTypeContinuous ||
(m_dataStreamProperties.type() == DataStreamType::DataStreamTypeFixed && m_numberOfDataBlocksToWrite > 0)) {
m_writeData.set_data(std::format("Data #{}", ++m_numberOfDataBlocksWritten));
m_numberOfDataBlocksToWrite--;
StartWrite(&m_writeData);
} else {
StartWritesDone();
Expand Down
12 changes: 10 additions & 2 deletions tests/unit/TestNetRemoteDataStreamingReactors.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ public:
* @brief Construct a new DataStreamReaderWriter object with the client stub and specified number of data blocks to write.
*
* @param client The data streaming client stub.
* @param numberOfDataBlocksToWrite The number of data blocks to write to the server.
* @param dataStreamProperties The properties associated with the type of data streaming.
*/
explicit DataStreamReaderWriter(Microsoft::Net::Remote::Service::NetRemoteDataStreaming::Stub* client, uint32_t numberOfDataBlocksToWrite);
explicit DataStreamReaderWriter(Microsoft::Net::Remote::Service::NetRemoteDataStreaming::Stub* client, Microsoft::Net::Remote::DataStream::DataStreamProperties dataStreamProperties);

/**
* @brief Callback that is executed when a read operation is completed.
Expand Down Expand Up @@ -190,6 +190,12 @@ public:
grpc::Status
Await(uint32_t* numberOfDataBlocksReceived, Microsoft::Net::Remote::DataStream::DataStreamOperationStatus* operationStatus, std::span<uint32_t> lostDataBlockSequenceNumbers);

/**
* @brief Stops writing data to the client. Should only be called with DataStreamTypeContinuous.
*/
void
StopWrites();

private:
/**
* @brief Facilitate the next write operation.
Expand All @@ -203,6 +209,7 @@ private:
grpc::ClientContext m_clientContext{};
Microsoft::Net::Remote::DataStream::DataStreamDownloadData m_readData{};
Microsoft::Net::Remote::DataStream::DataStreamUploadData m_writeData{};
Microsoft::Net::Remote::DataStream::DataStreamProperties m_dataStreamProperties{};
uint32_t m_numberOfDataBlocksToWrite{};
uint32_t m_numberOfDataBlocksWritten{};
uint32_t m_numberOfDataBlocksReceived{};
Expand All @@ -211,6 +218,7 @@ private:
std::condition_variable m_operationsDone{};
bool m_done{ false };
std::vector<uint32_t> m_lostDataBlockSequenceNumbers{};
std::atomic<bool> m_writesStopped{};
};

} // namespace Microsoft::Net::Remote::Test
Expand Down
40 changes: 37 additions & 3 deletions tests/unit/TestNetRemoteDataStreamingServiceClient.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,45 @@ TEST_CASE("DataStreamBidirectional API", "[basic][rpc][client][remote][stream]")
auto channel = grpc::CreateChannel(RemoteServiceAddressHttp, grpc::InsecureChannelCredentials());
auto client = NetRemoteDataStreaming::NewStub(channel);

static constexpr auto numberOfDataBlocksToStream = 10;
static constexpr auto fixedNumberOfDataBlocksToStream = 10;

SECTION("Can be called with DataStreamTypeFixed")
{
DataStreamFixedTypeProperties fixedTypeProperties{};
fixedTypeProperties.set_numberofdatablockstostream(fixedNumberOfDataBlocksToStream);

DataStreamProperties properties{};
properties.set_type(DataStreamType::DataStreamTypeFixed);
*properties.mutable_fixed() = std::move(fixedTypeProperties);

DataStreamReaderWriter dataStreamReaderWriter{ client.get(), std::move(properties) };

uint32_t numberOfDataBlocksReceived{};
DataStreamOperationStatus operationStatus{};
std::vector<uint32_t> lostDataBlockSequenceNumbers{};
const grpc::Status status = dataStreamReaderWriter.Await(&numberOfDataBlocksReceived, &operationStatus, lostDataBlockSequenceNumbers);
REQUIRE(status.ok());
REQUIRE(numberOfDataBlocksReceived > 0);
REQUIRE(operationStatus.code() == DataStreamOperationStatusCodeSucceeded);
REQUIRE(lostDataBlockSequenceNumbers.empty());
}

SECTION("Can be called")
SECTION("Can be called with DataStreamTypeContinuous")
{
DataStreamReaderWriter dataStreamReaderWriter{ client.get(), numberOfDataBlocksToStream };
static constexpr auto StreamingDelayTime = 5s;

DataStreamContinuousTypeProperties continuousTypeProperties{};

DataStreamProperties properties{};
properties.set_type(DataStreamType::DataStreamTypeContinuous);
*properties.mutable_continuous() = std::move(continuousTypeProperties);

DataStreamReaderWriter dataStreamReaderWriter{ client.get(), std::move(properties) };

// Allow some time of continuous streaming, then stop writing data. This will prompt the
// server to stop writing too, eliminating the need to cancel the RPC.
std::this_thread::sleep_for(StreamingDelayTime);
dataStreamReaderWriter.StopWrites();

uint32_t numberOfDataBlocksReceived{};
DataStreamOperationStatus operationStatus{};
Expand Down

0 comments on commit 8395123

Please sign in to comment.