From 8ac2c50466b291411cc584aa65dd2183f5c9ad36 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Mon, 19 Apr 2021 17:20:44 +0100 Subject: [PATCH 01/61] Projector: DistIR function -> per-rank function --- dist_ir/executor/__init__.py | 1 + dist_ir/executor/rank_projector.py | 105 +++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 dist_ir/executor/rank_projector.py diff --git a/dist_ir/executor/__init__.py b/dist_ir/executor/__init__.py index 65509f53..8fa99963 100644 --- a/dist_ir/executor/__init__.py +++ b/dist_ir/executor/__init__.py @@ -2,3 +2,4 @@ from .sequential_executor import SequentialExecutor from .type_inference import infer_types from .absint import AbstractInterpreter, AbstractState +from .rank_projector import project diff --git a/dist_ir/executor/rank_projector.py b/dist_ir/executor/rank_projector.py new file mode 100644 index 00000000..110d8c4c --- /dev/null +++ b/dist_ir/executor/rank_projector.py @@ -0,0 +1,105 @@ +from collections import defaultdict +from dist_ir.executor.type_inference import TypePropRegister +from typing import Any, Dict, Sequence + +from ..ir import Function, FunctionMaker, Device, Op +from ..ir.type import Type, Tensor +from .absint import AbstractState, AbstractInterpreter + + +class ProjectorState(AbstractState): + def __init__(self, function: Function, inputs: Sequence[Any]): + AbstractState.__init__(self, function, inputs) + self.per_rank_fns: Dict[Device, FunctionMaker] = defaultdict(FunctionMaker) + + +def _get_input_devices(op: Op): + return list(set(x.type.device for x in op.inputs)) + + +# TODO should projectors just get the per_rank_fns dict instead of full state? + + +def _identity_projector(op: Op, state: ProjectorState): + """Projects op unchanged to its device's per-rank program. + The inputs of op must all be on a single device. + """ + devices = _get_input_devices(op) + assert len(devices) == 1 and devices[0] is not None + + state.per_rank_fns[devices[0]].ops.append(op) + # state.per_rank_fns[d].add_op(op.op_type, name=op.name, inputs=op.inputs, ) + + +def _mpi_allgather_projector(op: Op, state: ProjectorState): + assert len(op.inputs) == len(op.outputs) + for in_v, out_v in zip(op.inputs, op.outputs): + assert in_v.type.device == out_v.type.device + d = in_v.type.device + + new_op = Op( + "MPIAllgather", + inputs=(in_v,), + output_values=(out_v,), + attributes=op.attributes, + ) + state.per_rank_fns[d].ops.append(new_op) + + +ProjectorRegister = { + ("MatMul", (Tensor, Tensor)): _identity_projector, + ("MPIAllgather", (Tensor,) * 2): _mpi_allgather_projector, + ("MPIAllgather", (Tensor,) * 4): _mpi_allgather_projector, + ("MPIAllgather", (Tensor,) * 8): _mpi_allgather_projector, + ("MPIAllgather", (Tensor,) * 16): _mpi_allgather_projector, +} + + +def _create_semantics(type_prop_register, projector_register): + """Creates a semantics for AbstractInterpreter by combining a register of + projector functions and the type propagation register. + """ + + def convert_impl(type_prop_fn, projector): + def semantics(op: Op, state: AbstractState): + # Find the op's inputs in state's environment + inputs = tuple(state.env[v] for v in op.inputs) + # Run the type propagation function + outputs = type_prop_fn(op, *inputs) + + # Write outputs to state's environment + if not isinstance(outputs, tuple): + outputs = (outputs,) + for x, val in zip(op.outputs, outputs): + state.env[x] = val + + # Project op and add to appropriate per-rank function + projector(op, state) + + return semantics + + signatures = set(projector_register.keys()).intersection(type_prop_register.keys()) + + return { + f: convert_impl(type_prop_register[f], projector_register[f]) + for f in signatures + } + + +Projector = AbstractInterpreter( + AbstractState=ProjectorState, + semantics=_create_semantics(TypePropRegister, ProjectorRegister), +) + + +def project(fn: Function, input_types: Sequence[Type]): + """Project fn to a sequence of per-rank functions.""" + state = ProjectorState(fn, input_types) + + # Project fn's inputs to each per-rank fn: + for v in fn.inputs: + state.per_rank_fns[v.type.device].inputs.append(v) + + state = Projector.interpret(fn, input_types, state=state) + + return {d: state.per_rank_fns[d].finalize() for d in state.per_rank_fns} From 193dc848de988fbb17605304c7609be92eacac9c Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Mon, 19 Apr 2021 17:22:54 +0100 Subject: [PATCH 02/61] Fix bug: constructing variadic ops when specifying output_values --- dist_ir/ir/op.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/dist_ir/ir/op.py b/dist_ir/ir/op.py index bd1cae25..5c6f21c3 100644 --- a/dist_ir/ir/op.py +++ b/dist_ir/ir/op.py @@ -44,6 +44,13 @@ def __post_init__(self, output_names, output_types, output_values): f"Op {self.name} ({self.op_type}) has {len(self.inputs)} inputs; " f"{num_input_types} expected" ) + + if output_values is not None: + object.__setattr__( + self, "outputs", output_values + ) # Can't assign to frozen field + else: + # Create the correct number of output values with appropriate types # Number of outputs is given by OpRegister if OpRegister[self.op_type].variadic_outputs: if output_names is None: @@ -54,14 +61,6 @@ def __post_init__(self, output_names, output_types, output_values): num_outputs = len(output_names) else: num_outputs = OpRegister[self.op_type].num_outputs - - if output_values is not None: - object.__setattr__( - self, "outputs", output_values - ) # Can't assign to frozen field - else: - # Create the correct number of output values with appropriate types - # if self.outputs is None: if output_names is None: output_names = [f"{self.name}_out_{i}" for i in range(num_outputs)] elif len(output_names) != num_outputs: From c5bfbd9295db9339eab9b283a65d85b5bea4a90d Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Mon, 19 Apr 2021 17:23:12 +0100 Subject: [PATCH 03/61] Fix sequential executor docstring --- dist_ir/executor/sequential_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dist_ir/executor/sequential_executor.py b/dist_ir/executor/sequential_executor.py index 9193daf5..5bd209a7 100644 --- a/dist_ir/executor/sequential_executor.py +++ b/dist_ir/executor/sequential_executor.py @@ -47,7 +47,7 @@ def compute(self, function: Function, inputs: Sequence[Any]) -> Dict[Value, Any] inputs: A sequence of input data represented in the specified backend. Returns: - A map from output value to output data. + A tuple of outputs. """ state = self.interpreter.interpret(function, inputs) return tuple(state.env[v] for v in function.outputs) From d2f9e0a84064309408cf9fe1d22fa41c28776fc7 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 25 Apr 2021 15:08:34 +0100 Subject: [PATCH 04/61] A distributed PyTorch backend --- dist_ir/backend/__init__.py | 1 + dist_ir/backend/torch.py | 98 +++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 dist_ir/backend/__init__.py create mode 100644 dist_ir/backend/torch.py diff --git a/dist_ir/backend/__init__.py b/dist_ir/backend/__init__.py new file mode 100644 index 00000000..abef05ae --- /dev/null +++ b/dist_ir/backend/__init__.py @@ -0,0 +1 @@ +from . import torch diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py new file mode 100644 index 00000000..04dda474 --- /dev/null +++ b/dist_ir/backend/torch.py @@ -0,0 +1,98 @@ +import os +from tempfile import mkdtemp +from typing import Any, Tuple + +import torch +import torch.distributed as dist +from torch import fx +from torch.multiprocessing import Process + +from ..ir import Function + + +# TODO at op creation time, enforce MPIAllgather ops attributes +def _allgather(x_i, world_size=None, dim=0): + xs = [torch.zeros_like(x_i) for _ in range(world_size)] + dist.all_gather(xs, x_i) + x = torch.cat(xs, dim=dim) + return x + + +_op_to_torch = { + "MatMul": torch.matmul, + "MPIAllgather": _allgather, +} + + +def function_to_module(fn: Function) -> torch.nn.Module: + g = fx.Graph() + value_map = {} + + # TODO need to check that fn has unique value names + + # Convert inputs + for v in fn.inputs: + value_map[v] = g.placeholder(v.name) + + # Convert ops + for op in fn.ops: + inputs = tuple(value_map[v] for v in op.inputs) + assert len(op.outputs) == 1, "TODO how to handle multiple outputs in fx" + kwargs = None if op.attributes is None else {**op.attributes} + output = g.call_function(_op_to_torch[op.op_type], inputs, kwargs) + value_map[op.outputs[0]] = output + + # Convert outputs + for v in fn.outputs: + g.output(value_map[v]) + + return fx.GraphModule({}, g) + + +def _init_process(rank, size, module, io_dir, backend): + """ Initialize the distributed environment. """ + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + dist.init_process_group(backend, rank=rank, world_size=size) + + per_rank_inputs = torch.load(os.path.join(io_dir, f"in.{rank}.pt")) + + # TODO time the next line only + res = module(*per_rank_inputs) + + torch.save(res, os.path.join(io_dir, f"out.{rank}.pt")) + + +def run_multiprocesses( + per_rank_modules: Tuple[torch.nn.Module], + per_rank_inputs: Tuple[Any], + backend="gloo", +): + assert len(per_rank_modules) == len(per_rank_inputs) + world_size = len(per_rank_modules) + + io_dir = mkdtemp() + print("run_multiprocess: saving I/O to:", io_dir) + # Save inputs for each per-rank function: + # TODO lowered pytorch file numbers devices 0...num_devices-1 + for d, inps in enumerate(per_rank_inputs): + torch.save(inps, os.path.join(io_dir, f"in.{d}.pt")) + + processes = [] + for rank, per_rank_module in enumerate(per_rank_modules): + p = Process( + target=_init_process, + args=(rank, world_size, per_rank_module, io_dir, backend), + ) + p.start() + processes.append(p) + + for p in processes: + p.join() + + # Load outputs: + per_rank_outputs = [ + torch.load(os.path.join(io_dir, f"out.{d}.pt")) for d in range(world_size) + ] + + return per_rank_outputs \ No newline at end of file From d9cc787fd788351356831d5194408b3fa8404d9f Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 25 Apr 2021 15:08:45 +0100 Subject: [PATCH 05/61] Upgrade PyTorch version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index fccdd866..be6b197a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ frozendict >= 1.2 numpy >= 1.19 onnx >= 1.7.0 -torch >= 1.6.0 +torch >= 1.8.0 prettyprinter >= 0.18.0 From a2a4cc22ead2ca89b82a60b959497a6151f0d061 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 25 Apr 2021 15:18:20 +0100 Subject: [PATCH 06/61] Add test: one-weird-trick (matmul version) --- test/test_pytorch_backend.py | 152 +++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 test/test_pytorch_backend.py diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py new file mode 100644 index 00000000..7513cd40 --- /dev/null +++ b/test/test_pytorch_backend.py @@ -0,0 +1,152 @@ +import numpy as np +import torch + +from dist_ir.backend.torch import function_to_module, run_multiprocesses +from dist_ir.executor import SequentialExecutor +from dist_ir.executor.rank_projector import project +from dist_ir.executor.type_inference import infer_types +from dist_ir.ir import Device, FunctionMaker, cpprint, Value +from dist_ir.ir.type import Float, Tensor + + +def create_owt_model(num_devices, num_layers): + assert num_layers % 2 == 0 + + fn = FunctionMaker() + + # Inputs + weights = {} + xs = {} + for l in range(num_layers): + for d in range(1, num_devices + 1): + weights[l, d] = fn.add_input_value(f"w{l}_{d}", None) + for d in range(1, num_devices + 1): + xs[d] = fn.add_input_value(f"x_{d}", None) + + # Data parallel conv blocks: (using MatMuls for now) + hs = [] + for d in range(1, num_devices + 1): + h = xs[d] + for l in range(num_layers // 2): + h = fn.add_op( + "MatMul", inputs=[h, weights[l, d]], output_names=[f"h{l}_{d}"] + ) + hs.append(h) + + # Allgather the activations + as_names = [f"hh{num_layers//2-1}_{d}" for d in range(1, num_devices + 1)] + hs = fn.add_op( + "MPIAllgather", + inputs=hs, + output_names=as_names, + attributes={"dim": 0, "world_size": num_devices}, + ) + + # Model parallel fully-connected layers: (again, MatMuls for now) + hs = hs + for l in range(num_layers // 2, num_layers): + h_is = [] + for d in range(1, num_devices + 1): + h_is.append( + fn.add_op( + "MatMul", + inputs=[hs[d - 1], weights[l, d]], + output_names=[f"h{l}_{d}"], + ) + ) + if l == num_layers - 1: + hs = h_is + else: + out_names = [f"hh{l}_{d}" for d in range(1, num_devices + 1)] + hs = fn.add_op( + "MPIAllgather", + inputs=h_is, + output_names=out_names, + attributes={"dim": 1, "world_size": num_devices}, + ) + + fn.set_outputs(hs) + return fn.finalize() + + +def test_owt(num_devices, num_layers): + fn = create_owt_model(num_devices, num_layers) + + devices = [Device(0, "cpu")] + for d in range(1, num_devices + 1): + devices.append(Device(d, "gpu")) + + batch_size = 8 + hidden_dim = 4 # using this for input/output dim also + + input_vals = [] + for l in range(num_layers): + for d in range(1, num_devices + 1): + if l < num_layers // 2: + shape = (hidden_dim, hidden_dim) + else: + shape = (hidden_dim, hidden_dim // num_devices) + # w{l}_{d}: + input_vals.append(Value("", Tensor(Float(), shape, devices[d]))) + for d in range(1, num_devices + 1): + # x_{d}: + shape = (batch_size // num_devices, hidden_dim) + input_vals.append(Value("", Tensor(Float(), shape, devices[d]))) + + # Test type inference: + fn = infer_types(fn, input_vals) + cpprint(fn) + assert all( + v.type.shape == (batch_size, hidden_dim // num_devices) for v in fn.outputs + ) + + # Test with sequential executor: + np.random.seed(0) + weights = [np.random.randn(hidden_dim, hidden_dim) for l in range(num_layers)] + x = np.random.randn(batch_size, hidden_dim) + + # Split inputs for distributed function + input_arrays = [] + for l in range(num_layers): + if l < num_layers // 2: + for d in range(1, num_devices + 1): + input_arrays.append(weights[l]) + else: + input_arrays += np.split(weights[l], num_devices, axis=1) + input_arrays += np.split(x, num_devices) + ex = SequentialExecutor("numpy") + output_arrays = ex.compute(fn, input_arrays) + + # Expected results + y = x + for l in range(num_layers): + y = np.matmul(y, weights[l]) + ys = np.split(y, num_devices, axis=1) + assert all(np.allclose(y, o) for y, o in zip(ys, output_arrays)) + + # Per-rank projection: + proj = project(fn, tuple(v.type for v in input_vals)) + for d, f_d in proj.items(): + print() + print(d) + cpprint(f_d) + + # Make inputs for each per-rank function: + per_rank_inputs = [[] for _ in range(num_devices)] + for v, a in zip(fn.inputs, input_arrays): + per_rank_inputs[v.type.device.device_id - 1].append(torch.tensor(a)) + + # Translate per-rank functions into torch GraphModules: + per_rank_modules = [function_to_module(f_d) for d, f_d in proj.items()] + for d, gm in enumerate(per_rank_modules): + print(f"\n{d}\n{gm.graph}") + + # Run per-rank modules using PyTorch backend: + per_rank_outputs = run_multiprocesses(per_rank_modules, per_rank_inputs) + + # Check outputs: + assert all(np.allclose(y, o) for y, o in zip(per_rank_outputs, output_arrays)) + + +if __name__ == "__main__": + test_owt(2, 4) \ No newline at end of file From 7ab770d36b539c70796cd52c52aaefc288f0b3d2 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 25 Apr 2021 15:24:18 +0100 Subject: [PATCH 07/61] Refactor run_multiprocess --- dist_ir/backend/torch.py | 44 +++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 04dda474..632d0bc5 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -1,5 +1,5 @@ import os -from tempfile import mkdtemp +from tempfile import TemporaryDirectory from typing import Any, Tuple import torch @@ -49,20 +49,6 @@ def function_to_module(fn: Function) -> torch.nn.Module: return fx.GraphModule({}, g) -def _init_process(rank, size, module, io_dir, backend): - """ Initialize the distributed environment. """ - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = "29500" - dist.init_process_group(backend, rank=rank, world_size=size) - - per_rank_inputs = torch.load(os.path.join(io_dir, f"in.{rank}.pt")) - - # TODO time the next line only - res = module(*per_rank_inputs) - - torch.save(res, os.path.join(io_dir, f"out.{rank}.pt")) - - def run_multiprocesses( per_rank_modules: Tuple[torch.nn.Module], per_rank_inputs: Tuple[Any], @@ -71,19 +57,30 @@ def run_multiprocesses( assert len(per_rank_modules) == len(per_rank_inputs) world_size = len(per_rank_modules) - io_dir = mkdtemp() - print("run_multiprocess: saving I/O to:", io_dir) + io_dir = TemporaryDirectory() + # print("run_multiprocess: saving I/O to:", io_dir.name) + + def run_process(rank, module): + """ Initialize the distributed environment. """ + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + dist.init_process_group(backend, rank=rank, world_size=world_size) + + per_rank_inputs = torch.load(os.path.join(io_dir.name, f"in.{rank}.pt")) + + # TODO time the next line only + res = module(*per_rank_inputs) + + torch.save(res, os.path.join(io_dir.name, f"out.{rank}.pt")) + # Save inputs for each per-rank function: # TODO lowered pytorch file numbers devices 0...num_devices-1 for d, inps in enumerate(per_rank_inputs): - torch.save(inps, os.path.join(io_dir, f"in.{d}.pt")) + torch.save(inps, os.path.join(io_dir.name, f"in.{d}.pt")) processes = [] for rank, per_rank_module in enumerate(per_rank_modules): - p = Process( - target=_init_process, - args=(rank, world_size, per_rank_module, io_dir, backend), - ) + p = Process(target=run_process, args=(rank, per_rank_module)) p.start() processes.append(p) @@ -92,7 +89,8 @@ def run_multiprocesses( # Load outputs: per_rank_outputs = [ - torch.load(os.path.join(io_dir, f"out.{d}.pt")) for d in range(world_size) + torch.load(os.path.join(io_dir.name, f"out.{d}.pt")) for d in range(world_size) ] + io_dir.cleanup() return per_rank_outputs \ No newline at end of file From 4e37ce3ca6e1cbc51f82fa044d923d5b7ca75f4e Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 25 Apr 2021 15:32:50 +0100 Subject: [PATCH 08/61] Make run_multiprocess take Functions not nn.Modules --- dist_ir/backend/torch.py | 13 ++++++++----- test/test_pytorch_backend.py | 11 +++-------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 632d0bc5..10fba052 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -50,12 +50,15 @@ def function_to_module(fn: Function) -> torch.nn.Module: def run_multiprocesses( - per_rank_modules: Tuple[torch.nn.Module], - per_rank_inputs: Tuple[Any], - backend="gloo", + per_rank_functions: Tuple[Function], per_rank_inputs: Tuple[Any], backend="gloo" ): - assert len(per_rank_modules) == len(per_rank_inputs) - world_size = len(per_rank_modules) + assert len(per_rank_functions) == len(per_rank_inputs) + world_size = len(per_rank_functions) + + # Convert per-rank DistIR functions to torch.nn.Modules: + per_rank_modules = list(map(function_to_module, per_rank_functions)) + for d, gm in enumerate(per_rank_modules): + print(f"{d}\n{gm.graph}\n") io_dir = TemporaryDirectory() # print("run_multiprocess: saving I/O to:", io_dir.name) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 7513cd40..2b880425 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -125,8 +125,8 @@ def test_owt(num_devices, num_layers): assert all(np.allclose(y, o) for y, o in zip(ys, output_arrays)) # Per-rank projection: - proj = project(fn, tuple(v.type for v in input_vals)) - for d, f_d in proj.items(): + per_rank_fns = project(fn, tuple(v.type for v in input_vals)) + for d, f_d in per_rank_fns.items(): print() print(d) cpprint(f_d) @@ -136,13 +136,8 @@ def test_owt(num_devices, num_layers): for v, a in zip(fn.inputs, input_arrays): per_rank_inputs[v.type.device.device_id - 1].append(torch.tensor(a)) - # Translate per-rank functions into torch GraphModules: - per_rank_modules = [function_to_module(f_d) for d, f_d in proj.items()] - for d, gm in enumerate(per_rank_modules): - print(f"\n{d}\n{gm.graph}") - # Run per-rank modules using PyTorch backend: - per_rank_outputs = run_multiprocesses(per_rank_modules, per_rank_inputs) + per_rank_outputs = run_multiprocesses(per_rank_fns.values(), per_rank_inputs) # Check outputs: assert all(np.allclose(y, o) for y, o in zip(per_rank_outputs, output_arrays)) From 957681138f70b9a0b78e613fd127764d833cd1f8 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 25 Apr 2021 15:38:20 +0100 Subject: [PATCH 09/61] End-of-file newlines --- dist_ir/backend/torch.py | 2 +- test/test_pytorch_backend.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 10fba052..b7cc1c54 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -96,4 +96,4 @@ def run_process(rank, module): ] io_dir.cleanup() - return per_rank_outputs \ No newline at end of file + return per_rank_outputs diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 2b880425..d94c452c 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -144,4 +144,4 @@ def test_owt(num_devices, num_layers): if __name__ == "__main__": - test_owt(2, 4) \ No newline at end of file + test_owt(2, 4) From c4ca1a0a95f9d308c20d8dedf3336bacbf9568ae Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Tue, 27 Apr 2021 16:58:56 +0100 Subject: [PATCH 10/61] Black --- dist_ir/backend/torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index b7cc1c54..bff00064 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -64,7 +64,7 @@ def run_multiprocesses( # print("run_multiprocess: saving I/O to:", io_dir.name) def run_process(rank, module): - """ Initialize the distributed environment. """ + """Initialize the distributed environment.""" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "29500" dist.init_process_group(backend, rank=rank, world_size=world_size) From a09566518000c7f984fa09079b9b506e93568198 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Tue, 27 Apr 2021 17:09:29 +0100 Subject: [PATCH 11/61] Parametrize pytest test_owt --- test/test_pytorch_backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index d94c452c..d09e0813 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import torch from dist_ir.backend.torch import function_to_module, run_multiprocesses @@ -69,6 +70,7 @@ def create_owt_model(num_devices, num_layers): return fn.finalize() +@pytest.mark.parametrize(["num_devices", "num_layers"], [(2, 4)]) def test_owt(num_devices, num_layers): fn = create_owt_model(num_devices, num_layers) From 3279ca1c8980616e4163b25e76a65c180d787a02 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Wed, 28 Apr 2021 00:34:58 +0100 Subject: [PATCH 12/61] Revert "Fix bug: constructing variadic ops when specifying output_values" This reverts commit 193dc848de988fbb17605304c7609be92eacac9c. --- dist_ir/ir/op.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/dist_ir/ir/op.py b/dist_ir/ir/op.py index 5c6f21c3..bd1cae25 100644 --- a/dist_ir/ir/op.py +++ b/dist_ir/ir/op.py @@ -44,13 +44,6 @@ def __post_init__(self, output_names, output_types, output_values): f"Op {self.name} ({self.op_type}) has {len(self.inputs)} inputs; " f"{num_input_types} expected" ) - - if output_values is not None: - object.__setattr__( - self, "outputs", output_values - ) # Can't assign to frozen field - else: - # Create the correct number of output values with appropriate types # Number of outputs is given by OpRegister if OpRegister[self.op_type].variadic_outputs: if output_names is None: @@ -61,6 +54,14 @@ def __post_init__(self, output_names, output_types, output_values): num_outputs = len(output_names) else: num_outputs = OpRegister[self.op_type].num_outputs + + if output_values is not None: + object.__setattr__( + self, "outputs", output_values + ) # Can't assign to frozen field + else: + # Create the correct number of output values with appropriate types + # if self.outputs is None: if output_names is None: output_names = [f"{self.name}_out_{i}" for i in range(num_outputs)] elif len(output_names) != num_outputs: From c84c61d8dd64712ee11f2b406dfd5d5a9e39ead1 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Wed, 28 Apr 2021 00:36:18 +0100 Subject: [PATCH 13/61] Fix Op constructor: better handling of variadic ops/pre-created outputs --- dist_ir/ir/op.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/dist_ir/ir/op.py b/dist_ir/ir/op.py index bd1cae25..3fd2e57d 100644 --- a/dist_ir/ir/op.py +++ b/dist_ir/ir/op.py @@ -23,6 +23,20 @@ class Op: output_values: InitVar[Tuple[Value]] = None def __post_init__(self, output_names, output_types, output_values): + # Check output_{names,types,values} have same length + given_num_outputs = set( + len(x) for x in [output_names, output_types, output_values] if x is not None + ) + if len(given_num_outputs) == 0: + given_num_outputs = None + elif len(given_num_outputs) == 1: + given_num_outputs = list(given_num_outputs)[0] + else: + raise ValueError( + "output_{names,types,values} must have same length. Got:\n" + f"{output_names}\n{output_types}\n{output_values}" + ) + if self.op_type == "Pmap": # Handle pmap specially assert len(self.subfunctions) == 1 @@ -46,22 +60,18 @@ def __post_init__(self, output_names, output_types, output_values): ) # Number of outputs is given by OpRegister if OpRegister[self.op_type].variadic_outputs: - if output_names is None: + if given_num_outputs is None: raise ValueError( f"Op {self.name} ({self.op_type}) has variadic " - f"outputs, so output names must be specified" + "outputs, so one of output_{names,values} must be specified" ) - num_outputs = len(output_names) + num_outputs = given_num_outputs else: num_outputs = OpRegister[self.op_type].num_outputs + assert given_num_outputs is None or num_outputs == given_num_outputs - if output_values is not None: - object.__setattr__( - self, "outputs", output_values - ) # Can't assign to frozen field - else: + if output_values is None: # Create the correct number of output values with appropriate types - # if self.outputs is None: if output_names is None: output_names = [f"{self.name}_out_{i}" for i in range(num_outputs)] elif len(output_names) != num_outputs: From bb604f6604da325b267603030d9007e50beb755d Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Wed, 28 Apr 2021 17:07:28 +0100 Subject: [PATCH 14/61] Add support for Relu --- dist_ir/backend/torch.py | 1 + dist_ir/executor/rank_projector.py | 1 + 2 files changed, 2 insertions(+) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index bff00064..78d5d41c 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -20,6 +20,7 @@ def _allgather(x_i, world_size=None, dim=0): _op_to_torch = { "MatMul": torch.matmul, + "Relu": torch.relu, "MPIAllgather": _allgather, } diff --git a/dist_ir/executor/rank_projector.py b/dist_ir/executor/rank_projector.py index 110d8c4c..a45b61d4 100644 --- a/dist_ir/executor/rank_projector.py +++ b/dist_ir/executor/rank_projector.py @@ -48,6 +48,7 @@ def _mpi_allgather_projector(op: Op, state: ProjectorState): ProjectorRegister = { ("MatMul", (Tensor, Tensor)): _identity_projector, + ("Relu", (Tensor,)): _identity_projector, ("MPIAllgather", (Tensor,) * 2): _mpi_allgather_projector, ("MPIAllgather", (Tensor,) * 4): _mpi_allgather_projector, ("MPIAllgather", (Tensor,) * 8): _mpi_allgather_projector, From 8d8bbda609ceed47bfc9edefe6d4844827f07bf4 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Thu, 29 Apr 2021 09:36:09 +0100 Subject: [PATCH 15/61] Add DP test --- examples/mlp.py | 69 ++++++++++++++++++++++++++++++++++++ test/test_pytorch_backend.py | 59 ++++++++++++++++++++++++++++-- 2 files changed, 126 insertions(+), 2 deletions(-) diff --git a/examples/mlp.py b/examples/mlp.py index 8962bd5b..7e79643f 100644 --- a/examples/mlp.py +++ b/examples/mlp.py @@ -54,3 +54,72 @@ def mlp(batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, device output_names=[f"dy{i}", f"dw{chr(ord('A')+i)}"], ) return function.finalize() + + +def mlp_inference( + batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, device +): + function = FunctionMaker(name="mlp") + weights = [] + for i in range(num_hidden_layers - 1): + w = function.add_input_value( + f"w{chr(ord('A')+i)}", + Tensor(dtype=Float(), shape=(input_dim, hidden_dim), device=device), + ) + weights.append(w) + w = function.add_input_value( + f"w{chr(ord('A')+i+1)}", + Tensor(dtype=Float(), shape=(hidden_dim, output_dim), device=device), + ) + weights.append(w) + x = function.add_input_value( + "x", + Tensor(dtype=Float(), shape=(batch_size, input_dim), device=device), + ) + + a = x + for i, weight in enumerate(weights): + y = function.add_op("MatMul", inputs=[a, weight], output_names=[f"y{i}"]) + a = function.add_op("Relu", inputs=[y], output_names=[f"a{i}"]) + + return function.finalize() + + +def mlp_inference_dp( + batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, devices +): + num_devices = len(devices) + assert batch_size % num_devices == 0 + function = FunctionMaker(name="mlp") + weights = {} + x = {} + for d in devices: + for i in range(num_hidden_layers - 1): + weights[i, d] = function.add_input_value( + f"w{chr(ord('A')+i)}_{d.device_id}", + Tensor(dtype=Float(), shape=(input_dim, hidden_dim), device=d), + ) + weights[num_hidden_layers - 1, d] = function.add_input_value( + f"w{chr(ord('A')+i+1)}_{d.device_id}", + Tensor(dtype=Float(), shape=(hidden_dim, output_dim), device=d), + ) + x[d] = function.add_input_value( + f"x_{d.device_id}", + Tensor( + dtype=Float(), shape=(batch_size // num_devices, input_dim), device=d + ), + ) + + a = x + for i in range(num_hidden_layers): + for d in devices: + y = function.add_op( + "MatMul", + inputs=[a[d], weights[i, d]], + output_names=[f"y{i}_{d.device_id}"], + ) + a[d] = function.add_op( + "Relu", inputs=[y], output_names=[f"a{i}_{d.device_id}"] + ) + + return function.finalize() diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index d09e0813..cd3a3baf 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -2,12 +2,14 @@ import pytest import torch -from dist_ir.backend.torch import function_to_module, run_multiprocesses +from dist_ir.backend.torch import run_multiprocesses from dist_ir.executor import SequentialExecutor from dist_ir.executor.rank_projector import project from dist_ir.executor.type_inference import infer_types from dist_ir.ir import Device, FunctionMaker, cpprint, Value from dist_ir.ir.type import Float, Tensor +from dist_ir.transforms import mlp_dhp_transform +from examples.mlp import mlp_inference, mlp_inference_dp def create_owt_model(num_devices, num_layers): @@ -145,5 +147,58 @@ def test_owt(num_devices, num_layers): assert all(np.allclose(y, o) for y, o in zip(per_rank_outputs, output_arrays)) +# TODO get DHP transform to work on mlp_inference and try running on backend +# def test_mlp_grid_search(): +# devices = [Device(d, "gpu") for d in range(3)] +# +# f = mlp_inference(4, 6, 6, 6, 4, devices[0]) +# f = infer_types(f, f.inputs) +# +# f_dist = mlp_dhp_transform(f, 2, 1, 1, devices, 1) + + +def test_dp_mlp(): + num_devices = 2 + num_layers = 4 + batch_size = 4 + hidden_dim = 6 # Also input/output dim for simplicity + devices = [Device(d, "gpu") for d in range(num_devices + 1)] + + fn = mlp_inference_dp( + batch_size, hidden_dim, hidden_dim, hidden_dim, num_layers, devices[1:] + ) + fn = infer_types(fn, fn.inputs) + cpprint(fn) + + def convert_inputs_dp(weights, x): + xs = torch.split(x, num_devices) + + def new_inputs(): + for d in range(num_devices): + yield from weights + yield xs[d] + + return list(new_inputs()) + + # Make random input/expected data: + weights = [torch.randn(hidden_dim, hidden_dim) for _ in range(num_layers)] + x = torch.randn(batch_size, hidden_dim) + y = x + for l in range(num_layers): + y = torch.matmul(y, weights[l]) + y = torch.relu(y) + + # Project and run on backend: + per_rank_fns = project(fn, tuple(v.type for v in fn.inputs)) + per_rank_inputs = [[] for _ in range(num_devices)] + for v, a in zip(fn.inputs, convert_inputs_dp(weights, x)): + per_rank_inputs[v.type.device.device_id - 1].append(a) + per_rank_outputs = run_multiprocesses(per_rank_fns.values(), per_rank_inputs) + + # Check outputs: + assert torch.allclose(y, torch.cat(per_rank_outputs, 0)) + + if __name__ == "__main__": - test_owt(2, 4) + # test_owt(2, 4) + test_dp_mlp() From 42751dad0d6cd99340ec7be4a4018e9c87ba1a6d Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Fri, 30 Apr 2021 10:50:20 +0000 Subject: [PATCH 16/61] Backend: run and time on GPU --- dist_ir/backend/torch.py | 75 ++++++++++++++++++++++++------------ test/test_pytorch_backend.py | 9 ++++- 2 files changed, 57 insertions(+), 27 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 78d5d41c..cbd3d9ff 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -1,7 +1,9 @@ +from functools import partial import os from tempfile import TemporaryDirectory from typing import Any, Tuple +import numpy as np import torch import torch.distributed as dist from torch import fx @@ -50,8 +52,47 @@ def function_to_module(fn: Function) -> torch.nn.Module: return fx.GraphModule({}, g) +def run_process( + backend, world_size, io_dir, num_warmup_steps, num_repetitions, rank, module +): + """The Python function on rank `rank` that runs module `module`.""" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + dist.init_process_group(backend, rank=rank, world_size=world_size) + + per_rank_inputs = torch.load(os.path.join(io_dir.name, f"in.{rank}.pt")) + + # Move module and inputs to GPU (TODO gpu flag) + module.to(rank) + for t in per_rank_inputs: + t.to(rank) + + # Time a bunch of executions, then execute once for output values + events = [torch.cuda.Event(enable_timing=True)] + events[0].record() + for _ in range(num_warmup_steps + num_repetitions): + res = module(*per_rank_inputs) + if world_size > 1: + torch.distributed.barrier() + events.append(torch.cuda.Event(enable_timing=True)) + events[-1].record() + + torch.save(res, os.path.join(io_dir.name, f"out.{rank}.pt")) + runtimes = [ + events[i].elapsed_time(events[i + 1]) / 1e3 for i in range(len(events) - 1) + ] + + torch.cuda.synchronize() + dist.destroy_process_group() + return runtimes[num_warmup_steps:] + + def run_multiprocesses( - per_rank_functions: Tuple[Function], per_rank_inputs: Tuple[Any], backend="gloo" + per_rank_functions: Tuple[Function], + per_rank_inputs: Tuple[Any], + backend="gloo", + num_repetitions=100, + num_warmup=10, ): assert len(per_rank_functions) == len(per_rank_inputs) world_size = len(per_rank_functions) @@ -61,35 +102,19 @@ def run_multiprocesses( for d, gm in enumerate(per_rank_modules): print(f"{d}\n{gm.graph}\n") + # Save inputs for each per-rank function: io_dir = TemporaryDirectory() # print("run_multiprocess: saving I/O to:", io_dir.name) - - def run_process(rank, module): - """Initialize the distributed environment.""" - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = "29500" - dist.init_process_group(backend, rank=rank, world_size=world_size) - - per_rank_inputs = torch.load(os.path.join(io_dir.name, f"in.{rank}.pt")) - - # TODO time the next line only - res = module(*per_rank_inputs) - - torch.save(res, os.path.join(io_dir.name, f"out.{rank}.pt")) - - # Save inputs for each per-rank function: # TODO lowered pytorch file numbers devices 0...num_devices-1 for d, inps in enumerate(per_rank_inputs): torch.save(inps, os.path.join(io_dir.name, f"in.{d}.pt")) - processes = [] - for rank, per_rank_module in enumerate(per_rank_modules): - p = Process(target=run_process, args=(rank, per_rank_module)) - p.start() - processes.append(p) - - for p in processes: - p.join() + global run_process + per_rank_runner = partial( + run_process, backend, world_size, io_dir, num_warmup, num_repetitions + ) + with torch.multiprocessing.Pool(world_size) as p: + runtimes = p.starmap(per_rank_runner, enumerate(per_rank_modules)) # Load outputs: per_rank_outputs = [ @@ -97,4 +122,4 @@ def run_process(rank, module): ] io_dir.cleanup() - return per_rank_outputs + return per_rank_outputs, runtimes diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index cd3a3baf..4e5378c1 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -193,12 +193,17 @@ def new_inputs(): per_rank_inputs = [[] for _ in range(num_devices)] for v, a in zip(fn.inputs, convert_inputs_dp(weights, x)): per_rank_inputs[v.type.device.device_id - 1].append(a) - per_rank_outputs = run_multiprocesses(per_rank_fns.values(), per_rank_inputs) + per_rank_outputs, runtimes = run_multiprocesses( + per_rank_fns.values(), per_rank_inputs + ) # Check outputs: assert torch.allclose(y, torch.cat(per_rank_outputs, 0)) + return runtimes + if __name__ == "__main__": # test_owt(2, 4) - test_dp_mlp() + # test_dp_mlp() + pass From f73d33b8f47deb100f6da411f64e1ce13d048d66 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Fri, 30 Apr 2021 10:51:03 +0000 Subject: [PATCH 17/61] Refactor grid_search for interactive use --- examples/grid_search.py | 112 ++++++++++++++++++++++++++-------------- 1 file changed, 72 insertions(+), 40 deletions(-) diff --git a/examples/grid_search.py b/examples/grid_search.py index 276df088..db94acc3 100644 --- a/examples/grid_search.py +++ b/examples/grid_search.py @@ -1,6 +1,7 @@ import argparse from collections import defaultdict, OrderedDict import csv +from itertools import product import logging import numpy as np import time @@ -15,7 +16,7 @@ from dist_ir.executor.cost_model import CostModel from dist_ir.ir.type import Bool, Float, Int64, Tensor from dist_ir.transforms import ( - parallel_transform_3d, + mlp_dhp_transform, PipeDreamScheduler, ) from mlp import mlp @@ -78,7 +79,7 @@ def run_experiment(config): world_size = dp_degree * hp_degree * pp_degree add_devices_to_topology(topology, world_size) - transformed_function = parallel_transform_3d( + transformed_function = mlp_dhp_transform( function, dp_degree, hp_degree, @@ -101,43 +102,68 @@ def run_experiment(config): return throughput -def grid_search(): - input_dim = 8192 - hidden_dim = input_dim - output_dim = input_dim - all_cluster_sizes = [1, 2, 4, 8, 16, 32] - all_num_hidden_layers = [64] - all_batch_sizes = [8192] - configs = [] - for num_hidden_layers in all_num_hidden_layers: - for batch_size in all_batch_sizes: - for i, cluster_size in enumerate(all_cluster_sizes): - all_degrees = get_all_degrees(cluster_size) - for (dp_degree, hp_degree, pp_degree) in all_degrees: - if num_hidden_layers % pp_degree != 0: - continue - dp_batch_size = batch_size // dp_degree - if pp_degree == 1: - all_num_microbatches = [1] - else: - all_num_microbatches = [ - int(2 ** k) - for k in range(1, int(np.floor(np.log2(dp_batch_size) / 2))) - ] - for num_microbatches in all_num_microbatches: - if pp_degree == 1: - num_microbatches == 1 - configs.append( - ( - batch_size, - input_dim, - num_hidden_layers, - dp_degree, - hp_degree, - pp_degree, - num_microbatches, - ) - ) +def mlp_dist( + batch_size, + input_dim, + num_hidden_layers, + dp_degree, + hp_degree, + pp_degree, + num_microbatches, + devices, +): + function = mlp(batch_size, input_dim, input_dim, input_dim, num_hidden_layers, None) + function = infer_types(function, function.inputs) + world_size = dp_degree * hp_degree * pp_degree + + transformed_function = mlp_dhp_transform( + function, + dp_degree, + hp_degree, + pp_degree, + devices, + num_microbatches, + ) + transformed_function = infer_types( + transformed_function, transformed_function.inputs + ) + return transformed_function + + +def gen_configurations(hidden_dims, cluster_sizes, all_num_layers, all_batch_sizes): + for hidden_dim, num_hidden_layers, batch_size, cluster_size in product( + hidden_dims, all_num_layers, all_batch_sizes, cluster_sizes + ): + all_degrees = get_all_degrees(cluster_size) + for (dp_degree, hp_degree, pp_degree) in all_degrees: + if num_hidden_layers % pp_degree != 0: + continue + dp_batch_size = batch_size // dp_degree + if pp_degree == 1: + all_num_microbatches = [1] + else: + all_num_microbatches = [ + int(2 ** k) + for k in range(1, int(np.floor(np.log2(dp_batch_size) / 2))) + ] + for num_microbatches in all_num_microbatches: + if pp_degree == 1: + num_microbatches == 1 + yield ( + batch_size, + hidden_dim, + num_hidden_layers, + dp_degree, + hp_degree, + pp_degree, + num_microbatches, + ) + + +def grid_search(hidden_dims, cluster_sizes, all_num_layers, all_batch_sizes): + configs = list( + gen_configurations(hidden_dims, cluster_sizes, all_num_layers, all_batch_sizes) + ) with Pool() as p: results = p.map(run_experiment, configs) @@ -174,4 +200,10 @@ def grid_search(): if __name__ == "__main__": - grid_search() + # grid_search( + # hidden_dims=[8192], + # cluster_sizes=[1, 2, 4, 8, 16, 32], + # all_num_layers=[64], + # all_batch_sizes=[8192], + # ) + pass From 20bae75a58388d5c1f5a87b21b73501e4e713c3c Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 14:54:13 +0100 Subject: [PATCH 18/61] Timing code for CPUs --- dist_ir/backend/torch.py | 42 ++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index cbd3d9ff..0545ad50 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -1,6 +1,7 @@ from functools import partial import os from tempfile import TemporaryDirectory +from time import perf_counter from typing import Any, Tuple import numpy as np @@ -53,34 +54,47 @@ def function_to_module(fn: Function) -> torch.nn.Module: def run_process( - backend, world_size, io_dir, num_warmup_steps, num_repetitions, rank, module + use_gpu, world_size, io_dir, num_warmup_steps, num_repetitions, rank, module ): """The Python function on rank `rank` that runs module `module`.""" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "29500" + backend = "nccl" if use_gpu else "gloo" dist.init_process_group(backend, rank=rank, world_size=world_size) per_rank_inputs = torch.load(os.path.join(io_dir.name, f"in.{rank}.pt")) - # Move module and inputs to GPU (TODO gpu flag) - module.to(rank) - for t in per_rank_inputs: - t.to(rank) + if use_gpu: + # Move module and inputs to GPU + module.to(rank) + for t in per_rank_inputs: + t.to(rank) + + events = [] + + def add_event(): + if use_gpu: + events.append(torch.cuda.Event(enable_timing=True)) + events[-1].record() + else: + events.append(perf_counter()) # Time a bunch of executions, then execute once for output values - events = [torch.cuda.Event(enable_timing=True)] - events[0].record() + add_event() for _ in range(num_warmup_steps + num_repetitions): res = module(*per_rank_inputs) if world_size > 1: torch.distributed.barrier() - events.append(torch.cuda.Event(enable_timing=True)) - events[-1].record() + add_event() torch.save(res, os.path.join(io_dir.name, f"out.{rank}.pt")) - runtimes = [ - events[i].elapsed_time(events[i + 1]) / 1e3 for i in range(len(events) - 1) - ] + + if use_gpu: + runtimes = [ + events[i].elapsed_time(events[i + 1]) / 1e3 for i in range(len(events) - 1) + ] + else: + runtimes = [events[i + 1] - events[i] for i in range(len(events) - 1)] torch.cuda.synchronize() dist.destroy_process_group() @@ -90,7 +104,7 @@ def run_process( def run_multiprocesses( per_rank_functions: Tuple[Function], per_rank_inputs: Tuple[Any], - backend="gloo", + use_gpu=False, num_repetitions=100, num_warmup=10, ): @@ -111,7 +125,7 @@ def run_multiprocesses( global run_process per_rank_runner = partial( - run_process, backend, world_size, io_dir, num_warmup, num_repetitions + run_process, use_gpu, world_size, io_dir, num_warmup, num_repetitions ) with torch.multiprocessing.Pool(world_size) as p: runtimes = p.starmap(per_rank_runner, enumerate(per_rank_modules)) From 8e8513db5e97c372aa64fa9fccb2d9597816b20b Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 14:54:37 +0100 Subject: [PATCH 19/61] Handle ops with multiple outputs --- dist_ir/backend/torch.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 0545ad50..ea92d048 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -41,14 +41,16 @@ def function_to_module(fn: Function) -> torch.nn.Module: # Convert ops for op in fn.ops: inputs = tuple(value_map[v] for v in op.inputs) - assert len(op.outputs) == 1, "TODO how to handle multiple outputs in fx" kwargs = None if op.attributes is None else {**op.attributes} output = g.call_function(_op_to_torch[op.op_type], inputs, kwargs) - value_map[op.outputs[0]] = output + if len(op.outputs) > 1: + for i, v in enumerate(op.outputs): + value_map[v] = g.call_function(getitem, (output, i)) + elif len(op.outputs) == 1: + value_map[op.outputs[0]] = output # Convert outputs - for v in fn.outputs: - g.output(value_map[v]) + g.output(tuple(value_map[v] for v in fn.outputs)) return fx.GraphModule({}, g) From 6f3ae5938e82cd1e75458ddf3050e66bd8c3d261 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 14:56:30 +0100 Subject: [PATCH 20/61] Add support for all MLP training ops --- dist_ir/backend/torch.py | 48 +++++++++++++++++++++++++++++- dist_ir/executor/rank_projector.py | 28 ++++++++++++++++- dist_ir/ir/op_register.py | 3 ++ 3 files changed, 77 insertions(+), 2 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index ea92d048..84117e72 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -1,4 +1,5 @@ from functools import partial +from operator import getitem import os from tempfile import TemporaryDirectory from time import perf_counter @@ -8,7 +9,6 @@ import torch import torch.distributed as dist from torch import fx -from torch.multiprocessing import Process from ..ir import Function @@ -21,9 +21,55 @@ def _allgather(x_i, world_size=None, dim=0): return x +def _identity(x): + return x + + +def _loss(x, y, N=None): + return torch.square(x - y) / N + + +def _loss_grad(x, y, N=None): + return 2 * (x - y) / N + + +def _matmul_grad(x, y, dz): + return (torch.matmul(dz, y.T), torch.matmul(x.T, dz)) + + +@torch.fx.wrap +def _recv(shape=None, device=None): + x = torch.zeros(shape) + # TODO pytorch rank = device_id - 1 + dist.recv(x, device - 1) + + +def _relu_grad(x, dy): + # TODO: fix + dx = torch.zeros(dy.shape) + dx[dy > 0] = 1 + return dx + + +@torch.fx.wrap +def _send(x, device=None): + print("_send input type", type(x)) + # TODO pytorch rank = device_id - 1 + dist.send(x, device - 1) + + _op_to_torch = { + "Add": torch.add, + "Concat": torch.cat, # TODO dim attribute? + "Identity": _identity, + "Loss": _loss, + "LossGrad": _loss_grad, "MatMul": torch.matmul, + "MatMulGrad": _matmul_grad, + "RecvP2P": _recv, "Relu": torch.relu, + "ReluGrad": _relu_grad, + "SendP2P": _send, "MPIAllgather": _allgather, } diff --git a/dist_ir/executor/rank_projector.py b/dist_ir/executor/rank_projector.py index a45b61d4..7717470b 100644 --- a/dist_ir/executor/rank_projector.py +++ b/dist_ir/executor/rank_projector.py @@ -7,6 +7,9 @@ from .absint import AbstractState, AbstractInterpreter +# TODO merge this with torch backend -- it breaks semantics to have P2P send/recv + + class ProjectorState(AbstractState): def __init__(self, function: Function, inputs: Sequence[Any]): AbstractState.__init__(self, function, inputs) @@ -46,13 +49,36 @@ def _mpi_allgather_projector(op: Op, state: ProjectorState): state.per_rank_fns[d].ops.append(new_op) +def _send_projector(op: Op, state: ProjectorState): + from_d = op.inputs[0].type.device + to_d = op.attributes["device"] + state.per_rank_fns[from_d].ops.append( + Op("SendP2P", inputs=op.inputs, attributes={"device": to_d.device_id}) + ) + state.per_rank_fns[to_d].ops.append( + Op( + "RecvP2P", + output_values=(op.outputs[0],), + attributes={"shape": op.inputs[0].type.shape, "device": from_d.device_id}, + ) + ) + + ProjectorRegister = { + ("Add", (Tensor, Tensor)): _identity_projector, + ("Concat", (Tensor, Tensor)): _identity_projector, + ("Identity", (Tensor,)): _identity_projector, + ("Loss", (Tensor, Tensor)): _identity_projector, + ("LossGrad", (Tensor, Tensor)): _identity_projector, ("MatMul", (Tensor, Tensor)): _identity_projector, - ("Relu", (Tensor,)): _identity_projector, + ("MatMulGrad", (Tensor, Tensor, Tensor)): _identity_projector, ("MPIAllgather", (Tensor,) * 2): _mpi_allgather_projector, ("MPIAllgather", (Tensor,) * 4): _mpi_allgather_projector, ("MPIAllgather", (Tensor,) * 8): _mpi_allgather_projector, ("MPIAllgather", (Tensor,) * 16): _mpi_allgather_projector, + ("Relu", (Tensor,)): _identity_projector, + ("ReluGrad", (Tensor, Tensor)): _identity_projector, + ("Send", (Tensor,)): _send_projector, } diff --git a/dist_ir/ir/op_register.py b/dist_ir/ir/op_register.py index 66096c58..99227f4f 100644 --- a/dist_ir/ir/op_register.py +++ b/dist_ir/ir/op_register.py @@ -57,6 +57,7 @@ class OpRegisterEntry: "MPIScatterToTupleType": OpRegisterEntry(num_inputs=1, num_outputs=1), "Mul": OpRegisterEntry(num_inputs=2, num_outputs=1), "Opt": OpRegisterEntry(num_inputs=2, num_outputs=1), + "RecvP2P": OpRegisterEntry(num_inputs=0, num_outputs=1), "ReduceAllL2": OpRegisterEntry(variadic_inputs=True, num_outputs=1), "ReduceSum": OpRegisterEntry(num_inputs=1, num_outputs=1), "ReduceSumTraining": OpRegisterEntry(num_inputs=2, num_outputs=1), @@ -64,7 +65,9 @@ class OpRegisterEntry: "ReluGrad": OpRegisterEntry(num_inputs=2, num_outputs=1), "Reshape": OpRegisterEntry(num_inputs=2, num_outputs=1), "Select": OpRegisterEntry(num_inputs=1, num_outputs=1), + # TODO call the combined one SendRecv? "Send": OpRegisterEntry(num_inputs=1, num_outputs=1), + "SendP2P": OpRegisterEntry(num_inputs=1, num_outputs=0), "SGDOptimizer": OpRegisterEntry(num_inputs=3, num_outputs=2), "Shape": OpRegisterEntry(num_inputs=1, num_outputs=1), # TODO allow optional inputs for things like slice From 49d385c24201dbad355dcd606255440700926b29 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 14:56:45 +0100 Subject: [PATCH 21/61] Type inference: fix function name bug --- dist_ir/executor/type_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dist_ir/executor/type_inference.py b/dist_ir/executor/type_inference.py index 3b0f29c9..c6393d20 100644 --- a/dist_ir/executor/type_inference.py +++ b/dist_ir/executor/type_inference.py @@ -576,7 +576,7 @@ def semantics(op: Op, state: AbstractState): def _type_function(function: Function, type_map: Dict[Value, Type]) -> Function: """Create a typed version of function, using the types given in type map.""" - new_function = FunctionMaker() + new_function = FunctionMaker(name=function.name) # A Map from function's values to new_function's (typed) values: value_map: Dict[Value, Value] = {} From 2cce7c38ef26c9324008678e7122a6275cbd10a5 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 14:56:57 +0100 Subject: [PATCH 22/61] Prettyprint: support for printing FunctionMakers --- dist_ir/ir/prettyprint.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/dist_ir/ir/prettyprint.py b/dist_ir/ir/prettyprint.py index 94dd671e..fe5053e9 100644 --- a/dist_ir/ir/prettyprint.py +++ b/dist_ir/ir/prettyprint.py @@ -47,7 +47,7 @@ ) from prettyprinter.utils import intersperse -from .function import Function +from .function import Function, FunctionMaker from .value import Value from .type import Type, Int32, Int64, Float, Tensor, TupleType from .device import Device @@ -120,6 +120,18 @@ def _(function: Function, ctx): ) +@register_pretty(FunctionMaker) +def _(function: FunctionMaker, ctx): + ops = _pprint_function_body(function, ctx) + return concat( + [ + annotate(Token.KEYWORD_CONSTANT, "function* "), + pretty_call(ctx, pp_fnname(function.name), *function.inputs), + nest(ctx.indent, concat([COLON, HARDLINE, interline(*ops)])), + ] + ) + + @register_pretty(Op) def _(op: Op, ctx): results = concat(_join(*(pretty_dispatch(r, ctx) for r in op.outputs))) From 91f5ecf12cdc0ae79ef4c605fe248c41dcc852ef Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 21:08:39 +0100 Subject: [PATCH 23/61] Fix backend op implementations --- dist_ir/backend/torch.py | 23 ++++++++++++++++------- dist_ir/executor/rank_projector.py | 22 ++++++++++++++-------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 84117e72..7bd7dd47 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -13,14 +13,24 @@ from ..ir import Function -# TODO at op creation time, enforce MPIAllgather ops attributes -def _allgather(x_i, world_size=None, dim=0): +# TODO kwargs of these functions are required, enforce this somewhere +def _allgather(x_i, dim=0): + world_size = dist.get_world_size() xs = [torch.zeros_like(x_i) for _ in range(world_size)] dist.all_gather(xs, x_i) x = torch.cat(xs, dim=dim) return x +def _allreduce(x): + dist.all_reduce(x) + return x + + +def _concat2(x, y, dim=None): + return torch.cat((x, y), dim=dim) + + def _identity(x): return x @@ -37,11 +47,11 @@ def _matmul_grad(x, y, dz): return (torch.matmul(dz, y.T), torch.matmul(x.T, dz)) -@torch.fx.wrap def _recv(shape=None, device=None): x = torch.zeros(shape) # TODO pytorch rank = device_id - 1 dist.recv(x, device - 1) + return x def _relu_grad(x, dy): @@ -51,16 +61,14 @@ def _relu_grad(x, dy): return dx -@torch.fx.wrap def _send(x, device=None): - print("_send input type", type(x)) # TODO pytorch rank = device_id - 1 dist.send(x, device - 1) _op_to_torch = { "Add": torch.add, - "Concat": torch.cat, # TODO dim attribute? + "Concat": _concat2, "Identity": _identity, "Loss": _loss, "LossGrad": _loss_grad, @@ -71,6 +79,7 @@ def _send(x, device=None): "ReluGrad": _relu_grad, "SendP2P": _send, "MPIAllgather": _allgather, + "MPIAllreduce": _allreduce, } @@ -141,10 +150,10 @@ def add_event(): runtimes = [ events[i].elapsed_time(events[i + 1]) / 1e3 for i in range(len(events) - 1) ] + torch.cuda.synchronize() else: runtimes = [events[i + 1] - events[i] for i in range(len(events) - 1)] - torch.cuda.synchronize() dist.destroy_process_group() return runtimes[num_warmup_steps:] diff --git a/dist_ir/executor/rank_projector.py b/dist_ir/executor/rank_projector.py index 7717470b..97f45551 100644 --- a/dist_ir/executor/rank_projector.py +++ b/dist_ir/executor/rank_projector.py @@ -1,8 +1,8 @@ from collections import defaultdict from dist_ir.executor.type_inference import TypePropRegister -from typing import Any, Dict, Sequence +from typing import Any, Dict, Sequence, Tuple -from ..ir import Function, FunctionMaker, Device, Op +from ..ir import Function, FunctionMaker, Device, Op, Value from ..ir.type import Type, Tensor from .absint import AbstractState, AbstractInterpreter @@ -34,14 +34,16 @@ def _identity_projector(op: Op, state: ProjectorState): # state.per_rank_fns[d].add_op(op.op_type, name=op.name, inputs=op.inputs, ) -def _mpi_allgather_projector(op: Op, state: ProjectorState): +def _collective_projector(op: Op, state: ProjectorState): + """Projects a collective op over D devices that has D inputs and D outputs, + one on each device.""" assert len(op.inputs) == len(op.outputs) for in_v, out_v in zip(op.inputs, op.outputs): assert in_v.type.device == out_v.type.device d = in_v.type.device new_op = Op( - "MPIAllgather", + op.op_type, inputs=(in_v,), output_values=(out_v,), attributes=op.attributes, @@ -72,10 +74,14 @@ def _send_projector(op: Op, state: ProjectorState): ("LossGrad", (Tensor, Tensor)): _identity_projector, ("MatMul", (Tensor, Tensor)): _identity_projector, ("MatMulGrad", (Tensor, Tensor, Tensor)): _identity_projector, - ("MPIAllgather", (Tensor,) * 2): _mpi_allgather_projector, - ("MPIAllgather", (Tensor,) * 4): _mpi_allgather_projector, - ("MPIAllgather", (Tensor,) * 8): _mpi_allgather_projector, - ("MPIAllgather", (Tensor,) * 16): _mpi_allgather_projector, + ("MPIAllgather", (Tensor,) * 2): _collective_projector, + ("MPIAllgather", (Tensor,) * 4): _collective_projector, + ("MPIAllgather", (Tensor,) * 8): _collective_projector, + ("MPIAllgather", (Tensor,) * 16): _collective_projector, + ("MPIAllreduce", (Tensor,) * 2): _collective_projector, + ("MPIAllreduce", (Tensor,) * 4): _collective_projector, + ("MPIAllreduce", (Tensor,) * 8): _collective_projector, + ("MPIAllreduce", (Tensor,) * 16): _collective_projector, ("Relu", (Tensor,)): _identity_projector, ("ReluGrad", (Tensor, Tensor)): _identity_projector, ("Send", (Tensor,)): _send_projector, From 172a5db51275294f9d7bb9f87f4dc484cc9a952a Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 21:13:44 +0100 Subject: [PATCH 24/61] Convert per-rank fns to Modules inside each thread --- dist_ir/backend/torch.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 7bd7dd47..e0816d10 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -111,7 +111,7 @@ def function_to_module(fn: Function) -> torch.nn.Module: def run_process( - use_gpu, world_size, io_dir, num_warmup_steps, num_repetitions, rank, module + use_gpu, world_size, io_dir, num_warmup_steps, num_repetitions, rank, fn ): """The Python function on rank `rank` that runs module `module`.""" os.environ["MASTER_ADDR"] = "127.0.0.1" @@ -121,6 +121,9 @@ def run_process( per_rank_inputs = torch.load(os.path.join(io_dir.name, f"in.{rank}.pt")) + # Convert per-rank DistIR function to torch.nn.Module: + module = function_to_module(fn) + if use_gpu: # Move module and inputs to GPU module.to(rank) @@ -168,11 +171,6 @@ def run_multiprocesses( assert len(per_rank_functions) == len(per_rank_inputs) world_size = len(per_rank_functions) - # Convert per-rank DistIR functions to torch.nn.Modules: - per_rank_modules = list(map(function_to_module, per_rank_functions)) - for d, gm in enumerate(per_rank_modules): - print(f"{d}\n{gm.graph}\n") - # Save inputs for each per-rank function: io_dir = TemporaryDirectory() # print("run_multiprocess: saving I/O to:", io_dir.name) @@ -185,7 +183,7 @@ def run_multiprocesses( run_process, use_gpu, world_size, io_dir, num_warmup, num_repetitions ) with torch.multiprocessing.Pool(world_size) as p: - runtimes = p.starmap(per_rank_runner, enumerate(per_rank_modules)) + runtimes = p.starmap(per_rank_runner, enumerate(per_rank_functions)) # Load outputs: per_rank_outputs = [ From 2ed3cf7ca15f4939d7513af4e30023288b95bb8e Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 21:24:02 +0100 Subject: [PATCH 25/61] Per-rank projector: remove types --- dist_ir/executor/rank_projector.py | 32 ++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/dist_ir/executor/rank_projector.py b/dist_ir/executor/rank_projector.py index 97f45551..b7161478 100644 --- a/dist_ir/executor/rank_projector.py +++ b/dist_ir/executor/rank_projector.py @@ -125,7 +125,9 @@ def semantics(op: Op, state: AbstractState): ) -def project(fn: Function, input_types: Sequence[Type]): +def project( + fn: Function, input_types: Sequence[Type], num_devices: int +) -> Tuple[Function]: """Project fn to a sequence of per-rank functions.""" state = ProjectorState(fn, input_types) @@ -135,4 +137,30 @@ def project(fn: Function, input_types: Sequence[Type]): state = Projector.interpret(fn, input_types, state=state) - return {d: state.per_rank_fns[d].finalize() for d in state.per_rank_fns} + # Erase all types in per_rank_fns: + # TODO do this during projection? + result_fns = [Function(fn.name, (), (), ()) for _ in range(num_devices)] + for d, per_rank_fn in state.per_rank_fns.items(): + value_map = {} + new_fn = FunctionMaker(name=f"{fn.name}_{d.device_id-1}") + for v in per_rank_fn.inputs: + value_map[v] = new_fn.add_input_value(v.name, None) + for op in per_rank_fn.ops: + new_inputs = tuple(value_map[v] for v in op.inputs) + for v in op.outputs: + value_map[v] = Value(v.name, None) + new_outputs = tuple(value_map[v] for v in op.outputs) + new_fn.ops.append( + Op( + op.op_type, + name=op.name, + inputs=new_inputs, + attributes=op.attributes, + subfunctions=op.subfunctions, + output_values=new_outputs, + ) + ) + new_fn.set_outputs(tuple(value_map[v] for v in per_rank_fn.outputs)) + result_fns[d.device_id - 1] = new_fn.finalize() + + return result_fns From 7eee9f91c502587c3a57879ee8cb9d51f2e96034 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 21:25:50 +0100 Subject: [PATCH 26/61] Add some more tests --- test/test_pytorch_backend.py | 138 ++++++++++++++++++++++++++--------- 1 file changed, 105 insertions(+), 33 deletions(-) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 4e5378c1..d95a17c1 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -4,12 +4,15 @@ from dist_ir.backend.torch import run_multiprocesses from dist_ir.executor import SequentialExecutor +from dist_ir.executor.cost_model import CostModel from dist_ir.executor.rank_projector import project +from dist_ir.executor.simulator import Simulator from dist_ir.executor.type_inference import infer_types from dist_ir.ir import Device, FunctionMaker, cpprint, Value from dist_ir.ir.type import Float, Tensor -from dist_ir.transforms import mlp_dhp_transform -from examples.mlp import mlp_inference, mlp_inference_dp +from dist_ir.ir.topology import Topology +from examples.grid_search import add_devices_to_topology, gen_configurations, mlp_dist +from examples.mlp import mlp, mlp_inference_dp def create_owt_model(num_devices, num_layers): @@ -42,7 +45,7 @@ def create_owt_model(num_devices, num_layers): "MPIAllgather", inputs=hs, output_names=as_names, - attributes={"dim": 0, "world_size": num_devices}, + attributes={"dim": 0}, ) # Model parallel fully-connected layers: (again, MatMuls for now) @@ -65,7 +68,7 @@ def create_owt_model(num_devices, num_layers): "MPIAllgather", inputs=h_is, output_names=out_names, - attributes={"dim": 1, "world_size": num_devices}, + attributes={"dim": 1}, ) fn.set_outputs(hs) @@ -128,33 +131,104 @@ def test_owt(num_devices, num_layers): ys = np.split(y, num_devices, axis=1) assert all(np.allclose(y, o) for y, o in zip(ys, output_arrays)) - # Per-rank projection: - per_rank_fns = project(fn, tuple(v.type for v in input_vals)) - for d, f_d in per_rank_fns.items(): - print() - print(d) - cpprint(f_d) + # Run per-rank modules using PyTorch backend: + per_rank_outputs, _ = run_pytorch( + num_devices, fn, [torch.tensor(a) for a in input_arrays] + ) + + # Check outputs: + assert all(np.allclose(y[0], o) for y, o in zip(per_rank_outputs, output_arrays)) + + +def test_mlp_grid_search(): + batch_size = 64 + hidden_dim = 64 + num_layers = 2 + world_size = 2 + + topology = Topology() + d0 = topology.add_device("gpu") + add_devices_to_topology(topology, world_size) + simulator = Simulator(CostModel(topology)) + seq_executor = SequentialExecutor("numpy") + + seq_mlp = mlp(batch_size, hidden_dim, hidden_dim, hidden_dim, num_layers, d0) + seq_mlp = infer_types(seq_mlp, seq_mlp.inputs) + configs = list( + gen_configurations([hidden_dim], [world_size], [num_layers], [batch_size]) + ) + dist_mlp_fns = [ + mlp_dist(seq_mlp, d, h, p, m, topology) for (_, _, _, d, h, p, m) in configs + ] + print(len(dist_mlp_fns)) + + # Create random input data + input_data = tuple( + np.random.randn(*v.type.shape).astype(np.float32) for v in seq_mlp.inputs + ) + + for init_fn, fn in dist_mlp_fns: + # Simulate + simulation = simulator.interpret(fn, (v.type for v in fn.inputs)) + simulated_time = max([simulation.timestamps[d] for d in simulation.timestamps]) + + # Reference-execute init_fn to get inputs for fn + dist_input_data = seq_executor.compute(init_fn, input_data) + dist_input_data = tuple(torch.tensor(t) for t in dist_input_data) + assert all( + t.shape == v.type.shape for (t, v) in zip(dist_input_data, fn.inputs) + ) + + # Measure actual execution time + # TODO check outputs match? + _, runtimes = run_pytorch(world_size, fn, dist_input_data) + actual_time = max(np.median(times) for times in runtimes) - # Make inputs for each per-rank function: + print(fn.name, simulated_time, actual_time) + + +def run_pytorch(num_devices, fn, inputs): + """Project `fn` and run on `inputs` using PyTorch backend.""" + # TODO add to backend.torch? + # TODO check that fn uses devices [0...num_devices) + per_rank_fns = project(fn, tuple(v.type for v in fn.inputs), num_devices) per_rank_inputs = [[] for _ in range(num_devices)] - for v, a in zip(fn.inputs, input_arrays): - per_rank_inputs[v.type.device.device_id - 1].append(torch.tensor(a)) + for v, a in zip(fn.inputs, inputs): + per_rank_inputs[v.type.device.device_id - 1].append(a) + return run_multiprocesses(per_rank_fns, per_rank_inputs) - # Run per-rank modules using PyTorch backend: - per_rank_outputs = run_multiprocesses(per_rank_fns.values(), per_rank_inputs) - # Check outputs: - assert all(np.allclose(y, o) for y, o in zip(per_rank_outputs, output_arrays)) +def test_empty_device(): + d1 = Device(1, "gpu") + d2 = Device(2, "gpu") + fn = FunctionMaker() + x = fn.add_input_value("x", Tensor(Float(), (4, 4), d1)) + y = fn.add_op("MatMul", inputs=(x, x)) + fn.set_outputs((y,)) + fn = fn.finalize() + cpprint(fn) + x = torch.randn(4, 4) + inputs = (x,) + outputs, _ = run_pytorch(2, fn, inputs) + print(outputs) + assert torch.allclose(torch.matmul(x, x), outputs[0][0]) -# TODO get DHP transform to work on mlp_inference and try running on backend -# def test_mlp_grid_search(): -# devices = [Device(d, "gpu") for d in range(3)] -# -# f = mlp_inference(4, 6, 6, 6, 4, devices[0]) -# f = infer_types(f, f.inputs) -# -# f_dist = mlp_dhp_transform(f, 2, 1, 1, devices, 1) + +def test_send_recv(): + d1 = Device(1, "gpu") + d2 = Device(2, "gpu") + fn = FunctionMaker() + x = fn.add_input_value("x", Tensor(Float(), (4, 4), d1)) + y = fn.add_op("Send", inputs=(x,), attributes={"device": d2}) + fn.set_outputs((x, y)) + fn = fn.finalize() + cpprint(fn) + + x = torch.randn(4, 4) + inputs = (x,) + outputs, _ = run_pytorch(2, fn, inputs) + assert torch.allclose(x, outputs[1][0]) def test_dp_mlp(): @@ -189,16 +263,12 @@ def new_inputs(): y = torch.relu(y) # Project and run on backend: - per_rank_fns = project(fn, tuple(v.type for v in fn.inputs)) - per_rank_inputs = [[] for _ in range(num_devices)] - for v, a in zip(fn.inputs, convert_inputs_dp(weights, x)): - per_rank_inputs[v.type.device.device_id - 1].append(a) - per_rank_outputs, runtimes = run_multiprocesses( - per_rank_fns.values(), per_rank_inputs + per_rank_outputs, runtimes = run_pytorch( + num_devices, fn, convert_inputs_dp(weights, x) ) # Check outputs: - assert torch.allclose(y, torch.cat(per_rank_outputs, 0)) + assert torch.allclose(y, torch.cat([o[0] for o in per_rank_outputs], 0)) return runtimes @@ -206,4 +276,6 @@ def new_inputs(): if __name__ == "__main__": # test_owt(2, 4) # test_dp_mlp() - pass + # test_send_recv() + # test_empty_device() + test_mlp_grid_search() From 78da37df405592301ba675c6c7e93bbe64cf66f1 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 21:33:17 +0100 Subject: [PATCH 27/61] Default number of repetitions = 1 --- dist_ir/backend/torch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index e0816d10..a6e31a0c 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -165,8 +165,8 @@ def run_multiprocesses( per_rank_functions: Tuple[Function], per_rank_inputs: Tuple[Any], use_gpu=False, - num_repetitions=100, - num_warmup=10, + num_repetitions=1, + num_warmup=0, ): assert len(per_rank_functions) == len(per_rank_inputs) world_size = len(per_rank_functions) From 56e05dc71407f6f863e5e90404b90c118ad136db Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 21:34:02 +0100 Subject: [PATCH 28/61] DHP transform: return separate init_fn and transformed fn --- dist_ir/transforms/mlp_dhp_transform.py | 31 ++++++++++++++++++------- examples/grid_search.py | 27 ++++++++------------- test/test_mlp_dhp_transform.py | 11 +++++---- 3 files changed, 38 insertions(+), 31 deletions(-) diff --git a/dist_ir/transforms/mlp_dhp_transform.py b/dist_ir/transforms/mlp_dhp_transform.py index 85a82b68..ffe8df53 100644 --- a/dist_ir/transforms/mlp_dhp_transform.py +++ b/dist_ir/transforms/mlp_dhp_transform.py @@ -309,7 +309,8 @@ def mlp_dhp_transform( function, dp_degree, hp_degree, pp_degree, devices, num_microbatches ): """Automatically distributes an MLP function using D/H/P hybrid parallelism.""" - transformed_function = FunctionMaker(name=function.name) + fn_name = f"{function.name}_{dp_degree}_{hp_degree}_{pp_degree}_{num_microbatches}" + transformed_function = FunctionMaker(name=fn_name) device_tree = _get_device_tree(dp_degree, hp_degree, pp_degree, devices) device_tree_root = tuple(device_tree.keys())[0] dp_devices = tuple(sorted(device_tree[device_tree_root].keys())) @@ -324,22 +325,28 @@ def mlp_dhp_transform( ) ) - # Add inputs to the transformed function. + # An init function that moves weights/inputs to correct devices + init_function = FunctionMaker(name=fn_name + "_init") transformed_inputs = {} for inp in function.inputs: - v = transformed_function.add_input_value(inp.name, inp.type) + v = init_function.add_input_value(inp.name, inp.type) transformed_inputs[inp] = v # Partition inputs across each parallelism dimension. - dp_inputs = _partition_inputs_dp(transformed_function, device_tree) - hp_inputs = _partition_inputs_hp(transformed_function, device_tree, dp_inputs) + dp_inputs = _partition_inputs_dp(init_function, device_tree) + hp_inputs = _partition_inputs_hp(init_function, device_tree, dp_inputs) pp_inputs = _partition_inputs_pp( - transformed_function, + init_function, device_tree, dp_inputs, hp_inputs, num_microbatches, ) + init_function = init_function.finalize() + + # Inputs of transformed_function are outputs of init_function + for v in init_function.outputs: + transformed_function.inputs.append(v) dp_outputs = defaultdict(list) for i, dp_device in enumerate(device_tree[device_tree_root]): @@ -670,7 +677,9 @@ def mlp_dhp_transform( hp_group, transformed_function, output_names=[ - f"{output.name}_dp_all_hp_{hp_device_group_str}_pp_all" + # TODO how to get device? + f"{output.name}_dp_all_hp_{hp_device_group_str}_pp_all_{j}" + for j in range(len(hp_group)) ], ) else: @@ -679,6 +688,10 @@ def mlp_dhp_transform( hp_group, transformed_function, dim=0, - output_names=[f"{output.name}_dp_all_hp_all_pp_all"], + output_names=[ + f"{output.name}_dp_all_hp_all_pp_all_{j}" + for j in range(len(hp_group)) + ], ) - return transformed_function.finalize() + # TODO transformed_function should output loss/grads? + return init_function, transformed_function.finalize() diff --git a/examples/grid_search.py b/examples/grid_search.py index db94acc3..34130280 100644 --- a/examples/grid_search.py +++ b/examples/grid_search.py @@ -19,7 +19,7 @@ mlp_dhp_transform, PipeDreamScheduler, ) -from mlp import mlp +from .mlp import mlp DGX_BANDWIDTH_GBPS = 200 @@ -31,7 +31,6 @@ def add_devices_to_topology(topology, num_devices): for i in range(0, len(devices)): for j in range(i + 1, len(devices)): topology.set_bandwidth(devices[i], devices[j], DGX_BANDWIDTH_GBPS) - return topology def get_all_degrees(n): @@ -103,31 +102,25 @@ def run_experiment(config): def mlp_dist( - batch_size, - input_dim, - num_hidden_layers, + mlp_fn, dp_degree, hp_degree, pp_degree, num_microbatches, - devices, + topology, ): - function = mlp(batch_size, input_dim, input_dim, input_dim, num_hidden_layers, None) - function = infer_types(function, function.inputs) - world_size = dp_degree * hp_degree * pp_degree - - transformed_function = mlp_dhp_transform( - function, + init_function, transformed_function = mlp_dhp_transform( + mlp_fn, dp_degree, hp_degree, pp_degree, - devices, + topology.devices, num_microbatches, ) - transformed_function = infer_types( - transformed_function, transformed_function.inputs - ) - return transformed_function + init_function = infer_types(init_function, init_function.inputs) + # init_function.outputs = transformed_function.inputs, so get types from there: + transformed_function = infer_types(transformed_function, init_function.outputs) + return init_function, transformed_function def gen_configurations(hidden_dims, cluster_sizes, all_num_layers, all_batch_sizes): diff --git a/test/test_mlp_dhp_transform.py b/test/test_mlp_dhp_transform.py index a3de44a5..7b583992 100644 --- a/test/test_mlp_dhp_transform.py +++ b/test/test_mlp_dhp_transform.py @@ -138,7 +138,7 @@ def _test_helper( world_size = dp_degree * hp_degree * pp_degree add_devices_to_topology(topology, world_size) - transformed_function = mlp_dhp_transform( + init_function, transformed_function = mlp_dhp_transform( function, dp_degree, hp_degree, @@ -146,14 +146,15 @@ def _test_helper( topology.devices, num_microbatches, ) - transformed_function = infer_types( - transformed_function, transformed_function.inputs - ) + init_function = infer_types(init_function, init_function.inputs) + # init_function.outputs = transformed_function.inputs, so get types from there: + transformed_function = infer_types(transformed_function, init_function.outputs) input_data = [np.random.normal(size=inp.type.shape) for inp in function.inputs] ex = SequentialExecutor("numpy") outputs = ex.compute(function, input_data) - transformed_outputs = ex.compute(transformed_function, input_data) + dist_input_data = ex.compute(init_function, input_data) + transformed_outputs = ex.compute(transformed_function, dist_input_data) if hp_degree > 1: _verify_hp( From 78bc1e9b0827a366d941eec41f60bcac487a6716 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 21:53:56 +0100 Subject: [PATCH 29/61] Run pytest with root dir included in PYTHONPATH --- .github/workflows/tests.yml | 2 +- test/test_mlp_dhp_transform.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 45a6f3ca..c838a8e9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -52,4 +52,4 @@ jobs: run: python setup.py install - name: Test with pytest - run: pytest + run: python -m pytest diff --git a/test/test_mlp_dhp_transform.py b/test/test_mlp_dhp_transform.py index 7b583992..ad03bbe9 100644 --- a/test/test_mlp_dhp_transform.py +++ b/test/test_mlp_dhp_transform.py @@ -124,7 +124,7 @@ def _verify_hp(function, transformed_function, outputs, transformed_outputs, dp= def _test_helper( batch_size=BATCH_SIZE, - num_hidden_layers=8, + num_hidden_layers=2, input_dim=INPUT_DIM, dp_degree=1, hp_degree=1, @@ -149,6 +149,8 @@ def _test_helper( init_function = infer_types(init_function, init_function.inputs) # init_function.outputs = transformed_function.inputs, so get types from there: transformed_function = infer_types(transformed_function, init_function.outputs) + cpprint(function) + cpprint(transformed_function) input_data = [np.random.normal(size=inp.type.shape) for inp in function.inputs] ex = SequentialExecutor("numpy") @@ -190,3 +192,7 @@ def test_hp_pp(): def test_dp_hp_pp(): _test_helper(dp_degree=2, hp_degree=2, pp_degree=2, num_microbatches=2) + + +if __name__ == "__main__": + test_dp_only() From 8a288b68bb8bc715d7a457efe705830c63f80e49 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 21:58:54 +0100 Subject: [PATCH 30/61] Grid search: remove unused imports --- examples/grid_search.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/examples/grid_search.py b/examples/grid_search.py index 34130280..ec82f231 100644 --- a/examples/grid_search.py +++ b/examples/grid_search.py @@ -1,24 +1,12 @@ -import argparse -from collections import defaultdict, OrderedDict import csv from itertools import product -import logging import numpy as np -import time -import matplotlib as mpl -import matplotlib.pyplot as plt from multiprocessing import Pool -import dist_ir -from dist_ir.importer import import_from_onnx, parse_tensor_from_file -from dist_ir.ir import FunctionMaker, cpprint, pformat, Device, Topology, Value -from dist_ir.executor import infer_types, SequentialExecutor, Simulator +from dist_ir.ir import Topology +from dist_ir.executor import infer_types, Simulator from dist_ir.executor.cost_model import CostModel -from dist_ir.ir.type import Bool, Float, Int64, Tensor -from dist_ir.transforms import ( - mlp_dhp_transform, - PipeDreamScheduler, -) +from dist_ir.transforms import mlp_dhp_transform from .mlp import mlp DGX_BANDWIDTH_GBPS = 200 From 4b2b3014012b615d0add9620e3a83a7c3e775db9 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 22:03:08 +0100 Subject: [PATCH 31/61] Revert unintended changes --- test/test_mlp_dhp_transform.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/test_mlp_dhp_transform.py b/test/test_mlp_dhp_transform.py index ad03bbe9..4ff1a219 100644 --- a/test/test_mlp_dhp_transform.py +++ b/test/test_mlp_dhp_transform.py @@ -124,7 +124,7 @@ def _verify_hp(function, transformed_function, outputs, transformed_outputs, dp= def _test_helper( batch_size=BATCH_SIZE, - num_hidden_layers=2, + num_hidden_layers=8, input_dim=INPUT_DIM, dp_degree=1, hp_degree=1, @@ -149,8 +149,6 @@ def _test_helper( init_function = infer_types(init_function, init_function.inputs) # init_function.outputs = transformed_function.inputs, so get types from there: transformed_function = infer_types(transformed_function, init_function.outputs) - cpprint(function) - cpprint(transformed_function) input_data = [np.random.normal(size=inp.type.shape) for inp in function.inputs] ex = SequentialExecutor("numpy") From 75ec41a0953d6d03002e6e090b0964e5484cfe47 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Sun, 2 May 2021 22:07:08 +0100 Subject: [PATCH 32/61] Move run_pytorch to backend.torch --- dist_ir/backend/torch.py | 13 +++++++++++++ test/test_pytorch_backend.py | 16 +++------------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index a6e31a0c..c3ecca92 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -10,6 +10,7 @@ import torch.distributed as dist from torch import fx +from ..executor.rank_projector import project from ..ir import Function @@ -192,3 +193,15 @@ def run_multiprocesses( io_dir.cleanup() return per_rank_outputs, runtimes + + +def run_pytorch(num_devices, fn, inputs): + """Project `fn` and run on `inputs` over `num_devices` devices using the + PyTorch backend. + """ + # TODO check that fn uses devices [0...num_devices) + per_rank_fns = project(fn, tuple(v.type for v in fn.inputs), num_devices) + per_rank_inputs = [[] for _ in range(num_devices)] + for v, a in zip(fn.inputs, inputs): + per_rank_inputs[v.type.device.device_id - 1].append(a) + return run_multiprocesses(per_rank_fns, per_rank_inputs) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index d95a17c1..d355b838 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -2,15 +2,16 @@ import pytest import torch -from dist_ir.backend.torch import run_multiprocesses +from dist_ir.backend.torch import run_pytorch from dist_ir.executor import SequentialExecutor from dist_ir.executor.cost_model import CostModel -from dist_ir.executor.rank_projector import project from dist_ir.executor.simulator import Simulator from dist_ir.executor.type_inference import infer_types from dist_ir.ir import Device, FunctionMaker, cpprint, Value from dist_ir.ir.type import Float, Tensor from dist_ir.ir.topology import Topology + +# TODO make examples submodule of dist_ir? from examples.grid_search import add_devices_to_topology, gen_configurations, mlp_dist from examples.mlp import mlp, mlp_inference_dp @@ -187,17 +188,6 @@ def test_mlp_grid_search(): print(fn.name, simulated_time, actual_time) -def run_pytorch(num_devices, fn, inputs): - """Project `fn` and run on `inputs` using PyTorch backend.""" - # TODO add to backend.torch? - # TODO check that fn uses devices [0...num_devices) - per_rank_fns = project(fn, tuple(v.type for v in fn.inputs), num_devices) - per_rank_inputs = [[] for _ in range(num_devices)] - for v, a in zip(fn.inputs, inputs): - per_rank_inputs[v.type.device.device_id - 1].append(a) - return run_multiprocesses(per_rank_fns, per_rank_inputs) - - def test_empty_device(): d1 = Device(1, "gpu") d2 = Device(2, "gpu") From 6dbf57cdd4cde3d3f07a21c03286d139c4ff571a Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Wed, 5 May 2021 16:41:09 +0100 Subject: [PATCH 33/61] Interpret Function instead of creating fx.Graph --- dist_ir/backend/torch.py | 51 +++++++++++++++++++++++++++++++----- test/test_pytorch_backend.py | 2 ++ 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index c3ecca92..07e57753 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -1,11 +1,11 @@ from functools import partial +import logging from operator import getitem import os from tempfile import TemporaryDirectory from time import perf_counter -from typing import Any, Tuple +from typing import Any, List, Tuple -import numpy as np import torch import torch.distributed as dist from torch import fx @@ -111,6 +111,37 @@ def function_to_module(fn: Function) -> torch.nn.Module: return fx.GraphModule({}, g) +def run_function(rank, fn: Function, inputs: List[Any]): + value_map = {} + + # Add inputs to value_map + for v, x in zip(fn.inputs, inputs): + value_map[v] = x + assert len(fn.inputs) == len(inputs) + + # Run ops + for op in fn.ops: + first_output = ( + op.outputs[0].name + if op.outputs is not None and len(op.outputs) > 0 + else "None" + ) + logging.info(f"{rank}: {first_output} {op.op_type}") + inputs = tuple(value_map[v] for v in op.inputs) + kwargs = {} if op.attributes is None else {**op.attributes} + output = _op_to_torch[op.op_type](*inputs, **kwargs) + if len(op.outputs) > 1: + assert isinstance(output, tuple) + for i, v in enumerate(op.outputs): + value_map[v] = output[i] + elif len(op.outputs) == 1: + value_map[op.outputs[0]] = output + logging.info(f"{rank}: {first_output} {op.op_type}") + + # Return outputs + return tuple(value_map[v] for v in fn.outputs) + + def run_process( use_gpu, world_size, io_dir, num_warmup_steps, num_repetitions, rank, fn ): @@ -122,12 +153,13 @@ def run_process( per_rank_inputs = torch.load(os.path.join(io_dir.name, f"in.{rank}.pt")) - # Convert per-rank DistIR function to torch.nn.Module: - module = function_to_module(fn) + # # Convert per-rank DistIR function to torch.nn.Module: + # module = function_to_module(fn) if use_gpu: # Move module and inputs to GPU - module.to(rank) + # TODO how to move interpreted non-module code to GPU? + # module.to(rank) for t in per_rank_inputs: t.to(rank) @@ -143,7 +175,8 @@ def add_event(): # Time a bunch of executions, then execute once for output values add_event() for _ in range(num_warmup_steps + num_repetitions): - res = module(*per_rank_inputs) + # res = module(*per_rank_inputs) + res = run_function(rank, fn, per_rank_inputs) if world_size > 1: torch.distributed.barrier() add_event() @@ -199,8 +232,12 @@ def run_pytorch(num_devices, fn, inputs): """Project `fn` and run on `inputs` over `num_devices` devices using the PyTorch backend. """ - # TODO check that fn uses devices [0...num_devices) + # TODO check that fn uses devices [0...num_devices), + # or run through and find max device used per_rank_fns = project(fn, tuple(v.type for v in fn.inputs), num_devices) + # from ..ir import cpprint + # for per_rank_fn in per_rank_fns: + # cpprint(per_rank_fn) per_rank_inputs = [[] for _ in range(num_devices)] for v, a in zip(fn.inputs, inputs): per_rank_inputs[v.type.device.device_id - 1].append(a) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index d355b838..0720ee9d 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -268,4 +268,6 @@ def new_inputs(): # test_dp_mlp() # test_send_recv() # test_empty_device() + # import logging + # logging.basicConfig(level=logging.INFO) test_mlp_grid_search() From 4a773836cdb6d350937e0e473f40d85467d27767 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Wed, 5 May 2021 16:59:54 +0100 Subject: [PATCH 34/61] Prettyprint attributes as Python-style kwargs --- dist_ir/ir/prettyprint.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/dist_ir/ir/prettyprint.py b/dist_ir/ir/prettyprint.py index fe5053e9..007adfee 100644 --- a/dist_ir/ir/prettyprint.py +++ b/dist_ir/ir/prettyprint.py @@ -134,8 +134,18 @@ def _(function: FunctionMaker, ctx): @register_pretty(Op) def _(op: Op, ctx): - results = concat(_join(*(pretty_dispatch(r, ctx) for r in op.outputs))) - args = concat(_join(*(v.name for v in op.inputs))) + attributes = () + if op.attributes is not None: + attributes = ( + concat((key, ASSIGN_OP, pretty_dispatch(value, ctx))) + for key, value in op.attributes.items() + ) + args = concat(_join(*(v.name for v in op.inputs), *attributes)) + + if len(op.outputs) == 0: + results = "_" + else: + results = concat(_join(*(pretty_dispatch(r, ctx) for r in op.outputs))) if op.op_type == "Pmap": lambda_body = _pprint_function_body(op.subfunctions[0], ctx) From e8735c36842ca5ae60c3de63267601a2f91348d9 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Wed, 5 May 2021 17:34:47 +0100 Subject: [PATCH 35/61] Use broadcast with pairwise groups for send/recv on GPUs --- dist_ir/backend/torch.py | 57 ++++++++++++++++++++++++++++-------- test/test_pytorch_backend.py | 1 + 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 07e57753..2ade33dc 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -1,4 +1,5 @@ from functools import partial +from itertools import combinations import logging from operator import getitem import os @@ -13,6 +14,22 @@ from ..executor.rank_projector import project from ..ir import Function +_use_gpu = False +_groups = None + + +def _init_p2p_groups(): + """Since torch.distributed's NCCL backed doesn't support P2P communication, + we create a group for each pair of ranks and use broadcasts to emulate P2P + send/recv. This method initializes the groups. + """ + global _use_gpu, _groups + if _use_gpu: + world_size = dist.get_world_size() + _groups = {} + for i, j in combinations(range(world_size), 2): + _groups[i, j] = dist.new_group([i, j]) + # TODO kwargs of these functions are required, enforce this somewhere def _allgather(x_i, dim=0): @@ -51,7 +68,13 @@ def _matmul_grad(x, y, dz): def _recv(shape=None, device=None): x = torch.zeros(shape) # TODO pytorch rank = device_id - 1 - dist.recv(x, device - 1) + if _use_gpu: + src_rank = device - 1 + dst_rank = dist.get_rank() + group = _groups[tuple(sorted((src_rank, dst_rank)))] + dist.broadcast(x, src_rank, group=group) + else: + dist.recv(x, device - 1) return x @@ -64,7 +87,13 @@ def _relu_grad(x, dy): def _send(x, device=None): # TODO pytorch rank = device_id - 1 - dist.send(x, device - 1) + if _use_gpu: + src_rank = dist.get_rank() + dst_rank = device - 1 + group = _groups[tuple(sorted((src_rank, dst_rank)))] + dist.broadcast(x, src_rank, group=group) + else: + dist.send(x, device - 1) _op_to_torch = { @@ -142,21 +171,20 @@ def run_function(rank, fn: Function, inputs: List[Any]): return tuple(value_map[v] for v in fn.outputs) -def run_process( - use_gpu, world_size, io_dir, num_warmup_steps, num_repetitions, rank, fn -): +def run_process(world_size, io_dir, num_warmup_steps, num_repetitions, rank, fn): """The Python function on rank `rank` that runs module `module`.""" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "29500" - backend = "nccl" if use_gpu else "gloo" + backend = "nccl" if _use_gpu else "gloo" dist.init_process_group(backend, rank=rank, world_size=world_size) + _init_p2p_groups() per_rank_inputs = torch.load(os.path.join(io_dir.name, f"in.{rank}.pt")) # # Convert per-rank DistIR function to torch.nn.Module: # module = function_to_module(fn) - if use_gpu: + if _use_gpu: # Move module and inputs to GPU # TODO how to move interpreted non-module code to GPU? # module.to(rank) @@ -166,7 +194,7 @@ def run_process( events = [] def add_event(): - if use_gpu: + if _use_gpu: events.append(torch.cuda.Event(enable_timing=True)) events[-1].record() else: @@ -183,7 +211,7 @@ def add_event(): torch.save(res, os.path.join(io_dir.name, f"out.{rank}.pt")) - if use_gpu: + if _use_gpu: runtimes = [ events[i].elapsed_time(events[i + 1]) / 1e3 for i in range(len(events) - 1) ] @@ -198,7 +226,6 @@ def add_event(): def run_multiprocesses( per_rank_functions: Tuple[Function], per_rank_inputs: Tuple[Any], - use_gpu=False, num_repetitions=1, num_warmup=0, ): @@ -214,7 +241,7 @@ def run_multiprocesses( global run_process per_rank_runner = partial( - run_process, use_gpu, world_size, io_dir, num_warmup, num_repetitions + run_process, world_size, io_dir, num_warmup, num_repetitions ) with torch.multiprocessing.Pool(world_size) as p: runtimes = p.starmap(per_rank_runner, enumerate(per_rank_functions)) @@ -228,17 +255,23 @@ def run_multiprocesses( return per_rank_outputs, runtimes -def run_pytorch(num_devices, fn, inputs): +def run_pytorch(num_devices, fn, inputs, use_gpu=False): """Project `fn` and run on `inputs` over `num_devices` devices using the PyTorch backend. """ # TODO check that fn uses devices [0...num_devices), # or run through and find max device used + + global _use_gpu + _use_gpu = use_gpu + per_rank_fns = project(fn, tuple(v.type for v in fn.inputs), num_devices) # from ..ir import cpprint # for per_rank_fn in per_rank_fns: # cpprint(per_rank_fn) + per_rank_inputs = [[] for _ in range(num_devices)] for v, a in zip(fn.inputs, inputs): per_rank_inputs[v.type.device.device_id - 1].append(a) + return run_multiprocesses(per_rank_fns, per_rank_inputs) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 0720ee9d..4b7f7af0 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -268,6 +268,7 @@ def new_inputs(): # test_dp_mlp() # test_send_recv() # test_empty_device() + # import logging # logging.basicConfig(level=logging.INFO) test_mlp_grid_search() From fd2e7b125fe56c09ae32da72df846a4ef49a0f92 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Wed, 5 May 2021 17:07:51 +0000 Subject: [PATCH 36/61] Move new tensors to GPU in each op, outputs back to CPU --- dist_ir/backend/torch.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 2ade33dc..4d4f5cd2 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -35,6 +35,9 @@ def _init_p2p_groups(): def _allgather(x_i, dim=0): world_size = dist.get_world_size() xs = [torch.zeros_like(x_i) for _ in range(world_size)] + if _use_gpu: + xs = [x.cuda(dist.get_rank()) for x in xs] + dist.all_gather(xs, x_i) x = torch.cat(xs, dim=dim) return x @@ -69,6 +72,7 @@ def _recv(shape=None, device=None): x = torch.zeros(shape) # TODO pytorch rank = device_id - 1 if _use_gpu: + x = x.cuda(dist.get_rank()) src_rank = device - 1 dst_rank = dist.get_rank() group = _groups[tuple(sorted((src_rank, dst_rank)))] @@ -81,6 +85,8 @@ def _recv(shape=None, device=None): def _relu_grad(x, dy): # TODO: fix dx = torch.zeros(dy.shape) + if _use_gpu: + dx = dx.cuda(dist.get_rank()) dx[dy > 0] = 1 return dx @@ -187,9 +193,8 @@ def run_process(world_size, io_dir, num_warmup_steps, num_repetitions, rank, fn) if _use_gpu: # Move module and inputs to GPU # TODO how to move interpreted non-module code to GPU? - # module.to(rank) - for t in per_rank_inputs: - t.to(rank) + # module = module.cuda(rank) + per_rank_inputs = [t.cuda(rank) for t in per_rank_inputs] events = [] @@ -209,6 +214,10 @@ def add_event(): torch.distributed.barrier() add_event() + if _use_gpu: + # Move outputs back to cpu + res = [t.cpu() for t in res] + torch.save(res, os.path.join(io_dir.name, f"out.{rank}.pt")) if _use_gpu: From 31031cc73546fb74e14e49080e6f93814d29ab43 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Thu, 6 May 2021 14:00:55 +0100 Subject: [PATCH 37/61] Add a mock multiprocess backend for debugging --- dist_ir/backend/torch.py | 79 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 7 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 4d4f5cd2..c47c7c96 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -118,6 +118,40 @@ def _send(x, device=None): "MPIAllreduce": _allreduce, } +# Some mock communication ops that return zero tensors of appropriate shape +# to be used in the sequential runner for debugging + +_mock_world_size = None + + +def _mock_allgather(x_i, dim=0): + xs = [torch.zeros_like(x_i) for _ in range(_mock_world_size)] + x = torch.cat(xs, dim=dim) + return x + + +def _mock_allreduce(x): + return x + + +def _mock_recv(shape=None, device=None): + x = torch.zeros(shape) + return x + + +def _mock_send(x, device=None): + pass + + +_mock_comm_ops = { + "RecvP2P": _mock_recv, + "SendP2P": _mock_send, + "MPIAllgather": _mock_allgather, + "MPIAllreduce": _mock_allreduce, +} + +_mock_op_to_torch = {**_op_to_torch, **_mock_comm_ops} + def function_to_module(fn: Function) -> torch.nn.Module: g = fx.Graph() @@ -146,7 +180,8 @@ def function_to_module(fn: Function) -> torch.nn.Module: return fx.GraphModule({}, g) -def run_function(rank, fn: Function, inputs: List[Any]): +def run_function(rank, fn: Function, inputs: List[Any], debug_mock=False): + op_to_torch = _mock_op_to_torch if debug_mock else _op_to_torch value_map = {} # Add inputs to value_map @@ -164,7 +199,7 @@ def run_function(rank, fn: Function, inputs: List[Any]): logging.info(f"{rank}: {first_output} {op.op_type}") inputs = tuple(value_map[v] for v in op.inputs) kwargs = {} if op.attributes is None else {**op.attributes} - output = _op_to_torch[op.op_type](*inputs, **kwargs) + output = op_to_torch[op.op_type](*inputs, **kwargs) if len(op.outputs) > 1: assert isinstance(output, tuple) for i, v in enumerate(op.outputs): @@ -232,6 +267,29 @@ def add_event(): return runtimes[num_warmup_steps:] +def run_mock_multiprocess( + per_rank_functions: Tuple[Function], + per_rank_inputs: Tuple[Any], + num_repetitions=1, + num_warmup=0, +): + assert len(per_rank_functions) == len(per_rank_inputs) + global _mock_world_size + _mock_world_size = len(per_rank_functions) + + per_rank_outputs = [ + run_function(rank, fn, inputs, debug_mock=True) + for rank, fn, inputs in zip( + range(_mock_world_size), per_rank_functions, per_rank_inputs + ) + ] + mock_runtimes = [ + [0.0 for _ in range(num_warmup + num_repetitions)] + for _ in range(_mock_world_size) + ] + return (per_rank_outputs, mock_runtimes) + + def run_multiprocesses( per_rank_functions: Tuple[Function], per_rank_inputs: Tuple[Any], @@ -264,7 +322,7 @@ def run_multiprocesses( return per_rank_outputs, runtimes -def run_pytorch(num_devices, fn, inputs, use_gpu=False): +def run_pytorch(num_devices, fn, inputs, use_gpu=False, debug_mock=False): """Project `fn` and run on `inputs` over `num_devices` devices using the PyTorch backend. """ @@ -274,13 +332,20 @@ def run_pytorch(num_devices, fn, inputs, use_gpu=False): global _use_gpu _use_gpu = use_gpu - per_rank_fns = project(fn, tuple(v.type for v in fn.inputs), num_devices) # from ..ir import cpprint - # for per_rank_fn in per_rank_fns: - # cpprint(per_rank_fn) + # print(*(x.shape for x in inputs)) + # cpprint(fn) + + per_rank_fns = project(fn, tuple(v.type for v in fn.inputs), num_devices) per_rank_inputs = [[] for _ in range(num_devices)] for v, a in zip(fn.inputs, inputs): per_rank_inputs[v.type.device.device_id - 1].append(a) + # for xs, per_rank_fn in zip(per_rank_inputs, per_rank_fns): + # print(*(x.shape for x in xs)) + # cpprint(per_rank_fn) - return run_multiprocesses(per_rank_fns, per_rank_inputs) + if debug_mock: + return run_mock_multiprocess(per_rank_fns, per_rank_inputs) + else: + return run_multiprocesses(per_rank_fns, per_rank_inputs) From 5d99a62abdc805a22d581274385ca1b22d0566c3 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Thu, 6 May 2021 17:16:56 +0100 Subject: [PATCH 38/61] Use spawn start method for multiprocessing --- dist_ir/backend/torch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index c47c7c96..0653c267 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -310,7 +310,8 @@ def run_multiprocesses( per_rank_runner = partial( run_process, world_size, io_dir, num_warmup, num_repetitions ) - with torch.multiprocessing.Pool(world_size) as p: + ctx = torch.multiprocessing.get_context("spawn") + with ctx.Pool(world_size) as p: runtimes = p.starmap(per_rank_runner, enumerate(per_rank_functions)) # Load outputs: From 2c8852a3734eb8759c8a1998ba2b429034815320 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Thu, 6 May 2021 17:25:01 +0100 Subject: [PATCH 39/61] Fix MLP DHP tests --- test/test_pytorch_backend.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 4b7f7af0..3be5bd62 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -142,9 +142,9 @@ def test_owt(num_devices, num_layers): def test_mlp_grid_search(): - batch_size = 64 - hidden_dim = 64 - num_layers = 2 + batch_size = 2 ** 10 + hidden_dim = batch_size + num_layers = 8 world_size = 2 topology = Topology() @@ -184,8 +184,11 @@ def test_mlp_grid_search(): # TODO check outputs match? _, runtimes = run_pytorch(world_size, fn, dist_input_data) actual_time = max(np.median(times) for times in runtimes) + # actual_time = 0.0 print(fn.name, simulated_time, actual_time) + print(*(v.type.shape for v in fn.inputs)) + print(len(dist_mlp_fns)) def test_empty_device(): @@ -269,6 +272,8 @@ def new_inputs(): # test_send_recv() # test_empty_device() - # import logging + import logging + import os + # logging.basicConfig(level=logging.INFO) test_mlp_grid_search() From 1d54fea5069bfd65f918f013b376560d1fccea4d Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Thu, 6 May 2021 17:28:16 +0100 Subject: [PATCH 40/61] Revert "Fix MLP DHP tests" This reverts commit 2c8852a3734eb8759c8a1998ba2b429034815320. --- test/test_pytorch_backend.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 3be5bd62..4b7f7af0 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -142,9 +142,9 @@ def test_owt(num_devices, num_layers): def test_mlp_grid_search(): - batch_size = 2 ** 10 - hidden_dim = batch_size - num_layers = 8 + batch_size = 64 + hidden_dim = 64 + num_layers = 2 world_size = 2 topology = Topology() @@ -184,11 +184,8 @@ def test_mlp_grid_search(): # TODO check outputs match? _, runtimes = run_pytorch(world_size, fn, dist_input_data) actual_time = max(np.median(times) for times in runtimes) - # actual_time = 0.0 print(fn.name, simulated_time, actual_time) - print(*(v.type.shape for v in fn.inputs)) - print(len(dist_mlp_fns)) def test_empty_device(): @@ -272,8 +269,6 @@ def new_inputs(): # test_send_recv() # test_empty_device() - import logging - import os - + # import logging # logging.basicConfig(level=logging.INFO) test_mlp_grid_search() From 18eaa0893966cd8c12d6a6bd718168caf11cb22f Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Thu, 6 May 2021 17:28:54 +0100 Subject: [PATCH 41/61] Fix MLP DHP tests for real --- test/test_mlp_dhp_transform.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/test/test_mlp_dhp_transform.py b/test/test_mlp_dhp_transform.py index 4ff1a219..52149a89 100644 --- a/test/test_mlp_dhp_transform.py +++ b/test/test_mlp_dhp_transform.py @@ -88,11 +88,12 @@ def add_devices_to_topology(topology, num_devices): def _verify_no_hp(outputs, transformed_outputs, dp=False): - for output, transformed_output in zip(outputs, transformed_outputs): - if dp: - np.testing.assert_array_almost_equal(output, transformed_output[0]) + for i in range(len(outputs)): + if not dp: + j = i else: - np.testing.assert_array_almost_equal(output, transformed_output) + j = 2 * i + np.testing.assert_array_almost_equal(outputs[i], transformed_outputs[j]) def _verify_hp(function, transformed_function, outputs, transformed_outputs, dp=False): @@ -105,10 +106,7 @@ def _verify_hp(function, transformed_function, outputs, transformed_outputs, dp= match = re.search(f"(.*)_dp_(.*)_hp_(.*)_pp_(.*){device_suffix}", output.name) assert match is not None key = (match.group(1), match.group(2), match.group(4)) - if dp: - aggregated_outputs[key].append(v[0]) - else: - aggregated_outputs[key].append(v) + aggregated_outputs[key].append(v) for key in aggregated_outputs: output_name = key[0] if "dw" in output_name: From 0ce2fdb752dde7c14fb300c37e71136bd433e58e Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Fri, 7 May 2021 08:32:51 +0100 Subject: [PATCH 42/61] Add code to plot grid search results --- test/test_pytorch_backend.py | 134 ++++++++++++++++++++++++++++++++++- 1 file changed, 131 insertions(+), 3 deletions(-) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 4b7f7af0..1aa7e08e 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -1,3 +1,5 @@ +from collections import defaultdict +import csv import numpy as np import pytest import torch @@ -142,9 +144,9 @@ def test_owt(num_devices, num_layers): def test_mlp_grid_search(): - batch_size = 64 - hidden_dim = 64 - num_layers = 2 + batch_size = 2 ** 10 + hidden_dim = batch_size + num_layers = 8 world_size = 2 topology = Topology() @@ -168,6 +170,7 @@ def test_mlp_grid_search(): np.random.randn(*v.type.shape).astype(np.float32) for v in seq_mlp.inputs ) + results = [] for init_fn, fn in dist_mlp_fns: # Simulate simulation = simulator.interpret(fn, (v.type for v in fn.inputs)) @@ -186,6 +189,130 @@ def test_mlp_grid_search(): actual_time = max(np.median(times) for times in runtimes) print(fn.name, simulated_time, actual_time) + results.append( + ( + world_size, + num_layers, + batch_size, + hidden_dim, + simulated_time, + actual_time, + ) + ) + + print(len(dist_mlp_fns)) + + fieldnames = [ + "world_size", + "num_layers", + "batch_size", + "hidden_dim", + "simulated_time", + "actual_time", + ] + + with open("mlp_grid_search.csv", "w") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for ( + world_size, + num_layers, + batch_size, + hidden_dim, + simulated_time, + actual_time, + ) in results: + writer.writerow( + { + "world_size": world_size, + "num_layers": num_layers, + "batch_size": batch_size, + "hidden_dim": hidden_dim, + "simulated_time": simulated_time, + "actual_time": actual_time, + } + ) + + +def plot_mlp_grid_search_results(): + import matplotlib as mpl + import matplotlib.pyplot as plt + from scipy.interpolate import interp1d + from scipy.stats import pearsonr, spearmanr + + results = [] + with open("mlp_grid_search.csv", "r") as f: + reader = csv.DictReader(f) + for row in reader: + results.append( + ( + int(row["world_size"]), + int(row["num_layers"]), + int(row["batch_size"]), + int(row["hidden_dim"]), + float(row["simulated_time"]), + float(row["actual_time"]), + ) + ) + real_throughputs = defaultdict(list) + simulated_throughputs = defaultdict(list) + for world_size, _, batch_size, _, simulated_time, actual_time in results: + real_throughputs[world_size].append(batch_size / actual_time / 1000) + simulated_throughputs[world_size].append(batch_size / simulated_time / 1000) + plt.rcParams["font.size"] = 12 + all_simulated_throughputs = [] + all_real_throughputs = [] + lines = [] + labels = ["Ideal", "Best fit"] + for world_size in simulated_throughputs: + all_real_throughputs += real_throughputs[world_size] + for world_size in simulated_throughputs: + all_simulated_throughputs += simulated_throughputs[world_size] + all_simulated_throughputs = np.array(all_simulated_throughputs) + all_real_throughputs = np.array(all_real_throughputs) + r, p = pearsonr(all_simulated_throughputs, all_real_throughputs) + print(f"Pearson's correlation: {r} (p={p})") + r, p = spearmanr(all_simulated_throughputs, all_real_throughputs) + print(f"Spearman's correlation: {r} (p={p})") + x_new = np.linspace( + min(all_simulated_throughputs.min(), all_real_throughputs.min()), + max(all_simulated_throughputs.max(), all_real_throughputs.max()), + 500, + ) + lines.append( + plt.plot(x_new, x_new, color="black", linestyle="--", label="Ideal")[0] + ) + m, b = np.polyfit(all_simulated_throughputs, all_real_throughputs, 1) + f = interp1d( + all_simulated_throughputs, m * all_simulated_throughputs + b, kind="linear" + ) + x_new = np.linspace( + all_simulated_throughputs.min(), all_simulated_throughputs.max(), 500 + ) + y_smooth = f(x_new) + lines.append( + plt.plot(x_new, y_smooth, color="orange", linestyle="-.", label="Best fit")[0] + ) + colors = ["b", "orange", "g", "purple"] + markers = ["x", "o", "^"] + plt.scatter(all_simulated_throughputs, all_real_throughputs, marker="x") + plt.grid() + plt.xticks([0, 200, 400, 600, 800, 1000]) + plt.yticks([0, 200, 400, 600, 800, 1000]) + plt.xlabel("Simulated throughput\n(1000 samples / second)") + plt.ylabel("Real throughput\n(1000 samples / second)") + plt.gca().set_aspect("equal", adjustable="box") + leg = plt.figlegend(lines, labels, loc="upper center", ncol=2) + leg.get_frame().set_linewidth(0.0) + bb = leg.get_bbox_to_anchor().transformed(plt.gca().transAxes.inverted()) + yOffset = 0 + bb.y0 += yOffset + bb.y1 += yOffset + leg.set_bbox_to_anchor(bb, transform=plt.gca().transAxes) + plt.tight_layout() + plt.savefig( + "data_parallel_simulation_performance.pdf", dpi=600, bbox_inches="tight" + ) def test_empty_device(): @@ -272,3 +399,4 @@ def new_inputs(): # import logging # logging.basicConfig(level=logging.INFO) test_mlp_grid_search() + plot_mlp_grid_search_results() From 5588b879ec95b4e541a4317690f11c2452fec829 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Fri, 7 May 2021 16:33:08 +0000 Subject: [PATCH 43/61] Don't use globals while multiprocessing --- dist_ir/backend/torch.py | 127 +++++++++++++++++++++++++-------------- 1 file changed, 81 insertions(+), 46 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 0653c267..84401d92 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -5,7 +5,7 @@ import os from tempfile import TemporaryDirectory from time import perf_counter -from typing import Any, List, Tuple +from typing import Any, Dict, List, NamedTuple, Tuple import torch import torch.distributed as dist @@ -14,28 +14,34 @@ from ..executor.rank_projector import project from ..ir import Function -_use_gpu = False -_groups = None +DistributedContext = NamedTuple( + "DistributedContext", use_gpu=bool, groups=Dict[Tuple[int, int], Any] # Any->Group +) -def _init_p2p_groups(): + +def _init_distributed_context(use_gpu): """Since torch.distributed's NCCL backed doesn't support P2P communication, we create a group for each pair of ranks and use broadcasts to emulate P2P send/recv. This method initializes the groups. """ - global _use_gpu, _groups - if _use_gpu: + groups = {} + if use_gpu: world_size = dist.get_world_size() - _groups = {} for i, j in combinations(range(world_size), 2): - _groups[i, j] = dist.new_group([i, j]) + groups[i, j] = dist.new_group([i, j]) + return DistributedContext(use_gpu=use_gpu, groups=groups) + + +def _add(x, y, ctx=None): + return torch.add(x, y) # TODO kwargs of these functions are required, enforce this somewhere -def _allgather(x_i, dim=0): +def _allgather(x_i, dim=0, ctx=None): world_size = dist.get_world_size() xs = [torch.zeros_like(x_i) for _ in range(world_size)] - if _use_gpu: + if ctx.use_gpu: xs = [x.cuda(dist.get_rank()) for x in xs] dist.all_gather(xs, x_i) @@ -43,75 +49,82 @@ def _allgather(x_i, dim=0): return x -def _allreduce(x): +def _allreduce(x, ctx=None): dist.all_reduce(x) return x -def _concat2(x, y, dim=None): +def _concat2(x, y, dim=None, ctx=None): return torch.cat((x, y), dim=dim) -def _identity(x): +def _identity(x, ctx=None): return x -def _loss(x, y, N=None): +def _loss(x, y, N=None, ctx=None): return torch.square(x - y) / N -def _loss_grad(x, y, N=None): +def _loss_grad(x, y, N=None, ctx=None): return 2 * (x - y) / N -def _matmul_grad(x, y, dz): +def _matmul(x, y, ctx=None): + return torch.matmul(x, y) + + +def _matmul_grad(x, y, dz, ctx=None): return (torch.matmul(dz, y.T), torch.matmul(x.T, dz)) -def _recv(shape=None, device=None): +def _recv(shape=None, device=None, ctx=None): x = torch.zeros(shape) # TODO pytorch rank = device_id - 1 - if _use_gpu: + if ctx.use_gpu: x = x.cuda(dist.get_rank()) src_rank = device - 1 dst_rank = dist.get_rank() - group = _groups[tuple(sorted((src_rank, dst_rank)))] + group = ctx.groups[tuple(sorted((src_rank, dst_rank)))] dist.broadcast(x, src_rank, group=group) else: dist.recv(x, device - 1) return x -def _relu_grad(x, dy): - # TODO: fix - dx = torch.zeros(dy.shape) - if _use_gpu: - dx = dx.cuda(dist.get_rank()) - dx[dy > 0] = 1 +def _relu(x, ctx=None): + return torch.relu(x) + + +def _relu_grad(x, dy, ctx=None): + dx = dy.clone() + dx[x <= 0] = 0 return dx -def _send(x, device=None): +def _send(x, device=None, ctx=None): # TODO pytorch rank = device_id - 1 - if _use_gpu: + if ctx.use_gpu: src_rank = dist.get_rank() dst_rank = device - 1 - group = _groups[tuple(sorted((src_rank, dst_rank)))] + group = ctx.groups[tuple(sorted((src_rank, dst_rank)))] dist.broadcast(x, src_rank, group=group) else: dist.send(x, device - 1) + # Note: in a proper backend, might want to concatenate multiple tensors into + # a single buffer and call a single send op _op_to_torch = { - "Add": torch.add, + "Add": _add, "Concat": _concat2, "Identity": _identity, "Loss": _loss, "LossGrad": _loss_grad, - "MatMul": torch.matmul, + "MatMul": _matmul, "MatMulGrad": _matmul_grad, "RecvP2P": _recv, - "Relu": torch.relu, + "Relu": _relu, "ReluGrad": _relu_grad, "SendP2P": _send, "MPIAllgather": _allgather, @@ -180,7 +193,14 @@ def function_to_module(fn: Function) -> torch.nn.Module: return fx.GraphModule({}, g) -def run_function(rank, fn: Function, inputs: List[Any], debug_mock=False): +def run_function( + ctx: DistributedContext, + rank: int, + fn: Function, + inputs: List[Any], + debug_mock=False, +): + # TODO free values when no longer needed op_to_torch = _mock_op_to_torch if debug_mock else _op_to_torch value_map = {} @@ -199,6 +219,7 @@ def run_function(rank, fn: Function, inputs: List[Any], debug_mock=False): logging.info(f"{rank}: {first_output} {op.op_type}") inputs = tuple(value_map[v] for v in op.inputs) kwargs = {} if op.attributes is None else {**op.attributes} + kwargs["ctx"] = ctx output = op_to_torch[op.op_type](*inputs, **kwargs) if len(op.outputs) > 1: assert isinstance(output, tuple) @@ -212,20 +233,22 @@ def run_function(rank, fn: Function, inputs: List[Any], debug_mock=False): return tuple(value_map[v] for v in fn.outputs) -def run_process(world_size, io_dir, num_warmup_steps, num_repetitions, rank, fn): +def run_process( + use_gpu, world_size, io_dir, num_warmup_steps, num_repetitions, rank, fn +): """The Python function on rank `rank` that runs module `module`.""" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "29500" - backend = "nccl" if _use_gpu else "gloo" + backend = "nccl" if use_gpu else "gloo" dist.init_process_group(backend, rank=rank, world_size=world_size) - _init_p2p_groups() + ctx = _init_distributed_context(use_gpu) per_rank_inputs = torch.load(os.path.join(io_dir.name, f"in.{rank}.pt")) # # Convert per-rank DistIR function to torch.nn.Module: # module = function_to_module(fn) - if _use_gpu: + if use_gpu: # Move module and inputs to GPU # TODO how to move interpreted non-module code to GPU? # module = module.cuda(rank) @@ -234,7 +257,7 @@ def run_process(world_size, io_dir, num_warmup_steps, num_repetitions, rank, fn) events = [] def add_event(): - if _use_gpu: + if use_gpu: events.append(torch.cuda.Event(enable_timing=True)) events[-1].record() else: @@ -244,18 +267,18 @@ def add_event(): add_event() for _ in range(num_warmup_steps + num_repetitions): # res = module(*per_rank_inputs) - res = run_function(rank, fn, per_rank_inputs) + res = run_function(ctx, rank, fn, per_rank_inputs) if world_size > 1: torch.distributed.barrier() add_event() - if _use_gpu: + if use_gpu: # Move outputs back to cpu res = [t.cpu() for t in res] torch.save(res, os.path.join(io_dir.name, f"out.{rank}.pt")) - if _use_gpu: + if use_gpu: runtimes = [ events[i].elapsed_time(events[i + 1]) / 1e3 for i in range(len(events) - 1) ] @@ -293,6 +316,7 @@ def run_mock_multiprocess( def run_multiprocesses( per_rank_functions: Tuple[Function], per_rank_inputs: Tuple[Any], + use_gpu=False, num_repetitions=1, num_warmup=0, ): @@ -308,7 +332,7 @@ def run_multiprocesses( global run_process per_rank_runner = partial( - run_process, world_size, io_dir, num_warmup, num_repetitions + run_process, use_gpu, world_size, io_dir, num_warmup, num_repetitions ) ctx = torch.multiprocessing.get_context("spawn") with ctx.Pool(world_size) as p: @@ -323,16 +347,21 @@ def run_multiprocesses( return per_rank_outputs, runtimes -def run_pytorch(num_devices, fn, inputs, use_gpu=False, debug_mock=False): +def run_pytorch( + num_devices, + fn, + inputs, + use_gpu=True, + num_repetitions=1, + num_warmup=0, + debug_mock=False, +): """Project `fn` and run on `inputs` over `num_devices` devices using the PyTorch backend. """ # TODO check that fn uses devices [0...num_devices), # or run through and find max device used - global _use_gpu - _use_gpu = use_gpu - # from ..ir import cpprint # print(*(x.shape for x in inputs)) # cpprint(fn) @@ -349,4 +378,10 @@ def run_pytorch(num_devices, fn, inputs, use_gpu=False, debug_mock=False): if debug_mock: return run_mock_multiprocess(per_rank_fns, per_rank_inputs) else: - return run_multiprocesses(per_rank_fns, per_rank_inputs) + return run_multiprocesses( + per_rank_fns, + per_rank_inputs, + use_gpu=use_gpu, + num_repetitions=num_repetitions, + num_warmup=num_warmup, + ) From 7d714d45978c734f4cff268c459c410c3afdc573 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Fri, 7 May 2021 16:42:19 +0000 Subject: [PATCH 44/61] Fix mock backend, use_gpu=False by default --- dist_ir/backend/torch.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 84401d92..2e793ecf 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -137,22 +137,22 @@ def _send(x, device=None, ctx=None): _mock_world_size = None -def _mock_allgather(x_i, dim=0): +def _mock_allgather(x_i, dim=0, ctx=None): xs = [torch.zeros_like(x_i) for _ in range(_mock_world_size)] x = torch.cat(xs, dim=dim) return x -def _mock_allreduce(x): +def _mock_allreduce(x, ctx=None): return x -def _mock_recv(shape=None, device=None): +def _mock_recv(shape=None, device=None, ctx=None): x = torch.zeros(shape) return x -def _mock_send(x, device=None): +def _mock_send(x, device=None, ctx=None): pass @@ -299,9 +299,10 @@ def run_mock_multiprocess( assert len(per_rank_functions) == len(per_rank_inputs) global _mock_world_size _mock_world_size = len(per_rank_functions) + ctx = DistributedContext(use_gpu=False, groups=None) per_rank_outputs = [ - run_function(rank, fn, inputs, debug_mock=True) + run_function(ctx, rank, fn, inputs, debug_mock=True) for rank, fn, inputs in zip( range(_mock_world_size), per_rank_functions, per_rank_inputs ) @@ -351,7 +352,7 @@ def run_pytorch( num_devices, fn, inputs, - use_gpu=True, + use_gpu=False, num_repetitions=1, num_warmup=0, debug_mock=False, From 86468420205173f8b38d1d9d60351e5f76fdbeb1 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Fri, 7 May 2021 17:18:00 +0000 Subject: [PATCH 45/61] Partial grid search on 4 devices --- test/test_pytorch_backend.py | 75 +++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 32 deletions(-) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 1aa7e08e..59da7817 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -144,37 +144,43 @@ def test_owt(num_devices, num_layers): def test_mlp_grid_search(): - batch_size = 2 ** 10 - hidden_dim = batch_size - num_layers = 8 - world_size = 2 - - topology = Topology() - d0 = topology.add_device("gpu") - add_devices_to_topology(topology, world_size) - simulator = Simulator(CostModel(topology)) - seq_executor = SequentialExecutor("numpy") - - seq_mlp = mlp(batch_size, hidden_dim, hidden_dim, hidden_dim, num_layers, d0) - seq_mlp = infer_types(seq_mlp, seq_mlp.inputs) - configs = list( - gen_configurations([hidden_dim], [world_size], [num_layers], [batch_size]) - ) - dist_mlp_fns = [ - mlp_dist(seq_mlp, d, h, p, m, topology) for (_, _, _, d, h, p, m) in configs - ] - print(len(dist_mlp_fns)) - - # Create random input data - input_data = tuple( - np.random.randn(*v.type.shape).astype(np.float32) for v in seq_mlp.inputs - ) + batch_sizes = [2 ** i for i in range(10, 15)] + hidden_dims = [2 ** i for i in range(8, 13)] + world_sizes = [2, 4] + all_num_layers = [8, 16, 32] results = [] - for init_fn, fn in dist_mlp_fns: + for (batch_size, hidden_dim, num_layers, d, h, p, m) in gen_configurations( + hidden_dims, world_sizes, all_num_layers, batch_sizes + ): + # TODO this is just for debugging, remove + batch_size = 1024 + hidden_dim = 256 + num_layers = 8 + d = 1 + h = p = m = 2 + world_size = d * h * p + # TODO reuse seq_mlp + topology = Topology() + d0 = topology.add_device("gpu") + add_devices_to_topology(topology, world_size) + simulator = Simulator(CostModel(topology)) + seq_executor = SequentialExecutor("numpy") + seq_mlp = mlp(batch_size, hidden_dim, hidden_dim, hidden_dim, num_layers, d0) + seq_mlp = infer_types(seq_mlp, seq_mlp.inputs) + + # Create random input data + input_data = tuple( + np.random.randn(*v.type.shape).astype(np.float32) for v in seq_mlp.inputs + ) + + init_fn, fn = mlp_dist(seq_mlp, d, h, p, m, topology) + print(fn.name) + # Simulate simulation = simulator.interpret(fn, (v.type for v in fn.inputs)) simulated_time = max([simulation.timestamps[d] for d in simulation.timestamps]) + print(simulated_time) # Reference-execute init_fn to get inputs for fn dist_input_data = seq_executor.compute(init_fn, input_data) @@ -185,7 +191,16 @@ def test_mlp_grid_search(): # Measure actual execution time # TODO check outputs match? - _, runtimes = run_pytorch(world_size, fn, dist_input_data) + # _, runtimes = run_pytorch(world_size, fn, dist_input_data) + _, runtimes = run_pytorch( + world_size, + fn, + dist_input_data, + use_gpu=False, + num_repetitions=1, # TODO use 100 + num_warmup=1, + ) + # TODO or median of max? actual_time = max(np.median(times) for times in runtimes) print(fn.name, simulated_time, actual_time) @@ -200,8 +215,6 @@ def test_mlp_grid_search(): ) ) - print(len(dist_mlp_fns)) - fieldnames = [ "world_size", "num_layers", @@ -396,7 +409,5 @@ def new_inputs(): # test_send_recv() # test_empty_device() - # import logging - # logging.basicConfig(level=logging.INFO) test_mlp_grid_search() - plot_mlp_grid_search_results() + # plot_mlp_grid_search_results() From 1db0736bdb772430e3daebdfaf874c9f27f0c058 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Mon, 10 May 2021 01:17:11 +0100 Subject: [PATCH 46/61] Support collectives between a subset of ranks --- dist_ir/backend/torch.py | 70 +++++++++++++----------------- dist_ir/executor/rank_projector.py | 18 ++++++-- test/test_pytorch_backend.py | 12 ++--- 3 files changed, 50 insertions(+), 50 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 2e793ecf..36c564e1 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -1,56 +1,46 @@ from functools import partial -from itertools import combinations import logging from operator import getitem import os from tempfile import TemporaryDirectory from time import perf_counter -from typing import Any, Dict, List, NamedTuple, Tuple +from typing import Any, Dict, Iterable, List, NamedTuple, Tuple import torch import torch.distributed as dist from torch import fx from ..executor.rank_projector import project -from ..ir import Function +from ..ir import Function, cpprint DistributedContext = NamedTuple( - "DistributedContext", use_gpu=bool, groups=Dict[Tuple[int, int], Any] # Any->Group + "DistributedContext", + use_gpu=bool, + groups=Dict[Tuple[int, int], Any], # Maps tuple of ranks to ProcessGroup + groups_list=Iterable[ + Tuple[int] + ], # to store group IDs until threads can create ProcessGroups ) -def _init_distributed_context(use_gpu): - """Since torch.distributed's NCCL backed doesn't support P2P communication, - we create a group for each pair of ranks and use broadcasts to emulate P2P - send/recv. This method initializes the groups. - """ - groups = {} - if use_gpu: - world_size = dist.get_world_size() - for i, j in combinations(range(world_size), 2): - groups[i, j] = dist.new_group([i, j]) - return DistributedContext(use_gpu=use_gpu, groups=groups) - - def _add(x, y, ctx=None): return torch.add(x, y) # TODO kwargs of these functions are required, enforce this somewhere -def _allgather(x_i, dim=0, ctx=None): - world_size = dist.get_world_size() - xs = [torch.zeros_like(x_i) for _ in range(world_size)] +def _allgather(x_i, dim=0, group=None, ctx=None): + xs = [torch.zeros_like(x_i) for _ in range(len(group))] if ctx.use_gpu: xs = [x.cuda(dist.get_rank()) for x in xs] - dist.all_gather(xs, x_i) + dist.all_gather(xs, x_i, group=ctx.groups[group]) x = torch.cat(xs, dim=dim) return x -def _allreduce(x, ctx=None): - dist.all_reduce(x) +def _allreduce(x, group=None, ctx=None): + dist.all_reduce(x, group=ctx.groups[group]) return x @@ -233,22 +223,23 @@ def run_function( return tuple(value_map[v] for v in fn.outputs) -def run_process( - use_gpu, world_size, io_dir, num_warmup_steps, num_repetitions, rank, fn -): +def run_process(ctx, world_size, io_dir, num_warmup_steps, num_repetitions, rank, fn): """The Python function on rank `rank` that runs module `module`.""" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "29500" - backend = "nccl" if use_gpu else "gloo" + backend = "nccl" if ctx.use_gpu else "gloo" dist.init_process_group(backend, rank=rank, world_size=world_size) - ctx = _init_distributed_context(use_gpu) + # Create the process groups used by fn's communication ops + for group in ctx.groups_list: + ranks = tuple(d - 1 for d in group) # TODO fixme + ctx.groups[group] = dist.new_group(ranks) per_rank_inputs = torch.load(os.path.join(io_dir.name, f"in.{rank}.pt")) # # Convert per-rank DistIR function to torch.nn.Module: # module = function_to_module(fn) - if use_gpu: + if ctx.use_gpu: # Move module and inputs to GPU # TODO how to move interpreted non-module code to GPU? # module = module.cuda(rank) @@ -257,7 +248,7 @@ def run_process( events = [] def add_event(): - if use_gpu: + if ctx.use_gpu: events.append(torch.cuda.Event(enable_timing=True)) events[-1].record() else: @@ -272,13 +263,13 @@ def add_event(): torch.distributed.barrier() add_event() - if use_gpu: + if ctx.use_gpu: # Move outputs back to cpu res = [t.cpu() for t in res] torch.save(res, os.path.join(io_dir.name, f"out.{rank}.pt")) - if use_gpu: + if ctx.use_gpu: runtimes = [ events[i].elapsed_time(events[i + 1]) / 1e3 for i in range(len(events) - 1) ] @@ -315,15 +306,16 @@ def run_mock_multiprocess( def run_multiprocesses( + ctx, per_rank_functions: Tuple[Function], per_rank_inputs: Tuple[Any], - use_gpu=False, num_repetitions=1, num_warmup=0, ): assert len(per_rank_functions) == len(per_rank_inputs) world_size = len(per_rank_functions) + # TODO just pass tensors instead # Save inputs for each per-rank function: io_dir = TemporaryDirectory() # print("run_multiprocess: saving I/O to:", io_dir.name) @@ -333,10 +325,10 @@ def run_multiprocesses( global run_process per_rank_runner = partial( - run_process, use_gpu, world_size, io_dir, num_warmup, num_repetitions + run_process, ctx, world_size, io_dir, num_warmup, num_repetitions ) - ctx = torch.multiprocessing.get_context("spawn") - with ctx.Pool(world_size) as p: + mp = torch.multiprocessing.get_context("spawn") + with mp.Pool(world_size) as p: runtimes = p.starmap(per_rank_runner, enumerate(per_rank_functions)) # Load outputs: @@ -363,11 +355,11 @@ def run_pytorch( # TODO check that fn uses devices [0...num_devices), # or run through and find max device used - # from ..ir import cpprint # print(*(x.shape for x in inputs)) # cpprint(fn) - per_rank_fns = project(fn, tuple(v.type for v in fn.inputs), num_devices) + per_rank_fns, groups = project(fn, tuple(v.type for v in fn.inputs), num_devices) + ctx = DistributedContext(use_gpu=use_gpu, groups={}, groups_list=groups) per_rank_inputs = [[] for _ in range(num_devices)] for v, a in zip(fn.inputs, inputs): @@ -380,9 +372,9 @@ def run_pytorch( return run_mock_multiprocess(per_rank_fns, per_rank_inputs) else: return run_multiprocesses( + ctx, per_rank_fns, per_rank_inputs, - use_gpu=use_gpu, num_repetitions=num_repetitions, num_warmup=num_warmup, ) diff --git a/dist_ir/executor/rank_projector.py b/dist_ir/executor/rank_projector.py index b7161478..0ee92c7d 100644 --- a/dist_ir/executor/rank_projector.py +++ b/dist_ir/executor/rank_projector.py @@ -1,6 +1,6 @@ from collections import defaultdict from dist_ir.executor.type_inference import TypePropRegister -from typing import Any, Dict, Sequence, Tuple +from typing import Any, Dict, Sequence, Set, Tuple from ..ir import Function, FunctionMaker, Device, Op, Value from ..ir.type import Type, Tensor @@ -14,6 +14,7 @@ class ProjectorState(AbstractState): def __init__(self, function: Function, inputs: Sequence[Any]): AbstractState.__init__(self, function, inputs) self.per_rank_fns: Dict[Device, FunctionMaker] = defaultdict(FunctionMaker) + self.groups: Set[Tuple[int]] = set() def _get_input_devices(op: Op): @@ -38,6 +39,11 @@ def _collective_projector(op: Op, state: ProjectorState): """Projects a collective op over D devices that has D inputs and D outputs, one on each device.""" assert len(op.inputs) == len(op.outputs) + devices = {int(v.type.device.device_id) for v in op.inputs + op.outputs} + attributes = { + **(op.attributes if op.attributes is not None else {}), + "group": tuple(devices), + } for in_v, out_v in zip(op.inputs, op.outputs): assert in_v.type.device == out_v.type.device d = in_v.type.device @@ -46,7 +52,7 @@ def _collective_projector(op: Op, state: ProjectorState): op.op_type, inputs=(in_v,), output_values=(out_v,), - attributes=op.attributes, + attributes=attributes, ) state.per_rank_fns[d].ops.append(new_op) @@ -109,6 +115,11 @@ def semantics(op: Op, state: AbstractState): # Project op and add to appropriate per-rank function projector(op, state) + # If op involves more than one device, create a group + devices = {int(v.type.device.device_id) for v in op.inputs + op.outputs} + if len(devices) > 1: + state.groups.add(tuple(devices)) + return semantics signatures = set(projector_register.keys()).intersection(type_prop_register.keys()) @@ -161,6 +172,7 @@ def project( ) ) new_fn.set_outputs(tuple(value_map[v] for v in per_rank_fn.outputs)) + # TODO fix off-by-one discrepancy between DistIR device ID and torch rank result_fns[d.device_id - 1] = new_fn.finalize() - return result_fns + return result_fns, state.groups diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 59da7817..7b2f6ea5 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -144,8 +144,10 @@ def test_owt(num_devices, num_layers): def test_mlp_grid_search(): - batch_sizes = [2 ** i for i in range(10, 15)] - hidden_dims = [2 ** i for i in range(8, 13)] + # batch_sizes = [2 ** i for i in range(10, 15)] + # hidden_dims = [2 ** i for i in range(8, 13)] + batch_sizes = [2 ** 10] + hidden_dims = [2 ** 10] world_sizes = [2, 4] all_num_layers = [8, 16, 32] @@ -153,12 +155,6 @@ def test_mlp_grid_search(): for (batch_size, hidden_dim, num_layers, d, h, p, m) in gen_configurations( hidden_dims, world_sizes, all_num_layers, batch_sizes ): - # TODO this is just for debugging, remove - batch_size = 1024 - hidden_dim = 256 - num_layers = 8 - d = 1 - h = p = m = 2 world_size = d * h * p # TODO reuse seq_mlp topology = Topology() From a79ab67bbd9911b4ef409b872f40668e79b73b02 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Tue, 11 May 2021 11:30:00 -0700 Subject: [PATCH 47/61] Bug fixes for distributed process groups (#24) Uses device IDs to index into `ctx.groups` for `send` and `recv` backend ops. Also ensures tuple of device IDs in `ctx.groups --- dist_ir/backend/torch.py | 6 ++++-- dist_ir/executor/rank_projector.py | 6 ++++-- examples/__init__.py | 0 3 files changed, 8 insertions(+), 4 deletions(-) create mode 100644 examples/__init__.py diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 36c564e1..14ec9ccf 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -75,7 +75,8 @@ def _recv(shape=None, device=None, ctx=None): x = x.cuda(dist.get_rank()) src_rank = device - 1 dst_rank = dist.get_rank() - group = ctx.groups[tuple(sorted((src_rank, dst_rank)))] + group_key = tuple(sorted(device, dst_rank + 1)) + group = ctx.groups[group_key] dist.broadcast(x, src_rank, group=group) else: dist.recv(x, device - 1) @@ -97,7 +98,8 @@ def _send(x, device=None, ctx=None): if ctx.use_gpu: src_rank = dist.get_rank() dst_rank = device - 1 - group = ctx.groups[tuple(sorted((src_rank, dst_rank)))] + group_key = tuple(sorted((src_rank - 1, device))) + group = ctx.groups[group_key] dist.broadcast(x, src_rank, group=group) else: dist.send(x, device - 1) diff --git a/dist_ir/executor/rank_projector.py b/dist_ir/executor/rank_projector.py index 0ee92c7d..a1f65ca1 100644 --- a/dist_ir/executor/rank_projector.py +++ b/dist_ir/executor/rank_projector.py @@ -116,9 +116,11 @@ def semantics(op: Op, state: AbstractState): projector(op, state) # If op involves more than one device, create a group - devices = {int(v.type.device.device_id) for v in op.inputs + op.outputs} + devices = {v.device.device_id for v in outputs}.union( + {int(v.type.device.device_id) for v in op.inputs} + ) if len(devices) > 1: - state.groups.add(tuple(devices)) + state.groups.add(tuple(sorted(devices))) return semantics diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b From 0c0f7f5d7972549d88e6da7b09691f34eeec97a3 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Thu, 13 May 2021 20:09:20 +0100 Subject: [PATCH 48/61] Debugging MLP deadlock --- test/test_pytorch_backend.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 7b2f6ea5..b1b66ab3 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -146,15 +146,19 @@ def test_owt(num_devices, num_layers): def test_mlp_grid_search(): # batch_sizes = [2 ** i for i in range(10, 15)] # hidden_dims = [2 ** i for i in range(8, 13)] - batch_sizes = [2 ** 10] - hidden_dims = [2 ** 10] - world_sizes = [2, 4] - all_num_layers = [8, 16, 32] + batch_sizes = [64] + hidden_dims = [64] + world_sizes = [1, 2, 4, 8] + all_num_layers = [32] results = [] for (batch_size, hidden_dim, num_layers, d, h, p, m) in gen_configurations( hidden_dims, world_sizes, all_num_layers, batch_sizes ): + # TODO this is to debug mlp_1_2_4_2. Remove when fixed + d = 1 + h = m = 2 + p = 4 world_size = d * h * p # TODO reuse seq_mlp topology = Topology() @@ -210,6 +214,7 @@ def test_mlp_grid_search(): actual_time, ) ) + return # TODO remove after debugging fieldnames = [ "world_size", From 33cefced5e8d2df89922a6a6e7d693ce2234bcb0 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Thu, 13 May 2021 22:14:16 -0700 Subject: [PATCH 49/61] Fix collective projector --- dist_ir/executor/rank_projector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dist_ir/executor/rank_projector.py b/dist_ir/executor/rank_projector.py index a1f65ca1..1af81029 100644 --- a/dist_ir/executor/rank_projector.py +++ b/dist_ir/executor/rank_projector.py @@ -42,7 +42,7 @@ def _collective_projector(op: Op, state: ProjectorState): devices = {int(v.type.device.device_id) for v in op.inputs + op.outputs} attributes = { **(op.attributes if op.attributes is not None else {}), - "group": tuple(devices), + "group": tuple(sorted(devices)), } for in_v, out_v in zip(op.inputs, op.outputs): assert in_v.type.device == out_v.type.device From a5c8b34877e9237a39f7fd191d4636d44fd29049 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Thu, 13 May 2021 23:45:22 +0100 Subject: [PATCH 50/61] Enable grid search test again --- test/test_pytorch_backend.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index b1b66ab3..697f12c2 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -155,10 +155,6 @@ def test_mlp_grid_search(): for (batch_size, hidden_dim, num_layers, d, h, p, m) in gen_configurations( hidden_dims, world_sizes, all_num_layers, batch_sizes ): - # TODO this is to debug mlp_1_2_4_2. Remove when fixed - d = 1 - h = m = 2 - p = 4 world_size = d * h * p # TODO reuse seq_mlp topology = Topology() @@ -214,7 +210,6 @@ def test_mlp_grid_search(): actual_time, ) ) - return # TODO remove after debugging fieldnames = [ "world_size", From ca68ec416b8ad3ca3cd0d83d75882ba140e376cf Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Fri, 21 May 2021 00:15:22 +0100 Subject: [PATCH 51/61] Some comments, debugging code, cuda sync earlier --- dist_ir/backend/torch.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 14ec9ccf..83c08f3b 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -1,7 +1,7 @@ from functools import partial -import logging from operator import getitem import os +import sys from tempfile import TemporaryDirectory from time import perf_counter from typing import Any, Dict, Iterable, List, NamedTuple, Tuple @@ -11,7 +11,7 @@ from torch import fx from ..executor.rank_projector import project -from ..ir import Function, cpprint +from ..ir import Function, cpprint, pformat DistributedContext = NamedTuple( @@ -24,6 +24,9 @@ ) +# TODO organize by category + + def _add(x, y, ctx=None): return torch.add(x, y) @@ -162,8 +165,6 @@ def function_to_module(fn: Function) -> torch.nn.Module: g = fx.Graph() value_map = {} - # TODO need to check that fn has unique value names - # Convert inputs for v in fn.inputs: value_map[v] = g.placeholder(v.name) @@ -203,12 +204,9 @@ def run_function( # Run ops for op in fn.ops: - first_output = ( - op.outputs[0].name - if op.outputs is not None and len(op.outputs) > 0 - else "None" - ) - logging.info(f"{rank}: {first_output} {op.op_type}") + # op_str = pformat(op).replace("\n", " ") + # print(f"{rank}: {op_str}") + # sys.stdout.flush() inputs = tuple(value_map[v] for v in op.inputs) kwargs = {} if op.attributes is None else {**op.attributes} kwargs["ctx"] = ctx @@ -219,7 +217,8 @@ def run_function( value_map[v] = output[i] elif len(op.outputs) == 1: value_map[op.outputs[0]] = output - logging.info(f"{rank}: {first_output} {op.op_type}") + # print(f"{rank}: {op_str}") + # sys.stdout.flush() # Return outputs return tuple(value_map[v] for v in fn.outputs) @@ -243,7 +242,7 @@ def run_process(ctx, world_size, io_dir, num_warmup_steps, num_repetitions, rank if ctx.use_gpu: # Move module and inputs to GPU - # TODO how to move interpreted non-module code to GPU? + # TODO check if interpreted code is running on GPU (check all inputs?) # module = module.cuda(rank) per_rank_inputs = [t.cuda(rank) for t in per_rank_inputs] @@ -272,10 +271,10 @@ def add_event(): torch.save(res, os.path.join(io_dir.name, f"out.{rank}.pt")) if ctx.use_gpu: + torch.cuda.synchronize() runtimes = [ events[i].elapsed_time(events[i + 1]) / 1e3 for i in range(len(events) - 1) ] - torch.cuda.synchronize() else: runtimes = [events[i + 1] - events[i] for i in range(len(events) - 1)] From b31209d03f03ab18b60bce86ccf9b419bcb078a9 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Fri, 21 May 2021 00:16:46 +0100 Subject: [PATCH 52/61] Projector for gather --- dist_ir/executor/rank_projector.py | 20 +++++++++++++++++ test/test_pytorch_backend.py | 35 ++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/dist_ir/executor/rank_projector.py b/dist_ir/executor/rank_projector.py index 1af81029..09029687 100644 --- a/dist_ir/executor/rank_projector.py +++ b/dist_ir/executor/rank_projector.py @@ -57,6 +57,25 @@ def _collective_projector(op: Op, state: ProjectorState): state.per_rank_fns[d].ops.append(new_op) +def _gather_projector(op: Op, state: ProjectorState): + devices = set(v.type.device for v in op.inputs) + assert len(op.inputs) == len(devices) + assert len(op.outputs) == 1 and op.outputs[0].type.device in devices + attributes = { + **(op.attributes if op.attributes is not None else {}), + "group": tuple(sorted(devices)), + } + for in_v in op.inputs: + d = in_v.type.device + new_op = Op( + op.op_type, + inputs=(in_v,), + output_values=op.outputs, # TODO only on dst device! + attributes=attributes, + ) + state.per_rank_fns[d].ops.append(new_op) + + def _send_projector(op: Op, state: ProjectorState): from_d = op.inputs[0].type.device to_d = op.attributes["device"] @@ -88,6 +107,7 @@ def _send_projector(op: Op, state: ProjectorState): ("MPIAllreduce", (Tensor,) * 4): _collective_projector, ("MPIAllreduce", (Tensor,) * 8): _collective_projector, ("MPIAllreduce", (Tensor,) * 16): _collective_projector, + ("MPIGather", (Tensor,) * 2): _gather_projector, ("Relu", (Tensor,)): _identity_projector, ("ReluGrad", (Tensor, Tensor)): _identity_projector, ("Send", (Tensor,)): _send_projector, diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 697f12c2..122ecbfa 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -143,6 +143,41 @@ def test_owt(num_devices, num_layers): assert all(np.allclose(y[0], o) for y, o in zip(per_rank_outputs, output_arrays)) +def test_dp_mp_matmuls(): + fn = FunctionMaker("dp_mp_matmuls") + B = 64 + d0 = Device(0, "gpu") + d1 = Device(1, "gpu") + x_0 = fn.add_input_value("x_0", Tensor(Float(), (B // 2, B), d0)) + x_1 = fn.add_input_value("x_1", Tensor(Float(), (B // 2, B), d1)) + wA_0 = fn.add_input_value("wA_0", Tensor(Float(), (B, B), d0)) + wA_1 = fn.add_input_value("wA_1", Tensor(Float(), (B, B), d1)) + wB_0 = fn.add_input_value("wB_0", Tensor(Float(), (B, B), d0)) + wC_1 = fn.add_input_value("wC_1", Tensor(Float(), (B, B), d1)) + a0_0 = fn.add_op("MatMul", inputs=[x_0, wA_0], output_names=["a0"]) + a1_1 = fn.add_op("MatMul", inputs=[x_1, wA_1], output_names=["a1"]) + a_0 = fn.add_op( + "MPIGather", + inputs=[a0_0, a1_1], + output_names=["a_0"], + attributes={"device": d0, "dim": 0}, + ) + b_0 = fn.add_op("MatMul", inputs=[a_0, wB_0], output_names=["b_0"]) + b_1 = fn.add_op( + "Send", inputs=[b_0], output_names=["b_1"], attributes={"device": d1} + ) + c_1 = fn.add_op("MatMul", inputs=[b_1, wC_1], output_names=["c_1"]) + fn = fn.finalize() + fn = infer_types(fn, fn.inputs) + cpprint(fn) + + from dist_ir.executor.rank_projector import project + + per_rank_fns, groups = project(fn, tuple(v.type for v in fn.inputs), 2) + for per_rank_fn in per_rank_fns: + cpprint(per_rank_fn) + + def test_mlp_grid_search(): # batch_sizes = [2 ** i for i in range(10, 15)] # hidden_dims = [2 ** i for i in range(8, 13)] From 54ce0d2160fae6e36aa2e25d3f349506aede8771 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Tue, 25 May 2021 15:20:56 +0100 Subject: [PATCH 53/61] Don't save input/outputs to file in torch backend --- dist_ir/backend/torch.py | 52 ++++++++++++++-------------------------- 1 file changed, 18 insertions(+), 34 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 83c08f3b..4fc23741 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -2,7 +2,6 @@ from operator import getitem import os import sys -from tempfile import TemporaryDirectory from time import perf_counter from typing import Any, Dict, Iterable, List, NamedTuple, Tuple @@ -224,8 +223,12 @@ def run_function( return tuple(value_map[v] for v in fn.outputs) -def run_process(ctx, world_size, io_dir, num_warmup_steps, num_repetitions, rank, fn): - """The Python function on rank `rank` that runs module `module`.""" +def run_process(ctx, world_size, num_warmup_steps, num_repetitions, rank, fn, inputs): + """The Python function on rank `rank` that runs DistIR function `fn` on + (torch) inputs `inputs`. The function is run + `num_warmup_steps + num_repetitions` times. The outputs of the last run are + returned, along with the last `num_repetitions` runtimes. + """ os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "29500" backend = "nccl" if ctx.use_gpu else "gloo" @@ -235,16 +238,11 @@ def run_process(ctx, world_size, io_dir, num_warmup_steps, num_repetitions, rank ranks = tuple(d - 1 for d in group) # TODO fixme ctx.groups[group] = dist.new_group(ranks) - per_rank_inputs = torch.load(os.path.join(io_dir.name, f"in.{rank}.pt")) - - # # Convert per-rank DistIR function to torch.nn.Module: - # module = function_to_module(fn) - if ctx.use_gpu: # Move module and inputs to GPU # TODO check if interpreted code is running on GPU (check all inputs?) # module = module.cuda(rank) - per_rank_inputs = [t.cuda(rank) for t in per_rank_inputs] + inputs = [t.cuda(rank) for t in inputs] events = [] @@ -258,17 +256,15 @@ def add_event(): # Time a bunch of executions, then execute once for output values add_event() for _ in range(num_warmup_steps + num_repetitions): - # res = module(*per_rank_inputs) - res = run_function(ctx, rank, fn, per_rank_inputs) + # res = module(*inputs) + outputs = run_function(ctx, rank, fn, inputs) if world_size > 1: torch.distributed.barrier() add_event() if ctx.use_gpu: # Move outputs back to cpu - res = [t.cpu() for t in res] - - torch.save(res, os.path.join(io_dir.name, f"out.{rank}.pt")) + outputs = [t.cpu() for t in outputs] if ctx.use_gpu: torch.cuda.synchronize() @@ -279,7 +275,7 @@ def add_event(): runtimes = [events[i + 1] - events[i] for i in range(len(events) - 1)] dist.destroy_process_group() - return runtimes[num_warmup_steps:] + return outputs, runtimes[num_warmup_steps:] def run_mock_multiprocess( @@ -315,29 +311,17 @@ def run_multiprocesses( ): assert len(per_rank_functions) == len(per_rank_inputs) world_size = len(per_rank_functions) + args = [ + (r, f, x) for (r, (f, x)) in enumerate(zip(per_rank_functions, per_rank_inputs)) + ] - # TODO just pass tensors instead - # Save inputs for each per-rank function: - io_dir = TemporaryDirectory() - # print("run_multiprocess: saving I/O to:", io_dir.name) - # TODO lowered pytorch file numbers devices 0...num_devices-1 - for d, inps in enumerate(per_rank_inputs): - torch.save(inps, os.path.join(io_dir.name, f"in.{d}.pt")) - - global run_process - per_rank_runner = partial( - run_process, ctx, world_size, io_dir, num_warmup, num_repetitions - ) + global run_process # TODO needed? + per_rank_runner = partial(run_process, ctx, world_size, num_warmup, num_repetitions) mp = torch.multiprocessing.get_context("spawn") with mp.Pool(world_size) as p: - runtimes = p.starmap(per_rank_runner, enumerate(per_rank_functions)) - - # Load outputs: - per_rank_outputs = [ - torch.load(os.path.join(io_dir.name, f"out.{d}.pt")) for d in range(world_size) - ] - io_dir.cleanup() + outputs = p.starmap(per_rank_runner, args) + per_rank_outputs, runtimes = zip(*outputs) return per_rank_outputs, runtimes From 5d1b63f2072326ee535cefaae93b5e3281a5af6f Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Tue, 25 May 2021 15:23:20 +0100 Subject: [PATCH 54/61] Remove unnecessary global --- dist_ir/backend/torch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 4fc23741..367cc3dd 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -315,7 +315,6 @@ def run_multiprocesses( (r, f, x) for (r, (f, x)) in enumerate(zip(per_rank_functions, per_rank_inputs)) ] - global run_process # TODO needed? per_rank_runner = partial(run_process, ctx, world_size, num_warmup, num_repetitions) mp = torch.multiprocessing.get_context("spawn") with mp.Pool(world_size) as p: From 3466627192fc2511a3c95285b2085a26a07ab757 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Tue, 25 May 2021 19:09:26 +0100 Subject: [PATCH 55/61] Free tensors after use --- dist_ir/backend/torch.py | 10 +++++++++- dist_ir/ir/function.py | 4 ++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 367cc3dd..12cef301 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -192,7 +192,6 @@ def run_function( inputs: List[Any], debug_mock=False, ): - # TODO free values when no longer needed op_to_torch = _mock_op_to_torch if debug_mock else _op_to_torch value_map = {} @@ -209,13 +208,21 @@ def run_function( inputs = tuple(value_map[v] for v in op.inputs) kwargs = {} if op.attributes is None else {**op.attributes} kwargs["ctx"] = ctx + output = op_to_torch[op.op_type](*inputs, **kwargs) + if len(op.outputs) > 1: assert isinstance(output, tuple) for i, v in enumerate(op.outputs): value_map[v] = output[i] elif len(op.outputs) == 1: value_map[op.outputs[0]] = output + + # Free tensors that are not used again + for v in op.inputs: + if v in value_map and fn.last_use(v) == op and not (v in fn.outputs): + del value_map[v] + # print(f"{rank}: {op_str}") # sys.stdout.flush() @@ -236,6 +243,7 @@ def run_process(ctx, world_size, num_warmup_steps, num_repetitions, rank, fn, in # Create the process groups used by fn's communication ops for group in ctx.groups_list: ranks = tuple(d - 1 for d in group) # TODO fixme + # TODO ctx is copied or shared among threads? ctx.groups[group] = dist.new_group(ranks) if ctx.use_gpu: diff --git a/dist_ir/ir/function.py b/dist_ir/ir/function.py index ab5bb89e..44d86479 100644 --- a/dist_ir/ir/function.py +++ b/dist_ir/ir/function.py @@ -62,6 +62,10 @@ def __str__(self): # TODO can we use the prettyprint output as __str__? return self.get_summary() + def last_use(self, value): + """Returns the last op that uses the given value `value`.""" + return self.consumers[value][-1] + def get_summary(self): output = "" output += "Function inputs:\n" From 7d16c90235909d8934deb9dc14eb53cd0dd175fa Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Wed, 26 May 2021 16:29:31 +0100 Subject: [PATCH 56/61] Map DistIR devices to pytorch backend ranks --- dist_ir/backend/torch.py | 86 ++++++++++++++++++------------ dist_ir/executor/rank_projector.py | 45 ++++++++++------ test/test_pytorch_backend.py | 18 +++---- 3 files changed, 86 insertions(+), 63 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 12cef301..033a7c8d 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -3,6 +3,7 @@ import os import sys from time import perf_counter +from traceback import print_exc from typing import Any, Dict, Iterable, List, NamedTuple, Tuple import torch @@ -11,15 +12,19 @@ from ..executor.rank_projector import project from ..ir import Function, cpprint, pformat +from ..ir.device import Device DistributedContext = NamedTuple( "DistributedContext", + world_size=int, use_gpu=bool, - groups=Dict[Tuple[int, int], Any], # Maps tuple of ranks to ProcessGroup - groups_list=Iterable[ - Tuple[int] - ], # to store group IDs until threads can create ProcessGroups + # Map from DistIR device to PyTorch backend rank + device_to_rank=Dict[Device, int], + # Maps tuple of ranks to ProcessGroup + groups=Dict[Tuple[int], Any], + # Temp store of group IDs until threads can create ProcessGroups + groups_list=Iterable[Tuple[int]], ) @@ -70,18 +75,14 @@ def _matmul_grad(x, y, dz, ctx=None): return (torch.matmul(dz, y.T), torch.matmul(x.T, dz)) -def _recv(shape=None, device=None, ctx=None): +def _recv(shape=None, from_d=None, group=None, ctx=None): x = torch.zeros(shape) - # TODO pytorch rank = device_id - 1 + src_rank = ctx.device_to_rank[from_d] if ctx.use_gpu: x = x.cuda(dist.get_rank()) - src_rank = device - 1 - dst_rank = dist.get_rank() - group_key = tuple(sorted(device, dst_rank + 1)) - group = ctx.groups[group_key] - dist.broadcast(x, src_rank, group=group) + dist.broadcast(x, src_rank, group=ctx.groups[group]) else: - dist.recv(x, device - 1) + dist.recv(x, src_rank) return x @@ -95,16 +96,13 @@ def _relu_grad(x, dy, ctx=None): return dx -def _send(x, device=None, ctx=None): - # TODO pytorch rank = device_id - 1 +def _send(x, to_d=None, group=None, ctx=None): if ctx.use_gpu: src_rank = dist.get_rank() - dst_rank = device - 1 - group_key = tuple(sorted((src_rank - 1, device))) - group = ctx.groups[group_key] - dist.broadcast(x, src_rank, group=group) + dist.broadcast(x, src_rank, group=ctx.groups[group]) else: - dist.send(x, device - 1) + dst_rank = ctx.device_to_rank[to_d] + dist.send(x, dst_rank) # Note: in a proper backend, might want to concatenate multiple tensors into # a single buffer and call a single send op @@ -230,7 +228,7 @@ def run_function( return tuple(value_map[v] for v in fn.outputs) -def run_process(ctx, world_size, num_warmup_steps, num_repetitions, rank, fn, inputs): +def run_process(ctx, num_warmup_steps, num_repetitions, rank, fn, inputs): """The Python function on rank `rank` that runs DistIR function `fn` on (torch) inputs `inputs`. The function is run `num_warmup_steps + num_repetitions` times. The outputs of the last run are @@ -239,10 +237,11 @@ def run_process(ctx, world_size, num_warmup_steps, num_repetitions, rank, fn, in os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "29500" backend = "nccl" if ctx.use_gpu else "gloo" - dist.init_process_group(backend, rank=rank, world_size=world_size) + dist.init_process_group(backend, rank=rank, world_size=ctx.world_size) + # Create the process groups used by fn's communication ops for group in ctx.groups_list: - ranks = tuple(d - 1 for d in group) # TODO fixme + ranks = [ctx.device_to_rank[d] for d in group] # TODO ctx is copied or shared among threads? ctx.groups[group] = dist.new_group(ranks) @@ -265,8 +264,13 @@ def add_event(): add_event() for _ in range(num_warmup_steps + num_repetitions): # res = module(*inputs) + # try: + # outputs = run_function(ctx, rank, fn, inputs) + # except Exception as e: + # print_exc() + # sys.exit(1) outputs = run_function(ctx, rank, fn, inputs) - if world_size > 1: + if ctx.world_size > 1: torch.distributed.barrier() add_event() @@ -318,14 +322,13 @@ def run_multiprocesses( num_warmup=0, ): assert len(per_rank_functions) == len(per_rank_inputs) - world_size = len(per_rank_functions) args = [ (r, f, x) for (r, (f, x)) in enumerate(zip(per_rank_functions, per_rank_inputs)) ] - per_rank_runner = partial(run_process, ctx, world_size, num_warmup, num_repetitions) + per_rank_runner = partial(run_process, ctx, num_warmup, num_repetitions) mp = torch.multiprocessing.get_context("spawn") - with mp.Pool(world_size) as p: + with mp.Pool(ctx.world_size) as p: outputs = p.starmap(per_rank_runner, args) per_rank_outputs, runtimes = zip(*outputs) @@ -333,7 +336,6 @@ def run_multiprocesses( def run_pytorch( - num_devices, fn, inputs, use_gpu=False, @@ -344,18 +346,32 @@ def run_pytorch( """Project `fn` and run on `inputs` over `num_devices` devices using the PyTorch backend. """ - # TODO check that fn uses devices [0...num_devices), - # or run through and find max device used - # print(*(x.shape for x in inputs)) # cpprint(fn) - per_rank_fns, groups = project(fn, tuple(v.type for v in fn.inputs), num_devices) - ctx = DistributedContext(use_gpu=use_gpu, groups={}, groups_list=groups) - - per_rank_inputs = [[] for _ in range(num_devices)] + device_to_fns, groups = project(fn, tuple(v.type for v in fn.inputs)) + + # Map between DistIR devices and pytorch ranks: + device_to_rank = {} + world_size = 0 + per_rank_fns = [] + for d in device_to_fns: + device_to_rank[d] = world_size + per_rank_fns.append(device_to_fns[d]) + world_size += 1 + + ctx = DistributedContext( + world_size=world_size, + use_gpu=use_gpu, + groups={}, + groups_list=list(groups), + device_to_rank=device_to_rank, + ) + + per_rank_inputs = [[] for _ in range(world_size)] for v, a in zip(fn.inputs, inputs): - per_rank_inputs[v.type.device.device_id - 1].append(a) + per_rank_inputs[device_to_rank[v.type.device]].append(a) + # for xs, per_rank_fn in zip(per_rank_inputs, per_rank_fns): # print(*(x.shape for x in xs)) # cpprint(per_rank_fn) diff --git a/dist_ir/executor/rank_projector.py b/dist_ir/executor/rank_projector.py index 09029687..85a2b36e 100644 --- a/dist_ir/executor/rank_projector.py +++ b/dist_ir/executor/rank_projector.py @@ -21,6 +21,13 @@ def _get_input_devices(op: Op): return list(set(x.type.device for x in op.inputs)) +def _make_group(devices): + """Return a hashable representation of a group of devices. This is needed by + the backend, which maps them to process groups for communication primitives. + """ + return tuple(sorted(set(devices))) + + # TODO should projectors just get the per_rank_fns dict instead of full state? @@ -39,10 +46,10 @@ def _collective_projector(op: Op, state: ProjectorState): """Projects a collective op over D devices that has D inputs and D outputs, one on each device.""" assert len(op.inputs) == len(op.outputs) - devices = {int(v.type.device.device_id) for v in op.inputs + op.outputs} + group = _make_group(v.type.device for v in op.inputs + op.outputs) attributes = { **(op.attributes if op.attributes is not None else {}), - "group": tuple(sorted(devices)), + "group": group, } for in_v, out_v in zip(op.inputs, op.outputs): assert in_v.type.device == out_v.type.device @@ -63,7 +70,7 @@ def _gather_projector(op: Op, state: ProjectorState): assert len(op.outputs) == 1 and op.outputs[0].type.device in devices attributes = { **(op.attributes if op.attributes is not None else {}), - "group": tuple(sorted(devices)), + "group": _make_group(devices), } for in_v in op.inputs: d = in_v.type.device @@ -79,14 +86,23 @@ def _gather_projector(op: Op, state: ProjectorState): def _send_projector(op: Op, state: ProjectorState): from_d = op.inputs[0].type.device to_d = op.attributes["device"] + group = _make_group((from_d, to_d)) state.per_rank_fns[from_d].ops.append( - Op("SendP2P", inputs=op.inputs, attributes={"device": to_d.device_id}) + Op( + "SendP2P", + inputs=op.inputs, + attributes={"to_d": to_d, "group": group}, + ) ) state.per_rank_fns[to_d].ops.append( Op( "RecvP2P", output_values=(op.outputs[0],), - attributes={"shape": op.inputs[0].type.shape, "device": from_d.device_id}, + attributes={ + "shape": op.inputs[0].type.shape, + "from_d": from_d, + "group": group, + }, ) ) @@ -136,11 +152,10 @@ def semantics(op: Op, state: AbstractState): projector(op, state) # If op involves more than one device, create a group - devices = {v.device.device_id for v in outputs}.union( - {int(v.type.device.device_id) for v in op.inputs} - ) - if len(devices) > 1: - state.groups.add(tuple(sorted(devices))) + devices = [v.device for v in outputs] + [v.type.device for v in op.inputs] + group = _make_group(devices) + if len(group) > 1: + state.groups.add(group) return semantics @@ -158,9 +173,7 @@ def semantics(op: Op, state: AbstractState): ) -def project( - fn: Function, input_types: Sequence[Type], num_devices: int -) -> Tuple[Function]: +def project(fn: Function, input_types: Sequence[Type]) -> Tuple[Function]: """Project fn to a sequence of per-rank functions.""" state = ProjectorState(fn, input_types) @@ -171,8 +184,8 @@ def project( state = Projector.interpret(fn, input_types, state=state) # Erase all types in per_rank_fns: - # TODO do this during projection? - result_fns = [Function(fn.name, (), (), ()) for _ in range(num_devices)] + # TODO don't use singleton types, and remove this + result_fns = {} for d, per_rank_fn in state.per_rank_fns.items(): value_map = {} new_fn = FunctionMaker(name=f"{fn.name}_{d.device_id-1}") @@ -195,6 +208,6 @@ def project( ) new_fn.set_outputs(tuple(value_map[v] for v in per_rank_fn.outputs)) # TODO fix off-by-one discrepancy between DistIR device ID and torch rank - result_fns[d.device_id - 1] = new_fn.finalize() + result_fns[d] = new_fn.finalize() return result_fns, state.groups diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 122ecbfa..2d3edbf9 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -135,9 +135,7 @@ def test_owt(num_devices, num_layers): assert all(np.allclose(y, o) for y, o in zip(ys, output_arrays)) # Run per-rank modules using PyTorch backend: - per_rank_outputs, _ = run_pytorch( - num_devices, fn, [torch.tensor(a) for a in input_arrays] - ) + per_rank_outputs, _ = run_pytorch(fn, [torch.tensor(a) for a in input_arrays]) # Check outputs: assert all(np.allclose(y[0], o) for y, o in zip(per_rank_outputs, output_arrays)) @@ -224,7 +222,6 @@ def test_mlp_grid_search(): # TODO check outputs match? # _, runtimes = run_pytorch(world_size, fn, dist_input_data) _, runtimes = run_pytorch( - world_size, fn, dist_input_data, use_gpu=False, @@ -359,9 +356,8 @@ def plot_mlp_grid_search_results(): ) -def test_empty_device(): +def test_single_device(): d1 = Device(1, "gpu") - d2 = Device(2, "gpu") fn = FunctionMaker() x = fn.add_input_value("x", Tensor(Float(), (4, 4), d1)) y = fn.add_op("MatMul", inputs=(x, x)) @@ -371,7 +367,7 @@ def test_empty_device(): x = torch.randn(4, 4) inputs = (x,) - outputs, _ = run_pytorch(2, fn, inputs) + outputs, _ = run_pytorch(fn, inputs) print(outputs) assert torch.allclose(torch.matmul(x, x), outputs[0][0]) @@ -388,7 +384,7 @@ def test_send_recv(): x = torch.randn(4, 4) inputs = (x,) - outputs, _ = run_pytorch(2, fn, inputs) + outputs, _ = run_pytorch(fn, inputs) assert torch.allclose(x, outputs[1][0]) @@ -424,9 +420,7 @@ def new_inputs(): y = torch.relu(y) # Project and run on backend: - per_rank_outputs, runtimes = run_pytorch( - num_devices, fn, convert_inputs_dp(weights, x) - ) + per_rank_outputs, runtimes = run_pytorch(fn, convert_inputs_dp(weights, x)) # Check outputs: assert torch.allclose(y, torch.cat([o[0] for o in per_rank_outputs], 0)) @@ -438,7 +432,7 @@ def new_inputs(): # test_owt(2, 4) # test_dp_mlp() # test_send_recv() - # test_empty_device() + # test_single_device() test_mlp_grid_search() # plot_mlp_grid_search_results() From 8c4f5e1d82528ecdfade768be2ae24debf9843f2 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Wed, 26 May 2021 16:41:16 +0100 Subject: [PATCH 57/61] Fix tests --- test/test_pytorch_backend.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 2d3edbf9..74985ad5 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -171,8 +171,8 @@ def test_dp_mp_matmuls(): from dist_ir.executor.rank_projector import project - per_rank_fns, groups = project(fn, tuple(v.type for v in fn.inputs), 2) - for per_rank_fn in per_rank_fns: + per_rank_fns, groups = project(fn, tuple(v.type for v in fn.inputs)) + for per_rank_fn in per_rank_fns.values(): cpprint(per_rank_fn) @@ -433,6 +433,7 @@ def new_inputs(): # test_dp_mlp() # test_send_recv() # test_single_device() + test_dp_mp_matmuls() test_mlp_grid_search() # plot_mlp_grid_search_results() From 79d4a656bf10ad397595d500ec132769dcf240f6 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Wed, 26 May 2021 16:57:50 +0100 Subject: [PATCH 58/61] Some documentation and cleanup --- dist_ir/backend/torch.py | 13 +++++++++---- dist_ir/executor/rank_projector.py | 17 +++++++++++++---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 033a7c8d..b785a5a3 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -159,6 +159,9 @@ def _mock_send(x, device=None, ctx=None): def function_to_module(fn: Function) -> torch.nn.Module: + """Deprecated. Converts a DistIR Function to a PyTorch nn.Module using + torch.fx. + """ g = fx.Graph() value_map = {} @@ -185,11 +188,13 @@ def function_to_module(fn: Function) -> torch.nn.Module: def run_function( ctx: DistributedContext, - rank: int, fn: Function, inputs: List[Any], debug_mock=False, ): + """Runs DistIR Function `fn` on `inputs` in a distributed context `ctx` by + converting each DistIR op to its torch implementation as given in _op_to_torch. + """ op_to_torch = _mock_op_to_torch if debug_mock else _op_to_torch value_map = {} @@ -265,11 +270,11 @@ def add_event(): for _ in range(num_warmup_steps + num_repetitions): # res = module(*inputs) # try: - # outputs = run_function(ctx, rank, fn, inputs) + # outputs = run_function(ctx, fn, inputs) # except Exception as e: # print_exc() # sys.exit(1) - outputs = run_function(ctx, rank, fn, inputs) + outputs = run_function(ctx, fn, inputs) if ctx.world_size > 1: torch.distributed.barrier() add_event() @@ -302,7 +307,7 @@ def run_mock_multiprocess( ctx = DistributedContext(use_gpu=False, groups=None) per_rank_outputs = [ - run_function(ctx, rank, fn, inputs, debug_mock=True) + run_function(ctx, fn, inputs, debug_mock=True) for rank, fn, inputs in zip( range(_mock_world_size), per_rank_functions, per_rank_inputs ) diff --git a/dist_ir/executor/rank_projector.py b/dist_ir/executor/rank_projector.py index 85a2b36e..cfb2862c 100644 --- a/dist_ir/executor/rank_projector.py +++ b/dist_ir/executor/rank_projector.py @@ -11,10 +11,15 @@ class ProjectorState(AbstractState): + """The Abstract Interpreter state for projection. It keeps a mapping from + Devices to per-rank Functions, and a set of Device groups that perform + collective communication. + """ + def __init__(self, function: Function, inputs: Sequence[Any]): AbstractState.__init__(self, function, inputs) self.per_rank_fns: Dict[Device, FunctionMaker] = defaultdict(FunctionMaker) - self.groups: Set[Tuple[int]] = set() + self.groups: Set[Tuple[Device]] = set() def _get_input_devices(op: Op): @@ -173,8 +178,13 @@ def semantics(op: Op, state: AbstractState): ) -def project(fn: Function, input_types: Sequence[Type]) -> Tuple[Function]: - """Project fn to a sequence of per-rank functions.""" +def project( + fn: Function, input_types: Sequence[Type] +) -> Tuple[Dict[Device, Function], Set[Tuple[Device]]]: + """Project `fn` to per-rank functions. Returns a mapping from Devices to + per-rank Functions, and a set of Device groups that perform collective + communications in `fn`. + """ state = ProjectorState(fn, input_types) # Project fn's inputs to each per-rank fn: @@ -207,7 +217,6 @@ def project(fn: Function, input_types: Sequence[Type]) -> Tuple[Function]: ) ) new_fn.set_outputs(tuple(value_map[v] for v in per_rank_fn.outputs)) - # TODO fix off-by-one discrepancy between DistIR device ID and torch rank result_fns[d] = new_fn.finalize() return result_fns, state.groups From da7ff5dbe7873ba8e16722d74c2e4b381f055384 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Wed, 26 May 2021 17:05:39 +0100 Subject: [PATCH 59/61] Fix comment --- dist_ir/backend/torch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index b785a5a3..9cdf3cbb 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -247,7 +247,7 @@ def run_process(ctx, num_warmup_steps, num_repetitions, rank, fn, inputs): # Create the process groups used by fn's communication ops for group in ctx.groups_list: ranks = [ctx.device_to_rank[d] for d in group] - # TODO ctx is copied or shared among threads? + # ctx is a curried arg, hence is thread-local and can be modified: ctx.groups[group] = dist.new_group(ranks) if ctx.use_gpu: From 72589b2976f51981e64f343f44fdc7244f934ff6 Mon Sep 17 00:00:00 2001 From: Siddharth Krishna Date: Wed, 26 May 2021 17:36:10 +0100 Subject: [PATCH 60/61] Remove experiment code and dead code --- dist_ir/backend/torch.py | 5 +- test/test_pytorch_backend.py | 123 ----------------------------------- 2 files changed, 1 insertion(+), 127 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index 9cdf3cbb..822a99d9 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -251,9 +251,7 @@ def run_process(ctx, num_warmup_steps, num_repetitions, rank, fn, inputs): ctx.groups[group] = dist.new_group(ranks) if ctx.use_gpu: - # Move module and inputs to GPU - # TODO check if interpreted code is running on GPU (check all inputs?) - # module = module.cuda(rank) + # Move inputs to GPU inputs = [t.cuda(rank) for t in inputs] events = [] @@ -268,7 +266,6 @@ def add_event(): # Time a bunch of executions, then execute once for output values add_event() for _ in range(num_warmup_steps + num_repetitions): - # res = module(*inputs) # try: # outputs = run_function(ctx, fn, inputs) # except Exception as e: diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 74985ad5..5d456a2a 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -232,128 +232,6 @@ def test_mlp_grid_search(): actual_time = max(np.median(times) for times in runtimes) print(fn.name, simulated_time, actual_time) - results.append( - ( - world_size, - num_layers, - batch_size, - hidden_dim, - simulated_time, - actual_time, - ) - ) - - fieldnames = [ - "world_size", - "num_layers", - "batch_size", - "hidden_dim", - "simulated_time", - "actual_time", - ] - - with open("mlp_grid_search.csv", "w") as f: - writer = csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - for ( - world_size, - num_layers, - batch_size, - hidden_dim, - simulated_time, - actual_time, - ) in results: - writer.writerow( - { - "world_size": world_size, - "num_layers": num_layers, - "batch_size": batch_size, - "hidden_dim": hidden_dim, - "simulated_time": simulated_time, - "actual_time": actual_time, - } - ) - - -def plot_mlp_grid_search_results(): - import matplotlib as mpl - import matplotlib.pyplot as plt - from scipy.interpolate import interp1d - from scipy.stats import pearsonr, spearmanr - - results = [] - with open("mlp_grid_search.csv", "r") as f: - reader = csv.DictReader(f) - for row in reader: - results.append( - ( - int(row["world_size"]), - int(row["num_layers"]), - int(row["batch_size"]), - int(row["hidden_dim"]), - float(row["simulated_time"]), - float(row["actual_time"]), - ) - ) - real_throughputs = defaultdict(list) - simulated_throughputs = defaultdict(list) - for world_size, _, batch_size, _, simulated_time, actual_time in results: - real_throughputs[world_size].append(batch_size / actual_time / 1000) - simulated_throughputs[world_size].append(batch_size / simulated_time / 1000) - plt.rcParams["font.size"] = 12 - all_simulated_throughputs = [] - all_real_throughputs = [] - lines = [] - labels = ["Ideal", "Best fit"] - for world_size in simulated_throughputs: - all_real_throughputs += real_throughputs[world_size] - for world_size in simulated_throughputs: - all_simulated_throughputs += simulated_throughputs[world_size] - all_simulated_throughputs = np.array(all_simulated_throughputs) - all_real_throughputs = np.array(all_real_throughputs) - r, p = pearsonr(all_simulated_throughputs, all_real_throughputs) - print(f"Pearson's correlation: {r} (p={p})") - r, p = spearmanr(all_simulated_throughputs, all_real_throughputs) - print(f"Spearman's correlation: {r} (p={p})") - x_new = np.linspace( - min(all_simulated_throughputs.min(), all_real_throughputs.min()), - max(all_simulated_throughputs.max(), all_real_throughputs.max()), - 500, - ) - lines.append( - plt.plot(x_new, x_new, color="black", linestyle="--", label="Ideal")[0] - ) - m, b = np.polyfit(all_simulated_throughputs, all_real_throughputs, 1) - f = interp1d( - all_simulated_throughputs, m * all_simulated_throughputs + b, kind="linear" - ) - x_new = np.linspace( - all_simulated_throughputs.min(), all_simulated_throughputs.max(), 500 - ) - y_smooth = f(x_new) - lines.append( - plt.plot(x_new, y_smooth, color="orange", linestyle="-.", label="Best fit")[0] - ) - colors = ["b", "orange", "g", "purple"] - markers = ["x", "o", "^"] - plt.scatter(all_simulated_throughputs, all_real_throughputs, marker="x") - plt.grid() - plt.xticks([0, 200, 400, 600, 800, 1000]) - plt.yticks([0, 200, 400, 600, 800, 1000]) - plt.xlabel("Simulated throughput\n(1000 samples / second)") - plt.ylabel("Real throughput\n(1000 samples / second)") - plt.gca().set_aspect("equal", adjustable="box") - leg = plt.figlegend(lines, labels, loc="upper center", ncol=2) - leg.get_frame().set_linewidth(0.0) - bb = leg.get_bbox_to_anchor().transformed(plt.gca().transAxes.inverted()) - yOffset = 0 - bb.y0 += yOffset - bb.y1 += yOffset - leg.set_bbox_to_anchor(bb, transform=plt.gca().transAxes) - plt.tight_layout() - plt.savefig( - "data_parallel_simulation_performance.pdf", dpi=600, bbox_inches="tight" - ) def test_single_device(): @@ -436,4 +314,3 @@ def new_inputs(): test_dp_mp_matmuls() test_mlp_grid_search() - # plot_mlp_grid_search_results() From d832fa518691946894738d1c809298f1ce2e7d88 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Fri, 28 May 2021 00:12:50 -0700 Subject: [PATCH 61/61] Add protobuf definition for DistIR --- dist_ir/ir/device.py | 10 +++++++++ dist_ir/ir/function.py | 15 ++++++++++++++ dist_ir/ir/op.py | 14 +++++++++++++ dist_ir/ir/type.py | 38 +++++++++++++++++++++++++++++++++++ dist_ir/ir/value.py | 20 +++++++++++++++++- dist_ir/proto/Makefile | 9 +++++++++ dist_ir/proto/device.proto | 10 +++++++++ dist_ir/proto/function.proto | 14 +++++++++++++ dist_ir/proto/op.proto | 15 ++++++++++++++ dist_ir/proto/type.proto | 39 ++++++++++++++++++++++++++++++++++++ dist_ir/proto/value.proto | 17 ++++++++++++++++ 11 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 dist_ir/proto/Makefile create mode 100644 dist_ir/proto/device.proto create mode 100644 dist_ir/proto/function.proto create mode 100644 dist_ir/proto/op.proto create mode 100644 dist_ir/proto/type.proto create mode 100644 dist_ir/proto/value.proto diff --git a/dist_ir/ir/device.py b/dist_ir/ir/device.py index 6aac12e1..b99df571 100644 --- a/dist_ir/ir/device.py +++ b/dist_ir/ir/device.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from typing import ClassVar +from ..proto import device_pb2 + @dataclass(frozen=True) class Device: @@ -24,3 +26,11 @@ def get_new_device_variable(cls, device_type): device_id = f"d{cls.device_variable_id}" cls.device_variable_id += 1 return Device(device_id, device_type, is_variable=True) + + def serialize_to_proto(self): + device_proto = device_pb2.Device() + device_proto.device_id = self.device_id + device_proto.device_type = self.device_type + device_proto.throughput = self.throughput + device_proto.dram_bandwidth = self.dram_bandwidth + return device_proto diff --git a/dist_ir/ir/function.py b/dist_ir/ir/function.py index 44d86479..deec3635 100644 --- a/dist_ir/ir/function.py +++ b/dist_ir/ir/function.py @@ -8,6 +8,7 @@ from .op import Op from .value import Value +from ..proto import function_pb2 @dataclass(frozen=True) @@ -137,6 +138,20 @@ def get_subfunction( subfunction.set_outputs(outputs) return subfunction.finalize() + def serialize_to_proto(self): + function_proto = function_pb2.Function() + function_proto.name = self.name + for op in self.ops: + op_proto = op.serialize_to_proto() + function_proto.ops.append(op_proto) + for inp in self.inputs: + inp_proto = inp.serialize_to_proto() + function_proto.inputs.append(inp_proto) + for output in self.outputs: + output_proto = output.serialize_to_proto() + function_proto.outputs.append(output_proto) + return function_proto + @dataclass class FunctionMaker: diff --git a/dist_ir/ir/op.py b/dist_ir/ir/op.py index 3fd2e57d..cbbbcf58 100644 --- a/dist_ir/ir/op.py +++ b/dist_ir/ir/op.py @@ -6,6 +6,7 @@ from .op_register import OpRegister from .value import Value from .type import Type +from ..proto import op_pb2 @dataclass(frozen=True) @@ -93,3 +94,16 @@ def __post_init__(self, output_names, output_types, output_values): object.__setattr__( self, "outputs", output_values ) # Can't assign to frozen field + + def serialize_to_proto(self): + op_proto = op_pb2.Op() + op_proto.op_type = self.op_type + op_proto.name = self.name + for inp in self.inputs: + inp_proto = inp.serialize_to_proto() + op_proto.inputs.append(inp_proto) + # TODO: Serialize attributes + for output in self.outputs: + output_proto = output.serialize_to_proto() + op_proto.outputs.append(output_proto) + return op_proto diff --git a/dist_ir/ir/type.py b/dist_ir/ir/type.py index 67092510..410bd933 100644 --- a/dist_ir/ir/type.py +++ b/dist_ir/ir/type.py @@ -5,6 +5,7 @@ from .device import Device from .utils import singleton +from ..proto import type_pb2 @dataclass(frozen=True) @@ -37,6 +38,9 @@ def __repr__(self): def size(self): return 4 + def serialize_to_proto(self): + return type_pb2.Int32() + @singleton class Int64(Type): @@ -49,6 +53,9 @@ def __repr__(self): def size(self): return 8 + def serialize_to_proto(self): + return type_pb2.Int64() + @singleton class Float(Type): @@ -73,6 +80,9 @@ def __repr__(self): def size(self): return 1 + def serialize_to_proto(self): + return type_pb2.Bool() + @dataclass(frozen=True) class Tensor(Type): @@ -99,6 +109,34 @@ def __repr__(self): def size(self): return reduce(mul, self.shape) * self.dtype.size + def serialize_to_proto(self): + tensor_proto = type_pb2.Tensor() + if self.device is not None: + tensor_proto.device.CopyFrom(self.device.serialize_to_proto()) + else: + # TODO: Anything? + pass + if self.shape is not None: + for dim in self.shape: + tensor_proto.shape.append(dim) + if self.dtype is not None: + if isinstance(self.dtype, Int32): + tensor_proto.dtype = type_pb2.DataType.INT32 + elif isinstance(self.dtype, Int64): + tensor_proto.dtype = type_pb2.DataType.INT64 + elif isinstance(self.dtype, Bool): + tensor_proto.dtype = type_pb2.DataType.BOOL + elif isinstance(self.dtype, Float): + raise NotImplementedError( + "Float type will be deprecated in favor of Float16 and Float32" + ) + else: + raise ValueError(f"Unknown dtype {type(self.dtype)}") + else: + # TODO: Anything? + pass + return tensor_proto + @dataclass(frozen=True) class TupleType(Type): diff --git a/dist_ir/ir/value.py b/dist_ir/ir/value.py index a22db06c..d6708247 100644 --- a/dist_ir/ir/value.py +++ b/dist_ir/ir/value.py @@ -1,6 +1,7 @@ from dataclasses import dataclass -from .type import Type +from .type import Bool, Int32, Int64, Tensor, Type +from ..proto import value_pb2 @dataclass(frozen=True, eq=False) @@ -12,3 +13,20 @@ class Value: name: str type: Type + + def serialize_to_proto(self): + value_proto = value_pb2.Value() + value_proto.name = self.name + if self.type is not None: + type_proto = self.type.serialize_to_proto() + if isinstance(self.type, type(Int32())): + value_proto.i32.CopyFrom(type_proto) + elif isinstance(self.type, type(Int64())): + value_proto.i64.CopyFrom(type_proto) + elif isinstance(self.type, type(Bool())): + value_proto.bool.CopyFrom(type_proto) + elif isinstance(self.type, Tensor): + value_proto.tensor.CopyFrom(type_proto) + else: + raise ValueError(f"Unknown type {type(self.type)}") + return value_proto diff --git a/dist_ir/proto/Makefile b/dist_ir/proto/Makefile new file mode 100644 index 00000000..a133d6fe --- /dev/null +++ b/dist_ir/proto/Makefile @@ -0,0 +1,9 @@ +all: + protoc -I=. --python_out=. ./device.proto + protoc -I=. --python_out=. ./type.proto + protoc -I=. --python_out=. ./value.proto + protoc -I=. --python_out=. ./op.proto + protoc -I=. --python_out=. ./function.proto + +clean: + rm *_pb2.py diff --git a/dist_ir/proto/device.proto b/dist_ir/proto/device.proto new file mode 100644 index 00000000..f2efe64e --- /dev/null +++ b/dist_ir/proto/device.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package dist_ir; + +message Device { + string device_id = 1; + string device_type = 2; + float throughput = 3; + float dram_bandwidth = 4; +} diff --git a/dist_ir/proto/function.proto b/dist_ir/proto/function.proto new file mode 100644 index 00000000..cab9582e --- /dev/null +++ b/dist_ir/proto/function.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package dist_ir; + +import "op.proto"; +import "value.proto"; + +message Function { + string name = 1; + repeated Op ops = 2; + repeated Value inputs = 3; + repeated Value outputs = 4; + // TODO: Add consumers? +} diff --git a/dist_ir/proto/op.proto b/dist_ir/proto/op.proto new file mode 100644 index 00000000..fe8a3cd6 --- /dev/null +++ b/dist_ir/proto/op.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package dist_ir; + +import "google/protobuf/any.proto"; +import "value.proto"; + +message Op { + string op_type = 1; + string name = 2; + repeated Value inputs = 3; + map attributes = 4; + repeated Value outputs = 5; + // TODO: Add subfunctions? +} diff --git a/dist_ir/proto/type.proto b/dist_ir/proto/type.proto new file mode 100644 index 00000000..8cadbcad --- /dev/null +++ b/dist_ir/proto/type.proto @@ -0,0 +1,39 @@ +syntax = "proto3"; + +package dist_ir; + +import "device.proto"; + +enum DataType { + INT32 = 0; + INT64 = 1; + FLOAT16 = 2; + FLOAT32 = 3; + BOOL = 4; +} + +message Int32 { + Device device = 1; +} + +message Int64 { + Device device = 1; +} + +message Float16 { + Device device = 1; +} + +message Float32 { + Device device = 1; +} + +message Bool { + Device device = 1; +} + +message Tensor { + Device device = 1; + DataType dtype = 2; + repeated uint32 shape = 3; +} diff --git a/dist_ir/proto/value.proto b/dist_ir/proto/value.proto new file mode 100644 index 00000000..fe5724e0 --- /dev/null +++ b/dist_ir/proto/value.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package dist_ir; + +import "type.proto"; + +message Value { + string name = 1; + oneof type { + Int32 i32 = 2; + Int64 i64 = 3; + Float16 f16 = 4; + Float32 f32 = 5; + Bool bool = 6; + Tensor tensor = 7; + } +}