diff --git a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h index 2f145970ec7c..2ba51d84a693 100644 --- a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h +++ b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h @@ -323,8 +323,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * sampled edges, see arXiv:2210.13339. * @param return_eids Boolean indicating whether edge IDs need to be returned, * typically used when edge features are required. - * @param probs_name An optional string specifying the name of an edge - * attribute. This attribute tensor should contain (unnormalized) + * @param probs_or_mask An optional edge attribute tensor for probablities + * or masks. This attribute tensor should contain (unnormalized) * probabilities corresponding to each neighboring edge of a node. It must be * a 1D floating-point or boolean tensor, with the number of elements * equalling the total number of edges. @@ -339,7 +339,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { torch::optional seeds, torch::optional> seed_offsets, const std::vector& fanouts, bool replace, bool layer, - bool return_eids, torch::optional probs_name, + bool return_eids, torch::optional probs_or_mask, torch::optional random_seed, double seed2_contribution) const; @@ -362,8 +362,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * sampled edges, see arXiv:2210.13339. * @param return_eids Boolean indicating whether edge IDs need to be returned, * typically used when edge features are required. - * @param probs_name An optional string specifying the name of an edge - * attribute, following the same rules as in SampleNeighbors. + * @param probs_or_mask An optional edge attribute tensor for probablities + * or masks, following the same rules as in SampleNeighbors. * @param node_timestamp_attr_name An optional string specifying the name of * the node attribute that contains the timestamp of nodes in the graph. * @param edge_timestamp_attr_name An optional string specifying the name of @@ -377,7 +377,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { const torch::Tensor& input_nodes, const torch::Tensor& input_nodes_timestamp, const std::vector& fanouts, bool replace, bool layer, - bool return_eids, torch::optional probs_name, + bool return_eids, torch::optional probs_or_mask, torch::optional node_timestamp_attr_name, torch::optional edge_timestamp_attr_name) const; diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index e2000b410458..a36404632103 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -805,11 +805,9 @@ c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( torch::optional seeds, torch::optional> seed_offsets, const std::vector& fanouts, bool replace, bool layer, - bool return_eids, torch::optional probs_name, + bool return_eids, torch::optional probs_or_mask, torch::optional random_seed, double seed2_contribution) const { - auto probs_or_mask = this->EdgeAttribute(probs_name); - // If seeds does not have a value, then we expect all arguments to be resident // on the GPU. If seeds has a value, then we expect them to be accessible from // GPU. This is required for the dispatch to work when CUDA is not available. @@ -903,13 +901,12 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( const torch::Tensor& input_nodes, const torch::Tensor& input_nodes_timestamp, const std::vector& fanouts, bool replace, bool layer, - bool return_eids, torch::optional probs_name, + bool return_eids, torch::optional probs_or_mask, torch::optional node_timestamp_attr_name, torch::optional edge_timestamp_attr_name) const { torch::optional> seed_offsets = torch::nullopt; // 1. Get probs_or_mask. - auto probs_or_mask = this->EdgeAttribute(probs_name); - if (probs_name.has_value()) { + if (probs_or_mask.has_value()) { // Note probs will be passed as input for 'torch.multinomial' in deeper // stack, which doesn't support 'torch.half' and 'torch.bool' data types. To // avoid crashes, convert 'probs_or_mask' to 'float32' data type. diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 97502a59ca79..16f2a145287f 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -144,7 +144,7 @@ def num_nodes(self) -> Union[int, Dict[str, int]]: # Heterogenous else: num_nodes_per_type = { - _type: (offset[_idx + 1] - offset[_idx]) + _type: offset[_idx + 1] - offset[_idx] for _type, _idx in self.node_type_to_id.items() } @@ -694,19 +694,20 @@ def sample_neighbors( seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds) elif seeds is None: seed_offsets = self._indptr_node_type_offset_list + probs_or_mask = self.edge_attributes[probs_name] if probs_name else None C_sampled_subgraph = self._sample_neighbors( seeds, seed_offsets, fanouts, replace=replace, - probs_name=probs_name, + probs_or_mask=probs_or_mask, return_eids=return_eids, ) return self._convert_to_sampled_subgraph( C_sampled_subgraph, seed_offsets ) - def _check_sampler_arguments(self, nodes, fanouts, probs_name): + def _check_sampler_arguments(self, nodes, fanouts, probs_or_mask): if nodes is not None: assert nodes.dim() == 1, "Nodes should be 1-D tensor." assert nodes.dtype == self.indices.dtype, ( @@ -731,11 +732,7 @@ def _check_sampler_arguments(self, nodes, fanouts, probs_name): (fanouts >= 0) | (fanouts == -1) ), "Fanouts should consist of values that are either -1 or \ greater than or equal to 0." - if probs_name: - assert ( - probs_name in self.edge_attributes - ), f"Unknown edge attribute '{probs_name}'." - probs_or_mask = self.edge_attributes[probs_name] + if probs_or_mask is not None: assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor." assert ( probs_or_mask.size(0) == self.total_num_edges @@ -755,7 +752,7 @@ def _sample_neighbors( seed_offsets: Optional[list], fanouts: torch.Tensor, replace: bool = False, - probs_name: Optional[str] = None, + probs_or_mask: Optional[torch.Tensor] = None, return_eids: bool = False, ) -> torch.ScriptObject: """Sample neighboring edges of the given nodes and return the induced @@ -789,8 +786,8 @@ def _sample_neighbors( Boolean indicating whether the sample is preformed with or without replacement. If True, a value can be selected multiple times. Otherwise, each value can be selected only once. - probs_name: str, optional - An optional string specifying the name of an edge attribute. This + probs_or_mask: torch.Tensor, optional + An optional tensor of edge attribute for probability or masks. This attribute tensor should contain (unnormalized) probabilities corresponding to each neighboring edge of a node. It must be a 1D floating-point or boolean tensor, with the number of elements @@ -805,7 +802,7 @@ def _sample_neighbors( The sampled C subgraph. """ # Ensure nodes is 1-D tensor. - self._check_sampler_arguments(seeds, fanouts, probs_name) + self._check_sampler_arguments(seeds, fanouts, probs_or_mask) return self._c_csc_graph.sample_neighbors( seeds, seed_offsets, @@ -813,7 +810,7 @@ def _sample_neighbors( replace, False, # is_labor return_eids, - probs_name, + probs_or_mask, None, # random_seed, labor parameter 0, # seed2_contribution, labor_parameter ) @@ -943,7 +940,8 @@ def sample_layer_neighbors( seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds) elif seeds is None: seed_offsets = self._indptr_node_type_offset_list - self._check_sampler_arguments(seeds, fanouts, probs_name) + probs_or_mask = self.edge_attributes[probs_name] if probs_name else None + self._check_sampler_arguments(seeds, fanouts, probs_or_mask) C_sampled_subgraph = self._c_csc_graph.sample_neighbors( seeds, seed_offsets, @@ -951,7 +949,7 @@ def sample_layer_neighbors( replace, True, has_original_eids, - probs_name, + probs_or_mask, random_seed, seed2_contribution, ) @@ -1025,7 +1023,8 @@ def temporal_sample_neighbors( ) # Ensure nodes is 1-D tensor. - self._check_sampler_arguments(nodes, fanouts, probs_name) + probs_or_mask = self.edge_attributes[probs_name] if probs_name else None + self._check_sampler_arguments(nodes, fanouts, probs_or_mask) has_original_eids = ( self.edge_attributes is not None and ORIGINAL_EDGE_ID in self.edge_attributes @@ -1037,7 +1036,7 @@ def temporal_sample_neighbors( replace, False, has_original_eids, - probs_name, + probs_or_mask, node_timestamp_attr_name, edge_timestamp_attr_name, )