Skip to content

Commit

Permalink
Add HalfClose to OakSessionChannel.
Browse files Browse the repository at this point in the history
ACKNOWLEDGE_FAILING_COPYBARA_IMPORT=cl/726948285 can be used for fixing tests while importing.

Change-Id: I45efccd25071a103b35a0b700e55a0e6b806518f
  • Loading branch information
Zhumazhenis Dairabay committed Mar 6, 2025
1 parent d889956 commit 8817480
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cc/client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,12 @@ cc_test(
"//cc/ffi:rust_bytes",
"//cc/oak_session:client_session",
"//cc/oak_session:server_session",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest_main",
],
)
29 changes: 29 additions & 0 deletions cc/client/session_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

#include <memory>

#include "absl/base/thread_annotations.h"
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "absl/synchronization/mutex.h"
#include "cc/ffi/rust_bytes.h"
#include "cc/oak_session/client_session.h"
#include "cc/oak_session/server_session.h"
Expand All @@ -31,6 +33,7 @@ namespace {

using ::absl_testing::IsOk;
using ::absl_testing::IsOkAndHolds;
using ::absl_testing::StatusIs;
using ::testing::Eq;
using ::testing::Ne;
using ::testing::Optional;
Expand All @@ -40,6 +43,10 @@ class TestTransport : public OakSessionClient::Channel::Transport {
TestTransport(std::unique_ptr<session::ServerSession> server_session)
: server_session_(std::move(server_session)) {}
absl::Status Send(session::v1::SessionRequest&& request) override {
absl::MutexLock lock(&mtx_);
if (half_closed_) {
return absl::InternalError("Already half-closed.");
}
return server_session_->PutIncomingMessage(request);
}
absl::StatusOr<session::v1::SessionResponse> Receive() override {
Expand All @@ -54,8 +61,16 @@ class TestTransport : public OakSessionClient::Channel::Transport {
return **msg;
}

void HalfClose() override {
absl::MutexLock lock(&mtx_);
// TODO: zhumazhenis - half-close in server_session_ too if supports.
half_closed_ = true;
}

private:
std::unique_ptr<session::ServerSession> server_session_;
absl::Mutex mtx_;
bool half_closed_ ABSL_GUARDED_BY(mtx_) = false;
};

session::SessionConfig* TestSessionConfig() {
Expand Down Expand Up @@ -104,6 +119,20 @@ TEST(OakSessionClientTest, CreatedSessionCanSend) {
EXPECT_THAT(test_send_read_back, IsOkAndHolds(Optional(Eq(test_send_msg))));
}

TEST(OakSessionClientTest, HalfClosedSessionFailsToSend) {
auto server_session = session::ServerSession::Create(TestSessionConfig());
ASSERT_THAT(server_session, IsOk());
auto channel = OakSessionClient(TestSessionConfig)
.NewChannel(std::make_unique<TestTransport>(
std::move(*server_session)));

(*channel)->HalfClose();

std::string test_send_msg = "Testing Send";
EXPECT_THAT((*channel)->Send(test_send_msg),
StatusIs(absl::StatusCode::kInternal));
}

TEST(OakSessionClientTest, CreatedSessionCanReceive) {
auto server_session = session::ServerSession::Create(TestSessionConfig());
// Hold a pointer for testing behavior below.
Expand Down
14 changes: 14 additions & 0 deletions cc/oak_session/channel/oak_session_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,29 @@ class OakSessionChannel {
class Transport {
public:
virtual ~Transport() = default;
// Any subsequent calls to Send() will fail with INTERNAL error after
// the transport is half-closed.
virtual absl::Status Send(SendMessage&& message) = 0;

// Implementations should block until a new message is available to return.
// Blocking semantics, deadlines, etc should be defined by the particular
// implementation.
virtual absl::StatusOr<ReceiveMessage> Receive() = 0;

// Half closes the transport. Thread safe, idempotent, so safe to call
// multiple times. Expected that "this end" will no longer send messages to
// the "other end" after "this end" half-closed. Any subsequent calls to
// Send() by "this end" will fail with INTERNAL error after this point.
virtual void HalfClose() = 0;
};

OakSessionChannel(std::unique_ptr<Session> session,
std::unique_ptr<Transport> transport)
: session_(std::move(session)), transport_(std::move(transport)) {}

// Encrypt and send a message back to the other party.
// Any subsequent calls to Send() will fail with INTERNAL error after
// the channel is half-closed.
absl::Status Send(absl::string_view unencrypted_message) {
absl::Status write_result = session_->Write(unencrypted_message);
if (!write_result.ok()) {
Expand Down Expand Up @@ -123,6 +133,10 @@ class OakSessionChannel {
return std::string(**decrypted_message);
}

// Half closes the channel. Similar behavior to the Transport::HalfClose()
// method.
void HalfClose() { transport_->HalfClose(); }

// Create a new OakSessionChannel instance with the provided session and
// transport.
//
Expand Down
2 changes: 2 additions & 0 deletions cc/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ cc_test(
"//cc/ffi:rust_bytes",
"//cc/oak_session:client_session",
"//cc/oak_session:server_session",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest_main",
],
)
27 changes: 27 additions & 0 deletions cc/server/session_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
#include <string>
#include <utility>

#include "absl/base/thread_annotations.h"
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "cc/ffi/rust_bytes.h"
#include "cc/oak_session/client_session.h"
#include "cc/oak_session/server_session.h"
Expand All @@ -35,6 +37,7 @@ namespace {

using ::absl_testing::IsOk;
using ::absl_testing::IsOkAndHolds;
using ::absl_testing::StatusIs;
using ::testing::Eq;
using ::testing::Ne;
using ::testing::Optional;
Expand All @@ -44,6 +47,10 @@ class TestTransport : public OakSessionServer::Channel::Transport {
TestTransport(std::unique_ptr<session::ClientSession> client_session)
: client_session_(std::move(client_session)) {}
absl::Status Send(session::v1::SessionResponse&& request) override {
absl::MutexLock lock(&mtx_);
if (half_closed_) {
return absl::InternalError("Already half-closed.");
}
return client_session_->PutIncomingMessage(request);
}
absl::StatusOr<session::v1::SessionRequest> Receive() override {
Expand All @@ -57,9 +64,15 @@ class TestTransport : public OakSessionServer::Channel::Transport {
}
return **msg;
}
void HalfClose() override {
absl::MutexLock lock(&mtx_);
half_closed_ = true;
}

private:
std::unique_ptr<session::ClientSession> client_session_;
absl::Mutex mtx_;
bool half_closed_ ABSL_GUARDED_BY(mtx_) = false;
};

session::SessionConfig* TestSessionConfig() {
Expand Down Expand Up @@ -109,6 +122,20 @@ TEST(OakSessionServerTest, CreatedSessionCanSend) {
EXPECT_THAT(test_send_read_back, IsOkAndHolds(Optional(Eq(test_send_msg))));
}

TEST(OakSessionServerTest, HalfClosedSessionFailsToSend) {
auto client_session = session::ClientSession::Create(TestSessionConfig());
ASSERT_THAT(client_session, IsOk());
auto channel = OakSessionServer(TestSessionConfig)
.NewChannel(std::make_unique<TestTransport>(
std::move(*client_session)));

(*channel)->HalfClose();

std::string test_send_msg = "Testing Send";
EXPECT_THAT((*channel)->Send(test_send_msg),
StatusIs(absl::StatusCode::kInternal));
}

TEST(OakSessionServerTest, CreatedSessionCanReceive) {
auto client_session = session::ClientSession::Create(TestSessionConfig());
// Hold a pointer for testing behavior below.
Expand Down
14 changes: 14 additions & 0 deletions cc/transport/grpc_sync_session_client_transport.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "cc/oak_session/client_session.h"
#include "grpcpp/support/sync_stream.h"
#include "proto/session/session.pb.h"
Expand All @@ -26,6 +27,10 @@ namespace oak::transport {

absl::Status GrpcSyncSessionClientTransport::Send(
session::v1::SessionRequest&& message) {
absl::MutexLock lock(&mtx_);
if (half_closed_) {
return absl::InternalError("Already half-closed.");
}
if (!stream_->Write(message)) {
return absl::AbortedError("Failed to write outgoing message.");
}
Expand All @@ -41,4 +46,13 @@ GrpcSyncSessionClientTransport::Receive() {
return response;
}

void GrpcSyncSessionClientTransport::HalfClose() {
absl::MutexLock lock(&mtx_);
if (half_closed_) {
return;
}
stream_->WritesDone();
half_closed_ = true;
}

}; // namespace oak::transport
6 changes: 6 additions & 0 deletions cc/transport/grpc_sync_session_client_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
#ifndef CC_TRANSPORT_GRPC_SYNC_CLIENT_SESSION_TRANSPORT_H_
#define CC_TRANSPORT_GRPC_SYNC_CLIENT_SESSION_TRANSPORT_H_

#include "absl/base/thread_annotations.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "cc/client/session_client.h"
#include "cc/oak_session/channel/oak_session_channel.h"
#include "cc/oak_session/client_session.h"
Expand All @@ -40,11 +42,15 @@ class GrpcSyncSessionClientTransport

absl::Status Send(session::v1::SessionRequest&& message) override;
absl::StatusOr<session::v1::SessionResponse> Receive() override;
void HalfClose() override;

private:
std::unique_ptr<grpc::ClientReaderWriterInterface<
session::v1::SessionRequest, session::v1::SessionResponse>>
stream_;

absl::Mutex mtx_;
bool half_closed_ ABSL_GUARDED_BY(mtx_) = false;
};

} // namespace oak::transport
Expand Down
10 changes: 10 additions & 0 deletions cc/transport/grpc_sync_session_server_transport.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "cc/oak_session/server_session.h"
#include "proto/session/session.pb.h"

namespace oak::transport {

absl::Status GrpcSyncSessionServerTransport::Send(
session::v1::SessionResponse&& message) {
absl::MutexLock lock(&mtx_);
if (half_closed_) {
return absl::InternalError("Already half-closed.");
}
if (!stream_->Write(message)) {
return absl::AbortedError("Failed to write outgoing message.");
}
Expand All @@ -40,4 +45,9 @@ GrpcSyncSessionServerTransport::Receive() {
return request;
}

void GrpcSyncSessionServerTransport::HalfClose() {
absl::MutexLock lock(&mtx_);
half_closed_ = true;
}

} // namespace oak::transport
5 changes: 5 additions & 0 deletions cc/transport/grpc_sync_session_server_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
#ifndef CC_TRANSPORT_GRPC_SYNC_SESSION_SERVER_TRANSPORT_H_
#define CC_TRANSPORT_GRPC_SYNC_SESSION_SERVER_TRANSPORT_H_

#include "absl/base/thread_annotations.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "cc/oak_session/channel/oak_session_channel.h"
#include "cc/oak_session/server_session.h"
#include "cc/server/session_server.h"
Expand All @@ -39,10 +41,13 @@ class GrpcSyncSessionServerTransport

absl::Status Send(session::v1::SessionResponse&& message) override;
absl::StatusOr<session::v1::SessionRequest> Receive() override;
void HalfClose() override;

private:
grpc::ServerReaderWriterInterface<session::v1::SessionResponse,
session::v1::SessionRequest>* stream_;
absl::Mutex mtx_;
bool half_closed_ ABSL_GUARDED_BY(mtx_) = false;
};

} // namespace oak::transport
Expand Down

0 comments on commit 8817480

Please sign in to comment.