diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 61146a614..1068359bc 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -362,6 +362,7 @@ sample(const at::Tensor& rowptr, if (edge_time.has_value()) { TORCH_CHECK(edge_time.value().is_contiguous(), "Non-contiguous 'edge_time'"); + TORCH_CHECK(seed_time.has_value(), "Seed time needs to be specified"); } if (seed_time.has_value()) { TORCH_CHECK(seed_time.value().is_contiguous(), @@ -561,6 +562,7 @@ sample(const std::vector& node_types, const at::Tensor& time = kv.value(); TORCH_CHECK(time.is_contiguous(), "Non-contiguous 'edge_time'"); } + TORCH_CHECK(seed_time_dict.has_value(), "Seed time needs to be specified"); } if (seed_time_dict.has_value()) { for (const auto& kv : seed_time_dict.value()) { diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index 9f8b7bfff..7f501641c 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -189,34 +189,26 @@ TEST(EdgeLevelTemporalNeighborTest, BasicAssertions) { /*rowptr=*/rowptr, /*col=*/col, /*seed=*/at::arange(2, 4, options), - /*num_neighbors=*/{1, 2}, + /*num_neighbors=*/{2, 2}, /*node_time=*/c10::nullopt, /*edge_time=*/edge_time, - /*seed_time=*/c10::nullopt, + /*seed_time=*/at::arange(5, 7, options), /*edge_weight=*/c10::nullopt, /*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true); - std::cout << rowptr << std::endl; - std::cout << col << std::endl; - std::cout << edge_time << std::endl; - std::cout << "==================" << std::endl; - std::cout << std::get<0>(out) << std::endl; - std::cout << std::get<1>(out) << std::endl; - std::cout << std::get<2>(out) << std::endl; - // Expect only the earlier neighbors or the same node to be sampled: - /* auto expected_row = at::tensor({0, 1, 2, 2, 3, 3}, options); */ - /* EXPECT_TRUE(at::equal(std::get<0>(out1), expected_row)); */ - /* auto expected_col = at::tensor({2, 3, 4, 0, 5, 1}, options); */ - /* EXPECT_TRUE(at::equal(std::get<1>(out1), expected_col)); */ - /* auto expected_nodes = */ - /* at::tensor({0, 2, 1, 3, 0, 1, 1, 2, 0, 0, 1, 1}, options); */ - /* EXPECT_TRUE(at::equal(std::get<2>(out1), expected_nodes.view({-1, 2}))); */ - /* auto expected_edges = at::tensor({4, 6, 2, 3, 4, 5}, options); */ - /* EXPECT_TRUE(at::equal(std::get<3>(out1).value(), expected_edges)); */ + auto expected_row = at::tensor({0, 0, 1, 2, 2, 4, 4}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); + auto expected_col = at::tensor({2, 3, 4, 5, 0, 6, 1}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_col)); + auto expected_nodes = + at::tensor({0, 2, 1, 3, 0, 1, 0, 3, 1, 2, 0, 0, 1, 1}, options); + EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes.view({-1, 2}))); + auto expected_edges = at::tensor({4, 5, 6, 2, 3, 4, 5}, options); + EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); } TEST(HeteroNeighborTest, BasicAssertions) {