diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 45a6f3ca..c838a8e9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -52,4 +52,4 @@ jobs: run: python setup.py install - name: Test with pytest - run: pytest + run: python -m pytest diff --git a/dist_ir/backend/__init__.py b/dist_ir/backend/__init__.py new file mode 100644 index 00000000..abef05ae --- /dev/null +++ b/dist_ir/backend/__init__.py @@ -0,0 +1 @@ +from . import torch diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py new file mode 100644 index 00000000..822a99d9 --- /dev/null +++ b/dist_ir/backend/torch.py @@ -0,0 +1,390 @@ +from functools import partial +from operator import getitem +import os +import sys +from time import perf_counter +from traceback import print_exc +from typing import Any, Dict, Iterable, List, NamedTuple, Tuple + +import torch +import torch.distributed as dist +from torch import fx + +from ..executor.rank_projector import project +from ..ir import Function, cpprint, pformat +from ..ir.device import Device + + +DistributedContext = NamedTuple( + "DistributedContext", + world_size=int, + use_gpu=bool, + # Map from DistIR device to PyTorch backend rank + device_to_rank=Dict[Device, int], + # Maps tuple of ranks to ProcessGroup + groups=Dict[Tuple[int], Any], + # Temp store of group IDs until threads can create ProcessGroups + groups_list=Iterable[Tuple[int]], +) + + +# TODO organize by category + + +def _add(x, y, ctx=None): + return torch.add(x, y) + + +# TODO kwargs of these functions are required, enforce this somewhere +def _allgather(x_i, dim=0, group=None, ctx=None): + xs = [torch.zeros_like(x_i) for _ in range(len(group))] + if ctx.use_gpu: + xs = [x.cuda(dist.get_rank()) for x in xs] + + dist.all_gather(xs, x_i, group=ctx.groups[group]) + x = torch.cat(xs, dim=dim) + return x + + +def _allreduce(x, group=None, ctx=None): + dist.all_reduce(x, group=ctx.groups[group]) + return x + + +def _concat2(x, y, dim=None, ctx=None): + return torch.cat((x, y), dim=dim) + + +def _identity(x, ctx=None): + return x + + +def _loss(x, y, N=None, ctx=None): + return torch.square(x - y) / N + + +def _loss_grad(x, y, N=None, ctx=None): + return 2 * (x - y) / N + + +def _matmul(x, y, ctx=None): + return torch.matmul(x, y) + + +def _matmul_grad(x, y, dz, ctx=None): + return (torch.matmul(dz, y.T), torch.matmul(x.T, dz)) + + +def _recv(shape=None, from_d=None, group=None, ctx=None): + x = torch.zeros(shape) + src_rank = ctx.device_to_rank[from_d] + if ctx.use_gpu: + x = x.cuda(dist.get_rank()) + dist.broadcast(x, src_rank, group=ctx.groups[group]) + else: + dist.recv(x, src_rank) + return x + + +def _relu(x, ctx=None): + return torch.relu(x) + + +def _relu_grad(x, dy, ctx=None): + dx = dy.clone() + dx[x <= 0] = 0 + return dx + + +def _send(x, to_d=None, group=None, ctx=None): + if ctx.use_gpu: + src_rank = dist.get_rank() + dist.broadcast(x, src_rank, group=ctx.groups[group]) + else: + dst_rank = ctx.device_to_rank[to_d] + dist.send(x, dst_rank) + # Note: in a proper backend, might want to concatenate multiple tensors into + # a single buffer and call a single send op + + +_op_to_torch = { + "Add": _add, + "Concat": _concat2, + "Identity": _identity, + "Loss": _loss, + "LossGrad": _loss_grad, + "MatMul": _matmul, + "MatMulGrad": _matmul_grad, + "RecvP2P": _recv, + "Relu": _relu, + "ReluGrad": _relu_grad, + "SendP2P": _send, + "MPIAllgather": _allgather, + "MPIAllreduce": _allreduce, +} + +# Some mock communication ops that return zero tensors of appropriate shape +# to be used in the sequential runner for debugging + +_mock_world_size = None + + +def _mock_allgather(x_i, dim=0, ctx=None): + xs = [torch.zeros_like(x_i) for _ in range(_mock_world_size)] + x = torch.cat(xs, dim=dim) + return x + + +def _mock_allreduce(x, ctx=None): + return x + + +def _mock_recv(shape=None, device=None, ctx=None): + x = torch.zeros(shape) + return x + + +def _mock_send(x, device=None, ctx=None): + pass + + +_mock_comm_ops = { + "RecvP2P": _mock_recv, + "SendP2P": _mock_send, + "MPIAllgather": _mock_allgather, + "MPIAllreduce": _mock_allreduce, +} + +_mock_op_to_torch = {**_op_to_torch, **_mock_comm_ops} + + +def function_to_module(fn: Function) -> torch.nn.Module: + """Deprecated. Converts a DistIR Function to a PyTorch nn.Module using + torch.fx. + """ + g = fx.Graph() + value_map = {} + + # Convert inputs + for v in fn.inputs: + value_map[v] = g.placeholder(v.name) + + # Convert ops + for op in fn.ops: + inputs = tuple(value_map[v] for v in op.inputs) + kwargs = None if op.attributes is None else {**op.attributes} + output = g.call_function(_op_to_torch[op.op_type], inputs, kwargs) + if len(op.outputs) > 1: + for i, v in enumerate(op.outputs): + value_map[v] = g.call_function(getitem, (output, i)) + elif len(op.outputs) == 1: + value_map[op.outputs[0]] = output + + # Convert outputs + g.output(tuple(value_map[v] for v in fn.outputs)) + + return fx.GraphModule({}, g) + + +def run_function( + ctx: DistributedContext, + fn: Function, + inputs: List[Any], + debug_mock=False, +): + """Runs DistIR Function `fn` on `inputs` in a distributed context `ctx` by + converting each DistIR op to its torch implementation as given in _op_to_torch. + """ + op_to_torch = _mock_op_to_torch if debug_mock else _op_to_torch + value_map = {} + + # Add inputs to value_map + for v, x in zip(fn.inputs, inputs): + value_map[v] = x + assert len(fn.inputs) == len(inputs) + + # Run ops + for op in fn.ops: + # op_str = pformat(op).replace("\n", " ") + # print(f"{rank}: {op_str}") + # sys.stdout.flush() + inputs = tuple(value_map[v] for v in op.inputs) + kwargs = {} if op.attributes is None else {**op.attributes} + kwargs["ctx"] = ctx + + output = op_to_torch[op.op_type](*inputs, **kwargs) + + if len(op.outputs) > 1: + assert isinstance(output, tuple) + for i, v in enumerate(op.outputs): + value_map[v] = output[i] + elif len(op.outputs) == 1: + value_map[op.outputs[0]] = output + + # Free tensors that are not used again + for v in op.inputs: + if v in value_map and fn.last_use(v) == op and not (v in fn.outputs): + del value_map[v] + + # print(f"{rank}: {op_str}") + # sys.stdout.flush() + + # Return outputs + return tuple(value_map[v] for v in fn.outputs) + + +def run_process(ctx, num_warmup_steps, num_repetitions, rank, fn, inputs): + """The Python function on rank `rank` that runs DistIR function `fn` on + (torch) inputs `inputs`. The function is run + `num_warmup_steps + num_repetitions` times. The outputs of the last run are + returned, along with the last `num_repetitions` runtimes. + """ + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + backend = "nccl" if ctx.use_gpu else "gloo" + dist.init_process_group(backend, rank=rank, world_size=ctx.world_size) + + # Create the process groups used by fn's communication ops + for group in ctx.groups_list: + ranks = [ctx.device_to_rank[d] for d in group] + # ctx is a curried arg, hence is thread-local and can be modified: + ctx.groups[group] = dist.new_group(ranks) + + if ctx.use_gpu: + # Move inputs to GPU + inputs = [t.cuda(rank) for t in inputs] + + events = [] + + def add_event(): + if ctx.use_gpu: + events.append(torch.cuda.Event(enable_timing=True)) + events[-1].record() + else: + events.append(perf_counter()) + + # Time a bunch of executions, then execute once for output values + add_event() + for _ in range(num_warmup_steps + num_repetitions): + # try: + # outputs = run_function(ctx, fn, inputs) + # except Exception as e: + # print_exc() + # sys.exit(1) + outputs = run_function(ctx, fn, inputs) + if ctx.world_size > 1: + torch.distributed.barrier() + add_event() + + if ctx.use_gpu: + # Move outputs back to cpu + outputs = [t.cpu() for t in outputs] + + if ctx.use_gpu: + torch.cuda.synchronize() + runtimes = [ + events[i].elapsed_time(events[i + 1]) / 1e3 for i in range(len(events) - 1) + ] + else: + runtimes = [events[i + 1] - events[i] for i in range(len(events) - 1)] + + dist.destroy_process_group() + return outputs, runtimes[num_warmup_steps:] + + +def run_mock_multiprocess( + per_rank_functions: Tuple[Function], + per_rank_inputs: Tuple[Any], + num_repetitions=1, + num_warmup=0, +): + assert len(per_rank_functions) == len(per_rank_inputs) + global _mock_world_size + _mock_world_size = len(per_rank_functions) + ctx = DistributedContext(use_gpu=False, groups=None) + + per_rank_outputs = [ + run_function(ctx, fn, inputs, debug_mock=True) + for rank, fn, inputs in zip( + range(_mock_world_size), per_rank_functions, per_rank_inputs + ) + ] + mock_runtimes = [ + [0.0 for _ in range(num_warmup + num_repetitions)] + for _ in range(_mock_world_size) + ] + return (per_rank_outputs, mock_runtimes) + + +def run_multiprocesses( + ctx, + per_rank_functions: Tuple[Function], + per_rank_inputs: Tuple[Any], + num_repetitions=1, + num_warmup=0, +): + assert len(per_rank_functions) == len(per_rank_inputs) + args = [ + (r, f, x) for (r, (f, x)) in enumerate(zip(per_rank_functions, per_rank_inputs)) + ] + + per_rank_runner = partial(run_process, ctx, num_warmup, num_repetitions) + mp = torch.multiprocessing.get_context("spawn") + with mp.Pool(ctx.world_size) as p: + outputs = p.starmap(per_rank_runner, args) + + per_rank_outputs, runtimes = zip(*outputs) + return per_rank_outputs, runtimes + + +def run_pytorch( + fn, + inputs, + use_gpu=False, + num_repetitions=1, + num_warmup=0, + debug_mock=False, +): + """Project `fn` and run on `inputs` over `num_devices` devices using the + PyTorch backend. + """ + # print(*(x.shape for x in inputs)) + # cpprint(fn) + + device_to_fns, groups = project(fn, tuple(v.type for v in fn.inputs)) + + # Map between DistIR devices and pytorch ranks: + device_to_rank = {} + world_size = 0 + per_rank_fns = [] + for d in device_to_fns: + device_to_rank[d] = world_size + per_rank_fns.append(device_to_fns[d]) + world_size += 1 + + ctx = DistributedContext( + world_size=world_size, + use_gpu=use_gpu, + groups={}, + groups_list=list(groups), + device_to_rank=device_to_rank, + ) + + per_rank_inputs = [[] for _ in range(world_size)] + for v, a in zip(fn.inputs, inputs): + per_rank_inputs[device_to_rank[v.type.device]].append(a) + + # for xs, per_rank_fn in zip(per_rank_inputs, per_rank_fns): + # print(*(x.shape for x in xs)) + # cpprint(per_rank_fn) + + if debug_mock: + return run_mock_multiprocess(per_rank_fns, per_rank_inputs) + else: + return run_multiprocesses( + ctx, + per_rank_fns, + per_rank_inputs, + num_repetitions=num_repetitions, + num_warmup=num_warmup, + ) diff --git a/dist_ir/executor/__init__.py b/dist_ir/executor/__init__.py index 65509f53..8fa99963 100644 --- a/dist_ir/executor/__init__.py +++ b/dist_ir/executor/__init__.py @@ -2,3 +2,4 @@ from .sequential_executor import SequentialExecutor from .type_inference import infer_types from .absint import AbstractInterpreter, AbstractState +from .rank_projector import project diff --git a/dist_ir/executor/rank_projector.py b/dist_ir/executor/rank_projector.py new file mode 100644 index 00000000..cfb2862c --- /dev/null +++ b/dist_ir/executor/rank_projector.py @@ -0,0 +1,222 @@ +from collections import defaultdict +from dist_ir.executor.type_inference import TypePropRegister +from typing import Any, Dict, Sequence, Set, Tuple + +from ..ir import Function, FunctionMaker, Device, Op, Value +from ..ir.type import Type, Tensor +from .absint import AbstractState, AbstractInterpreter + + +# TODO merge this with torch backend -- it breaks semantics to have P2P send/recv + + +class ProjectorState(AbstractState): + """The Abstract Interpreter state for projection. It keeps a mapping from + Devices to per-rank Functions, and a set of Device groups that perform + collective communication. + """ + + def __init__(self, function: Function, inputs: Sequence[Any]): + AbstractState.__init__(self, function, inputs) + self.per_rank_fns: Dict[Device, FunctionMaker] = defaultdict(FunctionMaker) + self.groups: Set[Tuple[Device]] = set() + + +def _get_input_devices(op: Op): + return list(set(x.type.device for x in op.inputs)) + + +def _make_group(devices): + """Return a hashable representation of a group of devices. This is needed by + the backend, which maps them to process groups for communication primitives. + """ + return tuple(sorted(set(devices))) + + +# TODO should projectors just get the per_rank_fns dict instead of full state? + + +def _identity_projector(op: Op, state: ProjectorState): + """Projects op unchanged to its device's per-rank program. + The inputs of op must all be on a single device. + """ + devices = _get_input_devices(op) + assert len(devices) == 1 and devices[0] is not None + + state.per_rank_fns[devices[0]].ops.append(op) + # state.per_rank_fns[d].add_op(op.op_type, name=op.name, inputs=op.inputs, ) + + +def _collective_projector(op: Op, state: ProjectorState): + """Projects a collective op over D devices that has D inputs and D outputs, + one on each device.""" + assert len(op.inputs) == len(op.outputs) + group = _make_group(v.type.device for v in op.inputs + op.outputs) + attributes = { + **(op.attributes if op.attributes is not None else {}), + "group": group, + } + for in_v, out_v in zip(op.inputs, op.outputs): + assert in_v.type.device == out_v.type.device + d = in_v.type.device + + new_op = Op( + op.op_type, + inputs=(in_v,), + output_values=(out_v,), + attributes=attributes, + ) + state.per_rank_fns[d].ops.append(new_op) + + +def _gather_projector(op: Op, state: ProjectorState): + devices = set(v.type.device for v in op.inputs) + assert len(op.inputs) == len(devices) + assert len(op.outputs) == 1 and op.outputs[0].type.device in devices + attributes = { + **(op.attributes if op.attributes is not None else {}), + "group": _make_group(devices), + } + for in_v in op.inputs: + d = in_v.type.device + new_op = Op( + op.op_type, + inputs=(in_v,), + output_values=op.outputs, # TODO only on dst device! + attributes=attributes, + ) + state.per_rank_fns[d].ops.append(new_op) + + +def _send_projector(op: Op, state: ProjectorState): + from_d = op.inputs[0].type.device + to_d = op.attributes["device"] + group = _make_group((from_d, to_d)) + state.per_rank_fns[from_d].ops.append( + Op( + "SendP2P", + inputs=op.inputs, + attributes={"to_d": to_d, "group": group}, + ) + ) + state.per_rank_fns[to_d].ops.append( + Op( + "RecvP2P", + output_values=(op.outputs[0],), + attributes={ + "shape": op.inputs[0].type.shape, + "from_d": from_d, + "group": group, + }, + ) + ) + + +ProjectorRegister = { + ("Add", (Tensor, Tensor)): _identity_projector, + ("Concat", (Tensor, Tensor)): _identity_projector, + ("Identity", (Tensor,)): _identity_projector, + ("Loss", (Tensor, Tensor)): _identity_projector, + ("LossGrad", (Tensor, Tensor)): _identity_projector, + ("MatMul", (Tensor, Tensor)): _identity_projector, + ("MatMulGrad", (Tensor, Tensor, Tensor)): _identity_projector, + ("MPIAllgather", (Tensor,) * 2): _collective_projector, + ("MPIAllgather", (Tensor,) * 4): _collective_projector, + ("MPIAllgather", (Tensor,) * 8): _collective_projector, + ("MPIAllgather", (Tensor,) * 16): _collective_projector, + ("MPIAllreduce", (Tensor,) * 2): _collective_projector, + ("MPIAllreduce", (Tensor,) * 4): _collective_projector, + ("MPIAllreduce", (Tensor,) * 8): _collective_projector, + ("MPIAllreduce", (Tensor,) * 16): _collective_projector, + ("MPIGather", (Tensor,) * 2): _gather_projector, + ("Relu", (Tensor,)): _identity_projector, + ("ReluGrad", (Tensor, Tensor)): _identity_projector, + ("Send", (Tensor,)): _send_projector, +} + + +def _create_semantics(type_prop_register, projector_register): + """Creates a semantics for AbstractInterpreter by combining a register of + projector functions and the type propagation register. + """ + + def convert_impl(type_prop_fn, projector): + def semantics(op: Op, state: AbstractState): + # Find the op's inputs in state's environment + inputs = tuple(state.env[v] for v in op.inputs) + # Run the type propagation function + outputs = type_prop_fn(op, *inputs) + + # Write outputs to state's environment + if not isinstance(outputs, tuple): + outputs = (outputs,) + for x, val in zip(op.outputs, outputs): + state.env[x] = val + + # Project op and add to appropriate per-rank function + projector(op, state) + + # If op involves more than one device, create a group + devices = [v.device for v in outputs] + [v.type.device for v in op.inputs] + group = _make_group(devices) + if len(group) > 1: + state.groups.add(group) + + return semantics + + signatures = set(projector_register.keys()).intersection(type_prop_register.keys()) + + return { + f: convert_impl(type_prop_register[f], projector_register[f]) + for f in signatures + } + + +Projector = AbstractInterpreter( + AbstractState=ProjectorState, + semantics=_create_semantics(TypePropRegister, ProjectorRegister), +) + + +def project( + fn: Function, input_types: Sequence[Type] +) -> Tuple[Dict[Device, Function], Set[Tuple[Device]]]: + """Project `fn` to per-rank functions. Returns a mapping from Devices to + per-rank Functions, and a set of Device groups that perform collective + communications in `fn`. + """ + state = ProjectorState(fn, input_types) + + # Project fn's inputs to each per-rank fn: + for v in fn.inputs: + state.per_rank_fns[v.type.device].inputs.append(v) + + state = Projector.interpret(fn, input_types, state=state) + + # Erase all types in per_rank_fns: + # TODO don't use singleton types, and remove this + result_fns = {} + for d, per_rank_fn in state.per_rank_fns.items(): + value_map = {} + new_fn = FunctionMaker(name=f"{fn.name}_{d.device_id-1}") + for v in per_rank_fn.inputs: + value_map[v] = new_fn.add_input_value(v.name, None) + for op in per_rank_fn.ops: + new_inputs = tuple(value_map[v] for v in op.inputs) + for v in op.outputs: + value_map[v] = Value(v.name, None) + new_outputs = tuple(value_map[v] for v in op.outputs) + new_fn.ops.append( + Op( + op.op_type, + name=op.name, + inputs=new_inputs, + attributes=op.attributes, + subfunctions=op.subfunctions, + output_values=new_outputs, + ) + ) + new_fn.set_outputs(tuple(value_map[v] for v in per_rank_fn.outputs)) + result_fns[d] = new_fn.finalize() + + return result_fns, state.groups diff --git a/dist_ir/executor/sequential_executor.py b/dist_ir/executor/sequential_executor.py index 9193daf5..5bd209a7 100644 --- a/dist_ir/executor/sequential_executor.py +++ b/dist_ir/executor/sequential_executor.py @@ -47,7 +47,7 @@ def compute(self, function: Function, inputs: Sequence[Any]) -> Dict[Value, Any] inputs: A sequence of input data represented in the specified backend. Returns: - A map from output value to output data. + A tuple of outputs. """ state = self.interpreter.interpret(function, inputs) return tuple(state.env[v] for v in function.outputs) diff --git a/dist_ir/executor/type_inference.py b/dist_ir/executor/type_inference.py index 3b0f29c9..c6393d20 100644 --- a/dist_ir/executor/type_inference.py +++ b/dist_ir/executor/type_inference.py @@ -576,7 +576,7 @@ def semantics(op: Op, state: AbstractState): def _type_function(function: Function, type_map: Dict[Value, Type]) -> Function: """Create a typed version of function, using the types given in type map.""" - new_function = FunctionMaker() + new_function = FunctionMaker(name=function.name) # A Map from function's values to new_function's (typed) values: value_map: Dict[Value, Value] = {} diff --git a/dist_ir/ir/device.py b/dist_ir/ir/device.py index 6aac12e1..b99df571 100644 --- a/dist_ir/ir/device.py +++ b/dist_ir/ir/device.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from typing import ClassVar +from ..proto import device_pb2 + @dataclass(frozen=True) class Device: @@ -24,3 +26,11 @@ def get_new_device_variable(cls, device_type): device_id = f"d{cls.device_variable_id}" cls.device_variable_id += 1 return Device(device_id, device_type, is_variable=True) + + def serialize_to_proto(self): + device_proto = device_pb2.Device() + device_proto.device_id = self.device_id + device_proto.device_type = self.device_type + device_proto.throughput = self.throughput + device_proto.dram_bandwidth = self.dram_bandwidth + return device_proto diff --git a/dist_ir/ir/function.py b/dist_ir/ir/function.py index ab5bb89e..deec3635 100644 --- a/dist_ir/ir/function.py +++ b/dist_ir/ir/function.py @@ -8,6 +8,7 @@ from .op import Op from .value import Value +from ..proto import function_pb2 @dataclass(frozen=True) @@ -62,6 +63,10 @@ def __str__(self): # TODO can we use the prettyprint output as __str__? return self.get_summary() + def last_use(self, value): + """Returns the last op that uses the given value `value`.""" + return self.consumers[value][-1] + def get_summary(self): output = "" output += "Function inputs:\n" @@ -133,6 +138,20 @@ def get_subfunction( subfunction.set_outputs(outputs) return subfunction.finalize() + def serialize_to_proto(self): + function_proto = function_pb2.Function() + function_proto.name = self.name + for op in self.ops: + op_proto = op.serialize_to_proto() + function_proto.ops.append(op_proto) + for inp in self.inputs: + inp_proto = inp.serialize_to_proto() + function_proto.inputs.append(inp_proto) + for output in self.outputs: + output_proto = output.serialize_to_proto() + function_proto.outputs.append(output_proto) + return function_proto + @dataclass class FunctionMaker: diff --git a/dist_ir/ir/op.py b/dist_ir/ir/op.py index bd1cae25..cbbbcf58 100644 --- a/dist_ir/ir/op.py +++ b/dist_ir/ir/op.py @@ -6,6 +6,7 @@ from .op_register import OpRegister from .value import Value from .type import Type +from ..proto import op_pb2 @dataclass(frozen=True) @@ -23,6 +24,20 @@ class Op: output_values: InitVar[Tuple[Value]] = None def __post_init__(self, output_names, output_types, output_values): + # Check output_{names,types,values} have same length + given_num_outputs = set( + len(x) for x in [output_names, output_types, output_values] if x is not None + ) + if len(given_num_outputs) == 0: + given_num_outputs = None + elif len(given_num_outputs) == 1: + given_num_outputs = list(given_num_outputs)[0] + else: + raise ValueError( + "output_{names,types,values} must have same length. Got:\n" + f"{output_names}\n{output_types}\n{output_values}" + ) + if self.op_type == "Pmap": # Handle pmap specially assert len(self.subfunctions) == 1 @@ -46,22 +61,18 @@ def __post_init__(self, output_names, output_types, output_values): ) # Number of outputs is given by OpRegister if OpRegister[self.op_type].variadic_outputs: - if output_names is None: + if given_num_outputs is None: raise ValueError( f"Op {self.name} ({self.op_type}) has variadic " - f"outputs, so output names must be specified" + "outputs, so one of output_{names,values} must be specified" ) - num_outputs = len(output_names) + num_outputs = given_num_outputs else: num_outputs = OpRegister[self.op_type].num_outputs + assert given_num_outputs is None or num_outputs == given_num_outputs - if output_values is not None: - object.__setattr__( - self, "outputs", output_values - ) # Can't assign to frozen field - else: + if output_values is None: # Create the correct number of output values with appropriate types - # if self.outputs is None: if output_names is None: output_names = [f"{self.name}_out_{i}" for i in range(num_outputs)] elif len(output_names) != num_outputs: @@ -83,3 +94,16 @@ def __post_init__(self, output_names, output_types, output_values): object.__setattr__( self, "outputs", output_values ) # Can't assign to frozen field + + def serialize_to_proto(self): + op_proto = op_pb2.Op() + op_proto.op_type = self.op_type + op_proto.name = self.name + for inp in self.inputs: + inp_proto = inp.serialize_to_proto() + op_proto.inputs.append(inp_proto) + # TODO: Serialize attributes + for output in self.outputs: + output_proto = output.serialize_to_proto() + op_proto.outputs.append(output_proto) + return op_proto diff --git a/dist_ir/ir/op_register.py b/dist_ir/ir/op_register.py index 66096c58..99227f4f 100644 --- a/dist_ir/ir/op_register.py +++ b/dist_ir/ir/op_register.py @@ -57,6 +57,7 @@ class OpRegisterEntry: "MPIScatterToTupleType": OpRegisterEntry(num_inputs=1, num_outputs=1), "Mul": OpRegisterEntry(num_inputs=2, num_outputs=1), "Opt": OpRegisterEntry(num_inputs=2, num_outputs=1), + "RecvP2P": OpRegisterEntry(num_inputs=0, num_outputs=1), "ReduceAllL2": OpRegisterEntry(variadic_inputs=True, num_outputs=1), "ReduceSum": OpRegisterEntry(num_inputs=1, num_outputs=1), "ReduceSumTraining": OpRegisterEntry(num_inputs=2, num_outputs=1), @@ -64,7 +65,9 @@ class OpRegisterEntry: "ReluGrad": OpRegisterEntry(num_inputs=2, num_outputs=1), "Reshape": OpRegisterEntry(num_inputs=2, num_outputs=1), "Select": OpRegisterEntry(num_inputs=1, num_outputs=1), + # TODO call the combined one SendRecv? "Send": OpRegisterEntry(num_inputs=1, num_outputs=1), + "SendP2P": OpRegisterEntry(num_inputs=1, num_outputs=0), "SGDOptimizer": OpRegisterEntry(num_inputs=3, num_outputs=2), "Shape": OpRegisterEntry(num_inputs=1, num_outputs=1), # TODO allow optional inputs for things like slice diff --git a/dist_ir/ir/prettyprint.py b/dist_ir/ir/prettyprint.py index 94dd671e..007adfee 100644 --- a/dist_ir/ir/prettyprint.py +++ b/dist_ir/ir/prettyprint.py @@ -47,7 +47,7 @@ ) from prettyprinter.utils import intersperse -from .function import Function +from .function import Function, FunctionMaker from .value import Value from .type import Type, Int32, Int64, Float, Tensor, TupleType from .device import Device @@ -120,10 +120,32 @@ def _(function: Function, ctx): ) +@register_pretty(FunctionMaker) +def _(function: FunctionMaker, ctx): + ops = _pprint_function_body(function, ctx) + return concat( + [ + annotate(Token.KEYWORD_CONSTANT, "function* "), + pretty_call(ctx, pp_fnname(function.name), *function.inputs), + nest(ctx.indent, concat([COLON, HARDLINE, interline(*ops)])), + ] + ) + + @register_pretty(Op) def _(op: Op, ctx): - results = concat(_join(*(pretty_dispatch(r, ctx) for r in op.outputs))) - args = concat(_join(*(v.name for v in op.inputs))) + attributes = () + if op.attributes is not None: + attributes = ( + concat((key, ASSIGN_OP, pretty_dispatch(value, ctx))) + for key, value in op.attributes.items() + ) + args = concat(_join(*(v.name for v in op.inputs), *attributes)) + + if len(op.outputs) == 0: + results = "_" + else: + results = concat(_join(*(pretty_dispatch(r, ctx) for r in op.outputs))) if op.op_type == "Pmap": lambda_body = _pprint_function_body(op.subfunctions[0], ctx) diff --git a/dist_ir/ir/type.py b/dist_ir/ir/type.py index 67092510..410bd933 100644 --- a/dist_ir/ir/type.py +++ b/dist_ir/ir/type.py @@ -5,6 +5,7 @@ from .device import Device from .utils import singleton +from ..proto import type_pb2 @dataclass(frozen=True) @@ -37,6 +38,9 @@ def __repr__(self): def size(self): return 4 + def serialize_to_proto(self): + return type_pb2.Int32() + @singleton class Int64(Type): @@ -49,6 +53,9 @@ def __repr__(self): def size(self): return 8 + def serialize_to_proto(self): + return type_pb2.Int64() + @singleton class Float(Type): @@ -73,6 +80,9 @@ def __repr__(self): def size(self): return 1 + def serialize_to_proto(self): + return type_pb2.Bool() + @dataclass(frozen=True) class Tensor(Type): @@ -99,6 +109,34 @@ def __repr__(self): def size(self): return reduce(mul, self.shape) * self.dtype.size + def serialize_to_proto(self): + tensor_proto = type_pb2.Tensor() + if self.device is not None: + tensor_proto.device.CopyFrom(self.device.serialize_to_proto()) + else: + # TODO: Anything? + pass + if self.shape is not None: + for dim in self.shape: + tensor_proto.shape.append(dim) + if self.dtype is not None: + if isinstance(self.dtype, Int32): + tensor_proto.dtype = type_pb2.DataType.INT32 + elif isinstance(self.dtype, Int64): + tensor_proto.dtype = type_pb2.DataType.INT64 + elif isinstance(self.dtype, Bool): + tensor_proto.dtype = type_pb2.DataType.BOOL + elif isinstance(self.dtype, Float): + raise NotImplementedError( + "Float type will be deprecated in favor of Float16 and Float32" + ) + else: + raise ValueError(f"Unknown dtype {type(self.dtype)}") + else: + # TODO: Anything? + pass + return tensor_proto + @dataclass(frozen=True) class TupleType(Type): diff --git a/dist_ir/ir/value.py b/dist_ir/ir/value.py index a22db06c..d6708247 100644 --- a/dist_ir/ir/value.py +++ b/dist_ir/ir/value.py @@ -1,6 +1,7 @@ from dataclasses import dataclass -from .type import Type +from .type import Bool, Int32, Int64, Tensor, Type +from ..proto import value_pb2 @dataclass(frozen=True, eq=False) @@ -12,3 +13,20 @@ class Value: name: str type: Type + + def serialize_to_proto(self): + value_proto = value_pb2.Value() + value_proto.name = self.name + if self.type is not None: + type_proto = self.type.serialize_to_proto() + if isinstance(self.type, type(Int32())): + value_proto.i32.CopyFrom(type_proto) + elif isinstance(self.type, type(Int64())): + value_proto.i64.CopyFrom(type_proto) + elif isinstance(self.type, type(Bool())): + value_proto.bool.CopyFrom(type_proto) + elif isinstance(self.type, Tensor): + value_proto.tensor.CopyFrom(type_proto) + else: + raise ValueError(f"Unknown type {type(self.type)}") + return value_proto diff --git a/dist_ir/proto/Makefile b/dist_ir/proto/Makefile new file mode 100644 index 00000000..a133d6fe --- /dev/null +++ b/dist_ir/proto/Makefile @@ -0,0 +1,9 @@ +all: + protoc -I=. --python_out=. ./device.proto + protoc -I=. --python_out=. ./type.proto + protoc -I=. --python_out=. ./value.proto + protoc -I=. --python_out=. ./op.proto + protoc -I=. --python_out=. ./function.proto + +clean: + rm *_pb2.py diff --git a/dist_ir/proto/device.proto b/dist_ir/proto/device.proto new file mode 100644 index 00000000..f2efe64e --- /dev/null +++ b/dist_ir/proto/device.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package dist_ir; + +message Device { + string device_id = 1; + string device_type = 2; + float throughput = 3; + float dram_bandwidth = 4; +} diff --git a/dist_ir/proto/function.proto b/dist_ir/proto/function.proto new file mode 100644 index 00000000..cab9582e --- /dev/null +++ b/dist_ir/proto/function.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package dist_ir; + +import "op.proto"; +import "value.proto"; + +message Function { + string name = 1; + repeated Op ops = 2; + repeated Value inputs = 3; + repeated Value outputs = 4; + // TODO: Add consumers? +} diff --git a/dist_ir/proto/op.proto b/dist_ir/proto/op.proto new file mode 100644 index 00000000..fe8a3cd6 --- /dev/null +++ b/dist_ir/proto/op.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package dist_ir; + +import "google/protobuf/any.proto"; +import "value.proto"; + +message Op { + string op_type = 1; + string name = 2; + repeated Value inputs = 3; + map attributes = 4; + repeated Value outputs = 5; + // TODO: Add subfunctions? +} diff --git a/dist_ir/proto/type.proto b/dist_ir/proto/type.proto new file mode 100644 index 00000000..8cadbcad --- /dev/null +++ b/dist_ir/proto/type.proto @@ -0,0 +1,39 @@ +syntax = "proto3"; + +package dist_ir; + +import "device.proto"; + +enum DataType { + INT32 = 0; + INT64 = 1; + FLOAT16 = 2; + FLOAT32 = 3; + BOOL = 4; +} + +message Int32 { + Device device = 1; +} + +message Int64 { + Device device = 1; +} + +message Float16 { + Device device = 1; +} + +message Float32 { + Device device = 1; +} + +message Bool { + Device device = 1; +} + +message Tensor { + Device device = 1; + DataType dtype = 2; + repeated uint32 shape = 3; +} diff --git a/dist_ir/proto/value.proto b/dist_ir/proto/value.proto new file mode 100644 index 00000000..fe5724e0 --- /dev/null +++ b/dist_ir/proto/value.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package dist_ir; + +import "type.proto"; + +message Value { + string name = 1; + oneof type { + Int32 i32 = 2; + Int64 i64 = 3; + Float16 f16 = 4; + Float32 f32 = 5; + Bool bool = 6; + Tensor tensor = 7; + } +} diff --git a/dist_ir/transforms/mlp_dhp_transform.py b/dist_ir/transforms/mlp_dhp_transform.py index 85a82b68..ffe8df53 100644 --- a/dist_ir/transforms/mlp_dhp_transform.py +++ b/dist_ir/transforms/mlp_dhp_transform.py @@ -309,7 +309,8 @@ def mlp_dhp_transform( function, dp_degree, hp_degree, pp_degree, devices, num_microbatches ): """Automatically distributes an MLP function using D/H/P hybrid parallelism.""" - transformed_function = FunctionMaker(name=function.name) + fn_name = f"{function.name}_{dp_degree}_{hp_degree}_{pp_degree}_{num_microbatches}" + transformed_function = FunctionMaker(name=fn_name) device_tree = _get_device_tree(dp_degree, hp_degree, pp_degree, devices) device_tree_root = tuple(device_tree.keys())[0] dp_devices = tuple(sorted(device_tree[device_tree_root].keys())) @@ -324,22 +325,28 @@ def mlp_dhp_transform( ) ) - # Add inputs to the transformed function. + # An init function that moves weights/inputs to correct devices + init_function = FunctionMaker(name=fn_name + "_init") transformed_inputs = {} for inp in function.inputs: - v = transformed_function.add_input_value(inp.name, inp.type) + v = init_function.add_input_value(inp.name, inp.type) transformed_inputs[inp] = v # Partition inputs across each parallelism dimension. - dp_inputs = _partition_inputs_dp(transformed_function, device_tree) - hp_inputs = _partition_inputs_hp(transformed_function, device_tree, dp_inputs) + dp_inputs = _partition_inputs_dp(init_function, device_tree) + hp_inputs = _partition_inputs_hp(init_function, device_tree, dp_inputs) pp_inputs = _partition_inputs_pp( - transformed_function, + init_function, device_tree, dp_inputs, hp_inputs, num_microbatches, ) + init_function = init_function.finalize() + + # Inputs of transformed_function are outputs of init_function + for v in init_function.outputs: + transformed_function.inputs.append(v) dp_outputs = defaultdict(list) for i, dp_device in enumerate(device_tree[device_tree_root]): @@ -670,7 +677,9 @@ def mlp_dhp_transform( hp_group, transformed_function, output_names=[ - f"{output.name}_dp_all_hp_{hp_device_group_str}_pp_all" + # TODO how to get device? + f"{output.name}_dp_all_hp_{hp_device_group_str}_pp_all_{j}" + for j in range(len(hp_group)) ], ) else: @@ -679,6 +688,10 @@ def mlp_dhp_transform( hp_group, transformed_function, dim=0, - output_names=[f"{output.name}_dp_all_hp_all_pp_all"], + output_names=[ + f"{output.name}_dp_all_hp_all_pp_all_{j}" + for j in range(len(hp_group)) + ], ) - return transformed_function.finalize() + # TODO transformed_function should output loss/grads? + return init_function, transformed_function.finalize() diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/grid_search.py b/examples/grid_search.py index 276df088..ec82f231 100644 --- a/examples/grid_search.py +++ b/examples/grid_search.py @@ -1,24 +1,13 @@ -import argparse -from collections import defaultdict, OrderedDict import csv -import logging +from itertools import product import numpy as np -import time -import matplotlib as mpl -import matplotlib.pyplot as plt from multiprocessing import Pool -import dist_ir -from dist_ir.importer import import_from_onnx, parse_tensor_from_file -from dist_ir.ir import FunctionMaker, cpprint, pformat, Device, Topology, Value -from dist_ir.executor import infer_types, SequentialExecutor, Simulator +from dist_ir.ir import Topology +from dist_ir.executor import infer_types, Simulator from dist_ir.executor.cost_model import CostModel -from dist_ir.ir.type import Bool, Float, Int64, Tensor -from dist_ir.transforms import ( - parallel_transform_3d, - PipeDreamScheduler, -) -from mlp import mlp +from dist_ir.transforms import mlp_dhp_transform +from .mlp import mlp DGX_BANDWIDTH_GBPS = 200 @@ -30,7 +19,6 @@ def add_devices_to_topology(topology, num_devices): for i in range(0, len(devices)): for j in range(i + 1, len(devices)): topology.set_bandwidth(devices[i], devices[j], DGX_BANDWIDTH_GBPS) - return topology def get_all_degrees(n): @@ -78,7 +66,7 @@ def run_experiment(config): world_size = dp_degree * hp_degree * pp_degree add_devices_to_topology(topology, world_size) - transformed_function = parallel_transform_3d( + transformed_function = mlp_dhp_transform( function, dp_degree, hp_degree, @@ -101,43 +89,62 @@ def run_experiment(config): return throughput -def grid_search(): - input_dim = 8192 - hidden_dim = input_dim - output_dim = input_dim - all_cluster_sizes = [1, 2, 4, 8, 16, 32] - all_num_hidden_layers = [64] - all_batch_sizes = [8192] - configs = [] - for num_hidden_layers in all_num_hidden_layers: - for batch_size in all_batch_sizes: - for i, cluster_size in enumerate(all_cluster_sizes): - all_degrees = get_all_degrees(cluster_size) - for (dp_degree, hp_degree, pp_degree) in all_degrees: - if num_hidden_layers % pp_degree != 0: - continue - dp_batch_size = batch_size // dp_degree - if pp_degree == 1: - all_num_microbatches = [1] - else: - all_num_microbatches = [ - int(2 ** k) - for k in range(1, int(np.floor(np.log2(dp_batch_size) / 2))) - ] - for num_microbatches in all_num_microbatches: - if pp_degree == 1: - num_microbatches == 1 - configs.append( - ( - batch_size, - input_dim, - num_hidden_layers, - dp_degree, - hp_degree, - pp_degree, - num_microbatches, - ) - ) +def mlp_dist( + mlp_fn, + dp_degree, + hp_degree, + pp_degree, + num_microbatches, + topology, +): + init_function, transformed_function = mlp_dhp_transform( + mlp_fn, + dp_degree, + hp_degree, + pp_degree, + topology.devices, + num_microbatches, + ) + init_function = infer_types(init_function, init_function.inputs) + # init_function.outputs = transformed_function.inputs, so get types from there: + transformed_function = infer_types(transformed_function, init_function.outputs) + return init_function, transformed_function + + +def gen_configurations(hidden_dims, cluster_sizes, all_num_layers, all_batch_sizes): + for hidden_dim, num_hidden_layers, batch_size, cluster_size in product( + hidden_dims, all_num_layers, all_batch_sizes, cluster_sizes + ): + all_degrees = get_all_degrees(cluster_size) + for (dp_degree, hp_degree, pp_degree) in all_degrees: + if num_hidden_layers % pp_degree != 0: + continue + dp_batch_size = batch_size // dp_degree + if pp_degree == 1: + all_num_microbatches = [1] + else: + all_num_microbatches = [ + int(2 ** k) + for k in range(1, int(np.floor(np.log2(dp_batch_size) / 2))) + ] + for num_microbatches in all_num_microbatches: + if pp_degree == 1: + num_microbatches == 1 + yield ( + batch_size, + hidden_dim, + num_hidden_layers, + dp_degree, + hp_degree, + pp_degree, + num_microbatches, + ) + + +def grid_search(hidden_dims, cluster_sizes, all_num_layers, all_batch_sizes): + configs = list( + gen_configurations(hidden_dims, cluster_sizes, all_num_layers, all_batch_sizes) + ) with Pool() as p: results = p.map(run_experiment, configs) @@ -174,4 +181,10 @@ def grid_search(): if __name__ == "__main__": - grid_search() + # grid_search( + # hidden_dims=[8192], + # cluster_sizes=[1, 2, 4, 8, 16, 32], + # all_num_layers=[64], + # all_batch_sizes=[8192], + # ) + pass diff --git a/examples/mlp.py b/examples/mlp.py index 8962bd5b..7e79643f 100644 --- a/examples/mlp.py +++ b/examples/mlp.py @@ -54,3 +54,72 @@ def mlp(batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, device output_names=[f"dy{i}", f"dw{chr(ord('A')+i)}"], ) return function.finalize() + + +def mlp_inference( + batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, device +): + function = FunctionMaker(name="mlp") + weights = [] + for i in range(num_hidden_layers - 1): + w = function.add_input_value( + f"w{chr(ord('A')+i)}", + Tensor(dtype=Float(), shape=(input_dim, hidden_dim), device=device), + ) + weights.append(w) + w = function.add_input_value( + f"w{chr(ord('A')+i+1)}", + Tensor(dtype=Float(), shape=(hidden_dim, output_dim), device=device), + ) + weights.append(w) + x = function.add_input_value( + "x", + Tensor(dtype=Float(), shape=(batch_size, input_dim), device=device), + ) + + a = x + for i, weight in enumerate(weights): + y = function.add_op("MatMul", inputs=[a, weight], output_names=[f"y{i}"]) + a = function.add_op("Relu", inputs=[y], output_names=[f"a{i}"]) + + return function.finalize() + + +def mlp_inference_dp( + batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, devices +): + num_devices = len(devices) + assert batch_size % num_devices == 0 + function = FunctionMaker(name="mlp") + weights = {} + x = {} + for d in devices: + for i in range(num_hidden_layers - 1): + weights[i, d] = function.add_input_value( + f"w{chr(ord('A')+i)}_{d.device_id}", + Tensor(dtype=Float(), shape=(input_dim, hidden_dim), device=d), + ) + weights[num_hidden_layers - 1, d] = function.add_input_value( + f"w{chr(ord('A')+i+1)}_{d.device_id}", + Tensor(dtype=Float(), shape=(hidden_dim, output_dim), device=d), + ) + x[d] = function.add_input_value( + f"x_{d.device_id}", + Tensor( + dtype=Float(), shape=(batch_size // num_devices, input_dim), device=d + ), + ) + + a = x + for i in range(num_hidden_layers): + for d in devices: + y = function.add_op( + "MatMul", + inputs=[a[d], weights[i, d]], + output_names=[f"y{i}_{d.device_id}"], + ) + a[d] = function.add_op( + "Relu", inputs=[y], output_names=[f"a{i}_{d.device_id}"] + ) + + return function.finalize() diff --git a/requirements.txt b/requirements.txt index fccdd866..be6b197a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ frozendict >= 1.2 numpy >= 1.19 onnx >= 1.7.0 -torch >= 1.6.0 +torch >= 1.8.0 prettyprinter >= 0.18.0 diff --git a/test/test_mlp_dhp_transform.py b/test/test_mlp_dhp_transform.py index a3de44a5..52149a89 100644 --- a/test/test_mlp_dhp_transform.py +++ b/test/test_mlp_dhp_transform.py @@ -88,11 +88,12 @@ def add_devices_to_topology(topology, num_devices): def _verify_no_hp(outputs, transformed_outputs, dp=False): - for output, transformed_output in zip(outputs, transformed_outputs): - if dp: - np.testing.assert_array_almost_equal(output, transformed_output[0]) + for i in range(len(outputs)): + if not dp: + j = i else: - np.testing.assert_array_almost_equal(output, transformed_output) + j = 2 * i + np.testing.assert_array_almost_equal(outputs[i], transformed_outputs[j]) def _verify_hp(function, transformed_function, outputs, transformed_outputs, dp=False): @@ -105,10 +106,7 @@ def _verify_hp(function, transformed_function, outputs, transformed_outputs, dp= match = re.search(f"(.*)_dp_(.*)_hp_(.*)_pp_(.*){device_suffix}", output.name) assert match is not None key = (match.group(1), match.group(2), match.group(4)) - if dp: - aggregated_outputs[key].append(v[0]) - else: - aggregated_outputs[key].append(v) + aggregated_outputs[key].append(v) for key in aggregated_outputs: output_name = key[0] if "dw" in output_name: @@ -138,7 +136,7 @@ def _test_helper( world_size = dp_degree * hp_degree * pp_degree add_devices_to_topology(topology, world_size) - transformed_function = mlp_dhp_transform( + init_function, transformed_function = mlp_dhp_transform( function, dp_degree, hp_degree, @@ -146,14 +144,15 @@ def _test_helper( topology.devices, num_microbatches, ) - transformed_function = infer_types( - transformed_function, transformed_function.inputs - ) + init_function = infer_types(init_function, init_function.inputs) + # init_function.outputs = transformed_function.inputs, so get types from there: + transformed_function = infer_types(transformed_function, init_function.outputs) input_data = [np.random.normal(size=inp.type.shape) for inp in function.inputs] ex = SequentialExecutor("numpy") outputs = ex.compute(function, input_data) - transformed_outputs = ex.compute(transformed_function, input_data) + dist_input_data = ex.compute(init_function, input_data) + transformed_outputs = ex.compute(transformed_function, dist_input_data) if hp_degree > 1: _verify_hp( @@ -189,3 +188,7 @@ def test_hp_pp(): def test_dp_hp_pp(): _test_helper(dp_degree=2, hp_degree=2, pp_degree=2, num_microbatches=2) + + +if __name__ == "__main__": + test_dp_only() diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py new file mode 100644 index 00000000..5d456a2a --- /dev/null +++ b/test/test_pytorch_backend.py @@ -0,0 +1,316 @@ +from collections import defaultdict +import csv +import numpy as np +import pytest +import torch + +from dist_ir.backend.torch import run_pytorch +from dist_ir.executor import SequentialExecutor +from dist_ir.executor.cost_model import CostModel +from dist_ir.executor.simulator import Simulator +from dist_ir.executor.type_inference import infer_types +from dist_ir.ir import Device, FunctionMaker, cpprint, Value +from dist_ir.ir.type import Float, Tensor +from dist_ir.ir.topology import Topology + +# TODO make examples submodule of dist_ir? +from examples.grid_search import add_devices_to_topology, gen_configurations, mlp_dist +from examples.mlp import mlp, mlp_inference_dp + + +def create_owt_model(num_devices, num_layers): + assert num_layers % 2 == 0 + + fn = FunctionMaker() + + # Inputs + weights = {} + xs = {} + for l in range(num_layers): + for d in range(1, num_devices + 1): + weights[l, d] = fn.add_input_value(f"w{l}_{d}", None) + for d in range(1, num_devices + 1): + xs[d] = fn.add_input_value(f"x_{d}", None) + + # Data parallel conv blocks: (using MatMuls for now) + hs = [] + for d in range(1, num_devices + 1): + h = xs[d] + for l in range(num_layers // 2): + h = fn.add_op( + "MatMul", inputs=[h, weights[l, d]], output_names=[f"h{l}_{d}"] + ) + hs.append(h) + + # Allgather the activations + as_names = [f"hh{num_layers//2-1}_{d}" for d in range(1, num_devices + 1)] + hs = fn.add_op( + "MPIAllgather", + inputs=hs, + output_names=as_names, + attributes={"dim": 0}, + ) + + # Model parallel fully-connected layers: (again, MatMuls for now) + hs = hs + for l in range(num_layers // 2, num_layers): + h_is = [] + for d in range(1, num_devices + 1): + h_is.append( + fn.add_op( + "MatMul", + inputs=[hs[d - 1], weights[l, d]], + output_names=[f"h{l}_{d}"], + ) + ) + if l == num_layers - 1: + hs = h_is + else: + out_names = [f"hh{l}_{d}" for d in range(1, num_devices + 1)] + hs = fn.add_op( + "MPIAllgather", + inputs=h_is, + output_names=out_names, + attributes={"dim": 1}, + ) + + fn.set_outputs(hs) + return fn.finalize() + + +@pytest.mark.parametrize(["num_devices", "num_layers"], [(2, 4)]) +def test_owt(num_devices, num_layers): + fn = create_owt_model(num_devices, num_layers) + + devices = [Device(0, "cpu")] + for d in range(1, num_devices + 1): + devices.append(Device(d, "gpu")) + + batch_size = 8 + hidden_dim = 4 # using this for input/output dim also + + input_vals = [] + for l in range(num_layers): + for d in range(1, num_devices + 1): + if l < num_layers // 2: + shape = (hidden_dim, hidden_dim) + else: + shape = (hidden_dim, hidden_dim // num_devices) + # w{l}_{d}: + input_vals.append(Value("", Tensor(Float(), shape, devices[d]))) + for d in range(1, num_devices + 1): + # x_{d}: + shape = (batch_size // num_devices, hidden_dim) + input_vals.append(Value("", Tensor(Float(), shape, devices[d]))) + + # Test type inference: + fn = infer_types(fn, input_vals) + cpprint(fn) + assert all( + v.type.shape == (batch_size, hidden_dim // num_devices) for v in fn.outputs + ) + + # Test with sequential executor: + np.random.seed(0) + weights = [np.random.randn(hidden_dim, hidden_dim) for l in range(num_layers)] + x = np.random.randn(batch_size, hidden_dim) + + # Split inputs for distributed function + input_arrays = [] + for l in range(num_layers): + if l < num_layers // 2: + for d in range(1, num_devices + 1): + input_arrays.append(weights[l]) + else: + input_arrays += np.split(weights[l], num_devices, axis=1) + input_arrays += np.split(x, num_devices) + ex = SequentialExecutor("numpy") + output_arrays = ex.compute(fn, input_arrays) + + # Expected results + y = x + for l in range(num_layers): + y = np.matmul(y, weights[l]) + ys = np.split(y, num_devices, axis=1) + assert all(np.allclose(y, o) for y, o in zip(ys, output_arrays)) + + # Run per-rank modules using PyTorch backend: + per_rank_outputs, _ = run_pytorch(fn, [torch.tensor(a) for a in input_arrays]) + + # Check outputs: + assert all(np.allclose(y[0], o) for y, o in zip(per_rank_outputs, output_arrays)) + + +def test_dp_mp_matmuls(): + fn = FunctionMaker("dp_mp_matmuls") + B = 64 + d0 = Device(0, "gpu") + d1 = Device(1, "gpu") + x_0 = fn.add_input_value("x_0", Tensor(Float(), (B // 2, B), d0)) + x_1 = fn.add_input_value("x_1", Tensor(Float(), (B // 2, B), d1)) + wA_0 = fn.add_input_value("wA_0", Tensor(Float(), (B, B), d0)) + wA_1 = fn.add_input_value("wA_1", Tensor(Float(), (B, B), d1)) + wB_0 = fn.add_input_value("wB_0", Tensor(Float(), (B, B), d0)) + wC_1 = fn.add_input_value("wC_1", Tensor(Float(), (B, B), d1)) + a0_0 = fn.add_op("MatMul", inputs=[x_0, wA_0], output_names=["a0"]) + a1_1 = fn.add_op("MatMul", inputs=[x_1, wA_1], output_names=["a1"]) + a_0 = fn.add_op( + "MPIGather", + inputs=[a0_0, a1_1], + output_names=["a_0"], + attributes={"device": d0, "dim": 0}, + ) + b_0 = fn.add_op("MatMul", inputs=[a_0, wB_0], output_names=["b_0"]) + b_1 = fn.add_op( + "Send", inputs=[b_0], output_names=["b_1"], attributes={"device": d1} + ) + c_1 = fn.add_op("MatMul", inputs=[b_1, wC_1], output_names=["c_1"]) + fn = fn.finalize() + fn = infer_types(fn, fn.inputs) + cpprint(fn) + + from dist_ir.executor.rank_projector import project + + per_rank_fns, groups = project(fn, tuple(v.type for v in fn.inputs)) + for per_rank_fn in per_rank_fns.values(): + cpprint(per_rank_fn) + + +def test_mlp_grid_search(): + # batch_sizes = [2 ** i for i in range(10, 15)] + # hidden_dims = [2 ** i for i in range(8, 13)] + batch_sizes = [64] + hidden_dims = [64] + world_sizes = [1, 2, 4, 8] + all_num_layers = [32] + + results = [] + for (batch_size, hidden_dim, num_layers, d, h, p, m) in gen_configurations( + hidden_dims, world_sizes, all_num_layers, batch_sizes + ): + world_size = d * h * p + # TODO reuse seq_mlp + topology = Topology() + d0 = topology.add_device("gpu") + add_devices_to_topology(topology, world_size) + simulator = Simulator(CostModel(topology)) + seq_executor = SequentialExecutor("numpy") + seq_mlp = mlp(batch_size, hidden_dim, hidden_dim, hidden_dim, num_layers, d0) + seq_mlp = infer_types(seq_mlp, seq_mlp.inputs) + + # Create random input data + input_data = tuple( + np.random.randn(*v.type.shape).astype(np.float32) for v in seq_mlp.inputs + ) + + init_fn, fn = mlp_dist(seq_mlp, d, h, p, m, topology) + print(fn.name) + + # Simulate + simulation = simulator.interpret(fn, (v.type for v in fn.inputs)) + simulated_time = max([simulation.timestamps[d] for d in simulation.timestamps]) + print(simulated_time) + + # Reference-execute init_fn to get inputs for fn + dist_input_data = seq_executor.compute(init_fn, input_data) + dist_input_data = tuple(torch.tensor(t) for t in dist_input_data) + assert all( + t.shape == v.type.shape for (t, v) in zip(dist_input_data, fn.inputs) + ) + + # Measure actual execution time + # TODO check outputs match? + # _, runtimes = run_pytorch(world_size, fn, dist_input_data) + _, runtimes = run_pytorch( + fn, + dist_input_data, + use_gpu=False, + num_repetitions=1, # TODO use 100 + num_warmup=1, + ) + # TODO or median of max? + actual_time = max(np.median(times) for times in runtimes) + + print(fn.name, simulated_time, actual_time) + + +def test_single_device(): + d1 = Device(1, "gpu") + fn = FunctionMaker() + x = fn.add_input_value("x", Tensor(Float(), (4, 4), d1)) + y = fn.add_op("MatMul", inputs=(x, x)) + fn.set_outputs((y,)) + fn = fn.finalize() + cpprint(fn) + + x = torch.randn(4, 4) + inputs = (x,) + outputs, _ = run_pytorch(fn, inputs) + print(outputs) + assert torch.allclose(torch.matmul(x, x), outputs[0][0]) + + +def test_send_recv(): + d1 = Device(1, "gpu") + d2 = Device(2, "gpu") + fn = FunctionMaker() + x = fn.add_input_value("x", Tensor(Float(), (4, 4), d1)) + y = fn.add_op("Send", inputs=(x,), attributes={"device": d2}) + fn.set_outputs((x, y)) + fn = fn.finalize() + cpprint(fn) + + x = torch.randn(4, 4) + inputs = (x,) + outputs, _ = run_pytorch(fn, inputs) + assert torch.allclose(x, outputs[1][0]) + + +def test_dp_mlp(): + num_devices = 2 + num_layers = 4 + batch_size = 4 + hidden_dim = 6 # Also input/output dim for simplicity + devices = [Device(d, "gpu") for d in range(num_devices + 1)] + + fn = mlp_inference_dp( + batch_size, hidden_dim, hidden_dim, hidden_dim, num_layers, devices[1:] + ) + fn = infer_types(fn, fn.inputs) + cpprint(fn) + + def convert_inputs_dp(weights, x): + xs = torch.split(x, num_devices) + + def new_inputs(): + for d in range(num_devices): + yield from weights + yield xs[d] + + return list(new_inputs()) + + # Make random input/expected data: + weights = [torch.randn(hidden_dim, hidden_dim) for _ in range(num_layers)] + x = torch.randn(batch_size, hidden_dim) + y = x + for l in range(num_layers): + y = torch.matmul(y, weights[l]) + y = torch.relu(y) + + # Project and run on backend: + per_rank_outputs, runtimes = run_pytorch(fn, convert_inputs_dp(weights, x)) + + # Check outputs: + assert torch.allclose(y, torch.cat([o[0] for o in per_rank_outputs], 0)) + + return runtimes + + +if __name__ == "__main__": + # test_owt(2, 4) + # test_dp_mlp() + # test_send_recv() + # test_single_device() + test_dp_mp_matmuls() + + test_mlp_grid_search()