Skip to content

Commit

Permalink
[GraphBolt] pass tensor directly to probs or masks in sampling C++ API (
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying committed Jun 6, 2024
1 parent 4913a7b commit e366260
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 29 deletions.
12 changes: 6 additions & 6 deletions graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param probs_name An optional string specifying the name of an edge
* attribute. This attribute tensor should contain (unnormalized)
* @param probs_or_mask An optional edge attribute tensor for probablities
* or masks. This attribute tensor should contain (unnormalized)
* probabilities corresponding to each neighboring edge of a node. It must be
* a 1D floating-point or boolean tensor, with the number of elements
* equalling the total number of edges.
Expand All @@ -339,7 +339,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<std::string> probs_name,
bool return_eids, torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const;

Expand All @@ -362,8 +362,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param probs_name An optional string specifying the name of an edge
* attribute, following the same rules as in SampleNeighbors.
* @param probs_or_mask An optional edge attribute tensor for probablities
* or masks, following the same rules as in SampleNeighbors.
* @param node_timestamp_attr_name An optional string specifying the name of
* the node attribute that contains the timestamp of nodes in the graph.
* @param edge_timestamp_attr_name An optional string specifying the name of
Expand All @@ -377,7 +377,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
const torch::Tensor& input_nodes,
const torch::Tensor& input_nodes_timestamp,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<std::string> probs_name,
bool return_eids, torch::optional<torch::Tensor> probs_or_mask,
torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const;

Expand Down
9 changes: 3 additions & 6 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -805,11 +805,9 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<std::string> probs_name,
bool return_eids, torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const {
auto probs_or_mask = this->EdgeAttribute(probs_name);

// If seeds does not have a value, then we expect all arguments to be resident
// on the GPU. If seeds has a value, then we expect them to be accessible from
// GPU. This is required for the dispatch to work when CUDA is not available.
Expand Down Expand Up @@ -903,13 +901,12 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
const torch::Tensor& input_nodes,
const torch::Tensor& input_nodes_timestamp,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<std::string> probs_name,
bool return_eids, torch::optional<torch::Tensor> probs_or_mask,
torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const {
torch::optional<std::vector<int64_t>> seed_offsets = torch::nullopt;
// 1. Get probs_or_mask.
auto probs_or_mask = this->EdgeAttribute(probs_name);
if (probs_name.has_value()) {
if (probs_or_mask.has_value()) {
// Note probs will be passed as input for 'torch.multinomial' in deeper
// stack, which doesn't support 'torch.half' and 'torch.bool' data types. To
// avoid crashes, convert 'probs_or_mask' to 'float32' data type.
Expand Down
33 changes: 16 additions & 17 deletions python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def num_nodes(self) -> Union[int, Dict[str, int]]:
# Heterogenous
else:
num_nodes_per_type = {
_type: (offset[_idx + 1] - offset[_idx])
_type: offset[_idx + 1] - offset[_idx]
for _type, _idx in self.node_type_to_id.items()
}

Expand Down Expand Up @@ -694,19 +694,20 @@ def sample_neighbors(
seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds)
elif seeds is None:
seed_offsets = self._indptr_node_type_offset_list
probs_or_mask = self.edge_attributes[probs_name] if probs_name else None
C_sampled_subgraph = self._sample_neighbors(
seeds,
seed_offsets,
fanouts,
replace=replace,
probs_name=probs_name,
probs_or_mask=probs_or_mask,
return_eids=return_eids,
)
return self._convert_to_sampled_subgraph(
C_sampled_subgraph, seed_offsets
)

def _check_sampler_arguments(self, nodes, fanouts, probs_name):
def _check_sampler_arguments(self, nodes, fanouts, probs_or_mask):
if nodes is not None:
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert nodes.dtype == self.indices.dtype, (
Expand All @@ -731,11 +732,7 @@ def _check_sampler_arguments(self, nodes, fanouts, probs_name):
(fanouts >= 0) | (fanouts == -1)
), "Fanouts should consist of values that are either -1 or \
greater than or equal to 0."
if probs_name:
assert (
probs_name in self.edge_attributes
), f"Unknown edge attribute '{probs_name}'."
probs_or_mask = self.edge_attributes[probs_name]
if probs_or_mask is not None:
assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
assert (
probs_or_mask.size(0) == self.total_num_edges
Expand All @@ -755,7 +752,7 @@ def _sample_neighbors(
seed_offsets: Optional[list],
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
probs_or_mask: Optional[torch.Tensor] = None,
return_eids: bool = False,
) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced
Expand Down Expand Up @@ -789,8 +786,8 @@ def _sample_neighbors(
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
probs_name: str, optional
An optional string specifying the name of an edge attribute. This
probs_or_mask: torch.Tensor, optional
An optional tensor of edge attribute for probability or masks. This
attribute tensor should contain (unnormalized) probabilities
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
Expand All @@ -805,15 +802,15 @@ def _sample_neighbors(
The sampled C subgraph.
"""
# Ensure nodes is 1-D tensor.
self._check_sampler_arguments(seeds, fanouts, probs_name)
self._check_sampler_arguments(seeds, fanouts, probs_or_mask)
return self._c_csc_graph.sample_neighbors(
seeds,
seed_offsets,
fanouts.tolist(),
replace,
False, # is_labor
return_eids,
probs_name,
probs_or_mask,
None, # random_seed, labor parameter
0, # seed2_contribution, labor_parameter
)
Expand Down Expand Up @@ -943,15 +940,16 @@ def sample_layer_neighbors(
seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds)
elif seeds is None:
seed_offsets = self._indptr_node_type_offset_list
self._check_sampler_arguments(seeds, fanouts, probs_name)
probs_or_mask = self.edge_attributes[probs_name] if probs_name else None
self._check_sampler_arguments(seeds, fanouts, probs_or_mask)
C_sampled_subgraph = self._c_csc_graph.sample_neighbors(
seeds,
seed_offsets,
fanouts.tolist(),
replace,
True,
has_original_eids,
probs_name,
probs_or_mask,
random_seed,
seed2_contribution,
)
Expand Down Expand Up @@ -1025,7 +1023,8 @@ def temporal_sample_neighbors(
)

# Ensure nodes is 1-D tensor.
self._check_sampler_arguments(nodes, fanouts, probs_name)
probs_or_mask = self.edge_attributes[probs_name] if probs_name else None
self._check_sampler_arguments(nodes, fanouts, probs_or_mask)
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
Expand All @@ -1037,7 +1036,7 @@ def temporal_sample_neighbors(
replace,
False,
has_original_eids,
probs_name,
probs_or_mask,
node_timestamp_attr_name,
edge_timestamp_attr_name,
)
Expand Down

0 comments on commit e366260

Please sign in to comment.