Skip to content

Commit

Permalink
change cast_various_to_minimum_dtype_gb
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Sep 18, 2024
1 parent b29e5a2 commit 33c6ea8
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 105 deletions.
70 changes: 43 additions & 27 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,9 +1417,9 @@ def get_homogeneous(g, balance_ntypes):
for name in g.edges[etype].data:
if name in [EID, "inner_edge"]:
continue

Check warning on line 1419 in python/dgl/distributed/partition.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
edge_feats[
_etype_tuple_to_str(etype) + "/" + name
] = F.gather_row(g.edges[etype].data[name], local_edges)
edge_feats[_etype_tuple_to_str(etype) + "/" + name] = (
F.gather_row(g.edges[etype].data[name], local_edges)
)
else:
for ntype in g.ntypes:
if len(g.ntypes) > 1:
Expand Down Expand Up @@ -1454,9 +1454,9 @@ def get_homogeneous(g, balance_ntypes):
for name in g.edges[etype].data:
if name in [EID, "inner_edge"]:
continue
edge_feats[
_etype_tuple_to_str(etype) + "/" + name
] = F.gather_row(g.edges[etype].data[name], local_edges)
edge_feats[_etype_tuple_to_str(etype) + "/" + name] = (
F.gather_row(g.edges[etype].data[name], local_edges)
)
# delete `orig_id` from ndata/edata
del part.ndata["orig_id"]
del part.edata["orig_id"]
Expand Down Expand Up @@ -1502,9 +1502,9 @@ def get_homogeneous(g, balance_ntypes):
for part_id, part in parts.items():
part_dir = os.path.join(out_path, "part" + str(part_id))
part_graph_file = os.path.join(part_dir, "graph.dgl")
part_metadata["part-{}".format(part_id)][
"part_graph"
] = os.path.relpath(part_graph_file, out_path)
part_metadata["part-{}".format(part_id)]["part_graph"] = (
os.path.relpath(part_graph_file, out_path)
)
# save DGLGraph
_save_dgl_graphs(
part_graph_file,
Expand Down Expand Up @@ -1600,8 +1600,6 @@ def _save_graph_gb(part_config, part_id, csc_graph):


def cast_various_to_minimum_dtype_gb(
graph,
part_meta,
num_parts,
indptr,
indices,
Expand All @@ -1610,25 +1608,43 @@ def cast_various_to_minimum_dtype_gb(
ntypes,
node_attributes,
edge_attributes,
part_meta=None,
graph=None,
edge_count=None,
node_count=None,
tot_edge_count=None,
tot_node_count=None,
):
"""Cast various data to minimum dtype."""
if graph is not None:
assert part_meta is not None
tot_edge_count = graph.num_edges()
tot_node_count = graph.num_nodes()
node_count = part_meta["num_nodes"]
edge_count = part_meta["num_edges"]
else:
assert tot_edge_count is not None
assert tot_node_count is not None
assert edge_count is not None
assert node_count is not None

# Cast 1: indptr.
indptr = _cast_to_minimum_dtype(graph.num_edges(), indptr)
indptr = _cast_to_minimum_dtype(tot_edge_count, indptr)
# Cast 2: indices.
indices = _cast_to_minimum_dtype(graph.num_nodes(), indices)
indices = _cast_to_minimum_dtype(tot_node_count, indices)
# Cast 3: type_per_edge.
type_per_edge = _cast_to_minimum_dtype(
len(etypes), type_per_edge, field=ETYPE
)
# Cast 4: node/edge_attributes.
predicates = {
NID: part_meta["num_nodes"],
NID: node_count,
"part_id": num_parts,
NTYPE: len(ntypes),
EID: part_meta["num_edges"],
EID: edge_count,
ETYPE: len(etypes),
DGL2GB_EID: part_meta["num_edges"],
GB_DST_ID: part_meta["num_nodes"],
DGL2GB_EID: edge_count,
GB_DST_ID: node_count,
}
for attributes in [node_attributes, edge_attributes]:
for key in attributes:
Expand Down Expand Up @@ -1779,16 +1795,16 @@ def gb_convert_single_dgl_partition(
)

indptr, indices, type_per_edge = cast_various_to_minimum_dtype_gb(
graph,
part_meta,
num_parts,
indptr,
indices,
type_per_edge,
etypes,
ntypes,
node_attributes,
edge_attributes,
graph=graph,
part_meta=part_meta,
num_parts=num_parts,
indptr=indptr,
indices=indices,
type_per_edge=type_per_edge,
etypes=etypes,
ntypes=ntypes,
node_attributes=node_attributes,
edge_attributes=edge_attributes,
)

csc_graph = gb.fused_csc_sampling_graph(
Expand Down
97 changes: 36 additions & 61 deletions tools/distpartitioning/convert_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@

import constants
import dgl
import dgl.backend as F
import dgl.graphbolt as gb
import numpy as np
import torch as th
from dgl import EID, ETYPE, NID, NTYPE
import dgl.backend as F

from dgl.distributed.constants import DGL2GB_EID, GB_DST_ID
from dgl.distributed.partition import (
_cast_to_minimum_dtype,
_etype_str_to_tuple,
_etype_tuple_to_str,
cast_various_to_minimum_dtype_gb,
RESERVED_FIELD_DTYPE,
)
from utils import get_idranges, memory_snapshot
Expand Down Expand Up @@ -262,7 +263,7 @@ def _create_edge_attr_gb(
is_homo = _is_homogeneous(ntypes, etypes)

edge_type_to_id = (
{gb.etype_tuple_to_str(('_N','_E','_N')) : 0}
{gb.etype_tuple_to_str(("_N", "_E", "_N")): 0}
if is_homo
else {
gb.etype_tuple_to_str(etype): etid
Expand Down Expand Up @@ -320,47 +321,6 @@ def remove_attr_gb(
return edata, ndata


def cast_various_to_minimum_dtype_gb(
node_count,
edge_count,
num_parts,
indptr,
indices,
type_per_edge,
etypes,
ntypes,
node_attributes,
edge_attributes,
):
"""Cast various data to minimum dtype."""
# Cast 1: indptr.
indptr = _cast_to_minimum_dtype(edge_count, indptr)
# Cast 2: indices.
indices = _cast_to_minimum_dtype(node_count, indices)
# Cast 3: type_per_edge.
type_per_edge = _cast_to_minimum_dtype(
len(etypes), type_per_edge, field=ETYPE
)
# Cast 4: node/edge_attributes.
predicates = {
NID: node_count,
"part_id": num_parts,
NTYPE: len(ntypes),
EID: edge_count,
ETYPE: len(etypes),
DGL2GB_EID: edge_count,
GB_DST_ID: node_count,
}
for attributes in [node_attributes, edge_attributes]:
for key in attributes:
if key not in predicates:
continue
attributes[key] = _cast_to_minimum_dtype(
predicates[key], attributes[key], field=key
)
return indptr, indices, type_per_edge


def _process_partition_gb(
node_attr,
edge_attr,
Expand All @@ -378,22 +338,26 @@ def _process_partition_gb(
node_attr[k] = F.astype(node_attr[k], dtype)
if k in edge_attr:
edge_attr[k] = F.astype(edge_attr[k], dtype)
indptr,indices,edge_ids=_coo2csc(src_ids,dst_ids)

indptr, indices, edge_ids = _coo2csc(src_ids, dst_ids)
if sort_etypes:
split_size = th.diff(indptr)
split_indices = th.split(type_per_edge, tuple(split_size), dim=0)
sorted_idxs=[]
sorted_idxs = []
for split_indice in split_indices:
sorted_idxs.append(split_indice.sort()[1])

sorted_idx = th.cat(sorted_idxs, dim=0)
sorted_idx=th.repeat_interleave(indptr[:-1], split_size, dim=0)+sorted_idx

sorted_idx = (
th.repeat_interleave(indptr[:-1], split_size, dim=0) + sorted_idx
)

return indptr, indices, edge_ids


def create_graph_object(
tot_node_count,
tot_edge_count,
node_count,
edge_count,
num_parts,
Expand Down Expand Up @@ -457,10 +421,14 @@ def create_graph_object(
Parameters:
-----------
node_count : int
tot_node_count : int
the number of all nodes
edge_count : int
tot_edge_count : int
the number of all edges
node_count : int
the number of nodes in partition
edge_count : int
the number of edges in partition
graph_formats : str
the format of graph
num_parts : int
Expand Down Expand Up @@ -744,7 +712,12 @@ def create_graph_object(

sort_etypes = len(etypes_map) > 1
indptr, indices, csc_edge_ids = _process_partition_gb(
ndata, edata, type_per_edge, part_local_src_id, part_local_dst_id,sort_etypes
ndata,
edata,
type_per_edge,
part_local_src_id,
part_local_dst_id,
sort_etypes,
)
edge_attr, node_attr = remove_attr_gb(
edge_attr=edata, node_attr=ndata, **kwargs
Expand All @@ -753,16 +726,18 @@ def create_graph_object(
attr: edge_attr[attr][csc_edge_ids] for attr in edge_attr.keys()
}
cast_various_to_minimum_dtype_gb(
node_count,
edge_count,
num_parts,
indptr,
indices,
type_per_edge,
etypes,
ntypes,
node_attr,
edge_attr,
node_count=node_count,
edge_count=edge_count,
tot_node_count=tot_node_count,
tot_edge_count=tot_edge_count,
num_parts=num_parts,
indptr=indptr,
indices=indices,
type_per_edge=type_per_edge,
etypes=etypes,
ntypes=ntypes,
node_attributes=node_attr,
edge_attributes=edge_attr,
)
part_graph = gb.fused_csc_sampling_graph(
csc_indptr=indptr,
Expand Down
36 changes: 19 additions & 17 deletions tools/distpartitioning/data_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,21 +285,21 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data, id_lookup):
local_etype_ids.append(rcvd_edge_data[:, 3])
local_eids.append(rcvd_edge_data[:, 4])

Check warning on line 287 in tools/distpartitioning/data_shuffle.py

View workflow job for this annotation

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
edge_data[
constants.GLOBAL_SRC_ID + "/" + str(local_part_id)
] = np.concatenate(local_src_ids)
edge_data[
constants.GLOBAL_DST_ID + "/" + str(local_part_id)
] = np.concatenate(local_dst_ids)
edge_data[
constants.GLOBAL_TYPE_EID + "/" + str(local_part_id)
] = np.concatenate(local_type_eids)
edge_data[
constants.ETYPE_ID + "/" + str(local_part_id)
] = np.concatenate(local_etype_ids)
edge_data[
constants.GLOBAL_EID + "/" + str(local_part_id)
] = np.concatenate(local_eids)
edge_data[constants.GLOBAL_SRC_ID + "/" + str(local_part_id)] = (
np.concatenate(local_src_ids)
)
edge_data[constants.GLOBAL_DST_ID + "/" + str(local_part_id)] = (
np.concatenate(local_dst_ids)
)
edge_data[constants.GLOBAL_TYPE_EID + "/" + str(local_part_id)] = (
np.concatenate(local_type_eids)
)
edge_data[constants.ETYPE_ID + "/" + str(local_part_id)] = (
np.concatenate(local_etype_ids)
)
edge_data[constants.GLOBAL_EID + "/" + str(local_part_id)] = (
np.concatenate(local_eids)
)

# Check if the data was exchanged correctly
local_edge_count = 0
Expand Down Expand Up @@ -1121,7 +1121,6 @@ def gen_dist_partitions(rank, world_size, params):
)
id_map = dgl.distributed.id_map.IdMap(global_nid_ranges)
id_lookup.set_idMap(id_map)

# read input graph files and augment these datastructures with
# appropriate information (global_nid and owner process) for node and edge data
(
Expand Down Expand Up @@ -1315,6 +1314,8 @@ def prepare_local_data(src_data, local_part_id):
)
local_node_data = prepare_local_data(node_data, local_part_id)
local_edge_data = prepare_local_data(edge_data, local_part_id)
tot_node_count = sum(schema_map["num_nodes_per_type"])
tot_edge_count = sum(schema_map["num_edges_per_type"])
(
graph_obj,
ntypes_map_val,
Expand All @@ -1324,9 +1325,10 @@ def prepare_local_data(src_data, local_part_id):
orig_nids,
orig_eids,
) = create_graph_object(
tot_node_count,
tot_edge_count,
node_count,
edge_count,
graph_formats,
params.num_parts,
schema_map,
rank + local_part_id * world_size,
Expand Down

0 comments on commit 33c6ea8

Please sign in to comment.