Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Nov 14, 2023
1 parent 99478a6 commit d955c89
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
4 changes: 2 additions & 2 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ sample(const at::Tensor& rowptr,
"Temporal sampling needs to create disjoint subgraphs");
TORCH_CHECK(!edge_time.has_value() || disjoint,
"Temporal sampling needs to create disjoint subgraphs");
TORCH_CHECK(node_time.has_value() && edge_time.has_value(),
TORCH_CHECK(!(node_time.has_value() && edge_time.has_value()),
"Only one of node-level or edge-level sampling is supported ");

TORCH_CHECK(rowptr.is_contiguous(), "Non-contiguous 'rowptr'");
Expand Down Expand Up @@ -535,7 +535,7 @@ sample(const std::vector<node_type>& node_types,
"Node temporal sampling needs to create disjoint subgraphs");
TORCH_CHECK(!edge_time_dict.has_value() || disjoint,
"Edge temporal sampling needs to create disjoint subgraphs");
TORCH_CHECK(node_time_dict.has_value() && edge_time_dict.has_value(),
TORCH_CHECK(!(node_time_dict.has_value() && edge_time_dict.has_value()),
"Only one of node-level or edge-level sampling is supported ");

for (const auto& kv : rowptr_dict) {
Expand Down
46 changes: 45 additions & 1 deletion test/csrc/sampler/test_neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ TEST(DisjointNeighborTest, BasicAssertions) {
EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges));
}

TEST(TemporalNeighborTest, BasicAssertions) {
TEST(NodeLevelTemporalNeighborTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);

auto graph = cycle_graph(/*num_nodes=*/6, options);
Expand Down Expand Up @@ -175,6 +175,50 @@ TEST(TemporalNeighborTest, BasicAssertions) {
EXPECT_TRUE(at::equal(std::get<3>(out1).value(), std::get<3>(out2).value()));
}

TEST(EdgeLevelTemporalNeighborTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);

auto graph = cycle_graph(/*num_nodes=*/6, options);
auto rowptr = std::get<0>(graph);
auto col = std::get<1>(graph);

// Time is equal to edge ID:
auto edge_time = at::arange(col.numel(), options);

auto out = pyg::sampler::neighbor_sample(
/*rowptr=*/rowptr,
/*col=*/col,
/*seed=*/at::arange(2, 4, options),
/*num_neighbors=*/{1, 2},
/*node_time=*/c10::nullopt,
/*edge_time=*/edge_time,
/*seed_time=*/c10::nullopt,
/*edge_weight=*/c10::nullopt,
/*csc=*/false,
/*replace=*/false,
/*directed=*/true,
/*disjoint=*/true);

std::cout << rowptr << std::endl;
std::cout << col << std::endl;
std::cout << edge_time << std::endl;
std::cout << "==================" << std::endl;
std::cout << std::get<0>(out) << std::endl;
std::cout << std::get<1>(out) << std::endl;
std::cout << std::get<2>(out) << std::endl;

// Expect only the earlier neighbors or the same node to be sampled:
/* auto expected_row = at::tensor({0, 1, 2, 2, 3, 3}, options); */
/* EXPECT_TRUE(at::equal(std::get<0>(out1), expected_row)); */
/* auto expected_col = at::tensor({2, 3, 4, 0, 5, 1}, options); */
/* EXPECT_TRUE(at::equal(std::get<1>(out1), expected_col)); */
/* auto expected_nodes = */
/* at::tensor({0, 2, 1, 3, 0, 1, 1, 2, 0, 0, 1, 1}, options); */
/* EXPECT_TRUE(at::equal(std::get<2>(out1), expected_nodes.view({-1, 2}))); */
/* auto expected_edges = at::tensor({4, 6, 2, 3, 4, 5}, options); */
/* EXPECT_TRUE(at::equal(std::get<3>(out1).value(), expected_edges)); */
}

TEST(HeteroNeighborTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);

Expand Down

0 comments on commit d955c89

Please sign in to comment.