Skip to content

Commit

Permalink
Add getter.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 17, 2024
1 parent 238bc58 commit 0c057a1
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
6 changes: 5 additions & 1 deletion plugin/federated/federated_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include <chrono> // for seconds
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <memory> // for shared_ptr, dynamic_pointer_cast
#include <string> // for string

#include "../../src/collective/comm.h" // for HostComm
Expand Down Expand Up @@ -65,6 +65,10 @@ class FederatedComm : public HostComm {
return Success();
}
[[nodiscard]] bool IsFederated() const override { return true; }
[[nodiscard]] bool IsEncrypted() const override {
auto mock_ptr = std::dynamic_pointer_cast<FederatedPluginMock>(plugin_);
return !mock_ptr;
}
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }

[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
Expand Down
1 change: 1 addition & 0 deletions src/collective/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class Comm : public std::enable_shared_from_this<Comm> {
return channels_.at(rank);
}
[[nodiscard]] virtual bool IsFederated() const = 0;
[[nodiscard]] virtual bool IsEncrypted() const { return false; }
[[nodiscard]] virtual Result LogTracker(std::string msg) const = 0;

[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
Expand Down
12 changes: 8 additions & 4 deletions src/collective/comm_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,20 @@ void Init(Json const& config) { GlobalCommGroupInit(config); }

void Finalize() { GlobalCommGroupFinalize(); }

std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); }
[[nodiscard]] std::int32_t GetRank() noexcept { return GlobalCommGroup()->Rank(); }

std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); }
[[nodiscard]] std::int32_t GetWorldSize() noexcept { return GlobalCommGroup()->World(); }

bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); }
[[nodiscard]] bool IsDistributed() noexcept { return GlobalCommGroup()->IsDistributed(); }

[[nodiscard]] bool IsFederated() {
[[nodiscard]] bool IsFederated() noexcept {
return GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsFederated();
}

[[nodiscard]] bool IsEncrypted() noexcept {
return IsFederated() && GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).IsEncrypted();
}

void Print(std::string const& message) {
auto rc = GlobalCommGroup()->Ctx(nullptr, DeviceOrd::CPU()).LogTracker(message);
SafeColl(rc);
Expand Down
4 changes: 3 additions & 1 deletion src/collective/communicator-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ void Finalize();
*
* @return True if the communicator is federated.
*/
[[nodiscard]] bool IsFederated();
[[nodiscard]] bool IsFederated() noexcept;

[[nodiscard]] bool IsEncrypted() noexcept;

/**
* @brief Print the message to the communicator.
Expand Down

0 comments on commit 0c057a1

Please sign in to comment.