Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Nov 14, 2023
1 parent d955c89 commit 73bb99c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 19 deletions.
2 changes: 2 additions & 0 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ sample(const at::Tensor& rowptr,
if (edge_time.has_value()) {
TORCH_CHECK(edge_time.value().is_contiguous(),
"Non-contiguous 'edge_time'");
TORCH_CHECK(seed_time.has_value(), "Seed time needs to be specified");
}
if (seed_time.has_value()) {
TORCH_CHECK(seed_time.value().is_contiguous(),
Expand Down Expand Up @@ -561,6 +562,7 @@ sample(const std::vector<node_type>& node_types,
const at::Tensor& time = kv.value();
TORCH_CHECK(time.is_contiguous(), "Non-contiguous 'edge_time'");

Check warning on line 563 in pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp

View check run for this annotation

Codecov / codecov/patch

pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp#L561-L563

Added lines #L561 - L563 were not covered by tests
}
TORCH_CHECK(seed_time_dict.has_value(), "Seed time needs to be specified");

Check warning on line 565 in pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp

View check run for this annotation

Codecov / codecov/patch

pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp#L565

Added line #L565 was not covered by tests
}
if (seed_time_dict.has_value()) {
for (const auto& kv : seed_time_dict.value()) {
Expand Down
30 changes: 11 additions & 19 deletions test/csrc/sampler/test_neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,34 +189,26 @@ TEST(EdgeLevelTemporalNeighborTest, BasicAssertions) {
/*rowptr=*/rowptr,
/*col=*/col,
/*seed=*/at::arange(2, 4, options),
/*num_neighbors=*/{1, 2},
/*num_neighbors=*/{2, 2},
/*node_time=*/c10::nullopt,
/*edge_time=*/edge_time,
/*seed_time=*/c10::nullopt,
/*seed_time=*/at::arange(5, 7, options),
/*edge_weight=*/c10::nullopt,
/*csc=*/false,
/*replace=*/false,
/*directed=*/true,
/*disjoint=*/true);

std::cout << rowptr << std::endl;
std::cout << col << std::endl;
std::cout << edge_time << std::endl;
std::cout << "==================" << std::endl;
std::cout << std::get<0>(out) << std::endl;
std::cout << std::get<1>(out) << std::endl;
std::cout << std::get<2>(out) << std::endl;

// Expect only the earlier neighbors or the same node to be sampled:
/* auto expected_row = at::tensor({0, 1, 2, 2, 3, 3}, options); */
/* EXPECT_TRUE(at::equal(std::get<0>(out1), expected_row)); */
/* auto expected_col = at::tensor({2, 3, 4, 0, 5, 1}, options); */
/* EXPECT_TRUE(at::equal(std::get<1>(out1), expected_col)); */
/* auto expected_nodes = */
/* at::tensor({0, 2, 1, 3, 0, 1, 1, 2, 0, 0, 1, 1}, options); */
/* EXPECT_TRUE(at::equal(std::get<2>(out1), expected_nodes.view({-1, 2}))); */
/* auto expected_edges = at::tensor({4, 6, 2, 3, 4, 5}, options); */
/* EXPECT_TRUE(at::equal(std::get<3>(out1).value(), expected_edges)); */
auto expected_row = at::tensor({0, 0, 1, 2, 2, 4, 4}, options);
EXPECT_TRUE(at::equal(std::get<0>(out), expected_row));
auto expected_col = at::tensor({2, 3, 4, 5, 0, 6, 1}, options);
EXPECT_TRUE(at::equal(std::get<1>(out), expected_col));
auto expected_nodes =
at::tensor({0, 2, 1, 3, 0, 1, 0, 3, 1, 2, 0, 0, 1, 1}, options);
EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({-1, 2})));
auto expected_edges = at::tensor({4, 5, 6, 2, 3, 4, 5}, options);
EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges));
}

TEST(HeteroNeighborTest, BasicAssertions) {
Expand Down

0 comments on commit 73bb99c

Please sign in to comment.