Skip to content

Commit 5324da2

Browse files
committed
Create test + fix.
1 parent e79726b commit 5324da2

File tree

4 files changed

+49
-5
lines changed

4 files changed

+49
-5
lines changed

BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ test_suite(
7777
"//test/cpp:test_replication",
7878
"//test/cpp:test_tensor",
7979
"//test/cpp:test_xla_sharding",
80+
"//torch_xla/csrc/runtime:runtime_test",
8081
"//torch_xla/csrc/runtime:pjrt_computation_client_test",
8182
# "//torch_xla/csrc/runtime:ifrt_computation_client_test",
8283
],

torch_xla/csrc/runtime/BUILD

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,12 @@ ptxla_cc_test(
524524
# "@xla//xla/tools:hlo_module_loader",
525525
# ],
526526
# )
527+
528+
ptxla_cc_test(
529+
name = "runtime_test",
530+
srcs = ["runtime_test.cpp"],
531+
deps = [
532+
":runtime",
533+
"@tsl//tsl/platform:test_main",
534+
],
535+
)

torch_xla/csrc/runtime/runtime.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ namespace torch_xla::runtime {
1212

1313
std::atomic<bool> g_computation_client_initialized(false);
1414

15-
// Creates a new instance of a `ComputationClient` (e.g. `PjRtComputationClient`),
16-
// and initializes the computation client
15+
// Creates a new instance of a `ComputationClient` (e.g.
16+
// `PjRtComputationClient`), and initializes the computation client
1717
static absl::StatusOr<absl_nonnull std::unique_ptr<ComputationClient>>
1818
InitializeComputationClient() {
1919
if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) {
@@ -22,9 +22,11 @@ InitializeComputationClient() {
2222

2323
std::unique_ptr<ComputationClient> client;
2424

25-
// Disable IFRT right now as it currently crashes.
25+
// TODO: enable IFRT once it's not crashing anymore.
26+
// Ref: https://github.com/pytorch/xla/pull/8267
27+
//
2628
// static bool use_ifrt = sys_util::GetEnvBool("XLA_USE_IFRT", false);
27-
static bool use_ifrt = false;
29+
static const bool use_ifrt = false;
2830
if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") {
2931
if (use_ifrt) {
3032
client = std::make_unique<IfrtComputationClient>();
@@ -40,7 +42,8 @@ InitializeComputationClient() {
4042
}
4143

4244
absl::StatusOr<ComputationClient*> GetComputationClient() {
43-
static absl::StatusOr<absl_nonnull std::unique_ptr<ComputationClient>> maybeClient = InitializeComputationClient();
45+
static absl::StatusOr<absl_nonnull std::unique_ptr<ComputationClient>>
46+
maybeClient = InitializeComputationClient();
4447

4548
if (!maybeClient.ok()) {
4649
return maybeClient.status();
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include "torch_xla/csrc/runtime/runtime.h"
2+
3+
#include <gtest/gtest.h>
4+
5+
namespace torch_xla::runtime {
6+
7+
TEST(RuntimeTest, NullComputationClient) {
8+
auto client = GetComputationClientIfInitialized();
9+
EXPECT_EQ(client, nullptr);
10+
}
11+
12+
TEST(RuntimeTest, GetComputationClientSuccess) {
13+
ComputationClient* client;
14+
15+
client = GetComputationClientIfInitialized();
16+
EXPECT_EQ(client, nullptr);
17+
18+
// Initialize the ComputationClient.
19+
// Check all the APIs return the same valid ComputationClient.
20+
21+
client = GetComputationClientOrDie();
22+
EXPECT_NE(client, nullptr);
23+
24+
auto status = GetComputationClient();
25+
EXPECT_TRUE(status.ok());
26+
EXPECT_EQ(client, status.value());
27+
28+
EXPECT_EQ(client, GetComputationClientIfInitialized());
29+
}
30+
31+
} // namespace torch_xla::runtime

0 commit comments

Comments
 (0)