diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py
index ccba2ab7..5a4665b4 100644
--- a/dist_ir/backend/torch.py
+++ b/dist_ir/backend/torch.py
@@ -1,3 +1,5 @@
+import itertools
+import json
from functools import partial
import numpy as np
from operator import getitem
@@ -6,6 +8,7 @@
from time import perf_counter
from traceback import print_exc
from typing import Any, Dict, Iterable, List, NamedTuple, Sequence, Tuple
+import time
import torch
import torch.distributed as dist
@@ -30,10 +33,14 @@
groups=Dict[Tuple[int], Any],
# Temp store of group IDs until threads can create ProcessGroups
groups_list=Iterable[Tuple[int]],
+ # Group encompassing all devices
+ global_group=Tuple[int],
# Debug flag
debug_stacktrace=bool,
# Profile flag
profile=bool,
+ # List of op execution events
+ trace=list,
)
@@ -73,7 +80,7 @@ def _cast(x, to, ctx=None):
raise NotImplementedError()
-def _concat2(*args, axis=None, ctx=None):
+def _concat(*args, axis=None, ctx=None):
return torch.cat(args, dim=axis)
@@ -162,14 +169,25 @@ def _reshape(x, y, ctx=None):
def _recv(shape=None, from_d=None, group=None, dtype=None, ctx=None):
- if isinstance(dtype, Int32):
- x = torch.zeros(shape).int()
- elif isinstance(dtype, Int64):
- x = torch.zeros(shape).long()
- elif isinstance(dtype, Float32):
- x = torch.zeros(shape).float()
+ # torch.distributed.barrier(group=ctx.groups[group])
+ if len(shape) == 0:
+ if isinstance(dtype, Int32):
+ x = torch.tensor(0).int()
+ if isinstance(dtype, Int64):
+ x = torch.tensor(0).long()
+ elif isinstance(dtype, Float32):
+ x = torch.tensor(0).float()
+ else:
+ raise NotImplementedError(dtype)
else:
- raise NotImplementedError(dtype)
+ if isinstance(dtype, Int32):
+ x = torch.zeros(shape).int()
+ if isinstance(dtype, Int64):
+ x = torch.zeros(shape).long()
+ elif isinstance(dtype, Float32):
+ x = torch.zeros(shape).float()
+ else:
+ raise NotImplementedError(dtype)
src_rank = ctx.device_to_rank[from_d]
if ctx.use_gpu:
@@ -191,6 +209,7 @@ def _relu_grad(x, dy, ctx=None):
def _send(x, to_d=None, group=None, ctx=None):
+ # torch.distributed.barrier(group=ctx.groups[group])
if ctx.use_gpu:
src_rank = dist.get_rank()
dist.broadcast(x, src_rank, group=ctx.groups[group])
@@ -277,7 +296,7 @@ def _unsqueeze(x, axes, ctx=None):
"Add": torch.add,
"Cast": _cast,
"Add": _add,
- "Concat": _concat2,
+ "Concat": _concat,
"Constant": _constant,
"ConstantOfShape": _constant_of_shape,
"Div": _div,
@@ -379,16 +398,26 @@ def function_to_module(fn: Function) -> torch.nn.Module:
return fx.GraphModule({}, g)
+def add_event(ctx, events):
+ if ctx.use_gpu:
+ events.append(torch.cuda.Event(enable_timing=True))
+ events[-1].record()
+ else:
+ events.append(perf_counter())
+
+
def run_function(
ctx: DistributedContext,
fn: Function,
inputs: List[Any],
rank: int,
debug_mock=False,
+ op_runtimes_ts: float = None,
):
"""Runs DistIR Function `fn` on `inputs` in a distributed context `ctx` by
converting each DistIR op to its torch implementation as given in _op_to_torch.
"""
+ record_op_runtimes = op_runtimes_ts is not None
op_to_torch = _mock_op_to_torch if debug_mock else _op_to_torch
value_map = {}
@@ -403,13 +432,39 @@ def print_memory_usage():
a = torch.cuda.memory_allocated(0)
print(f"Total: {t} Reserved: {r} Allocated: {a} Free: {r-a}")
+ if record_op_runtimes:
+ op_runtimes = []
+
# Run ops
for op in fn.ops:
inputs = tuple(value_map[v] for v in op.inputs)
kwargs = {} if op.attributes is None else {**op.attributes}
kwargs["ctx"] = ctx
+ # TODO: Consider adding this to mitigate network contention:
+ # if "MPI" in op.op_type or op.op_type == "Send":
+ # torch.cuda.synchronize()
+
+ if record_op_runtimes:
+ start = time.time()
output = op_to_torch[op.op_type](*inputs, **kwargs)
+ if record_op_runtimes:
+ if ctx.use_gpu:
+ torch.cuda.synchronize(device=rank)
+ end = time.time()
+ if op.op_type == "SendP2P":
+ x = inputs[0]
+ src_rank = dist.get_rank()
+ dst_rank = ctx.device_to_rank[kwargs["to_d"]]
+ group = ctx.groups[kwargs["group"]]
+ latency = end - start
+ print(
+ f"Sending tensor of size {x.size()} on device {x.device} with dtype "
+ f"{x.dtype} from device {src_rank} to {dst_rank}: latency={latency}, "
+ f"throughput={x.shape[0] * x.shape[1] * 4 / 1.25e8 / latency}"
+ )
+
+ op_runtimes.append(end - start)
if len(op.outputs) > 1:
assert isinstance(output, tuple)
@@ -423,6 +478,24 @@ def print_memory_usage():
if v in value_map and fn.last_use(v) == op and not (v in fn.outputs):
del value_map[v]
+ if record_op_runtimes:
+ trace = []
+ ts = op_runtimes_ts
+ assert len(fn.ops) == len(op_runtimes)
+ for op, runtime in zip(fn.ops, op_runtimes):
+ trace.append(
+ {
+ "name": op.op_type,
+ "ph": "X",
+ "ts": ts,
+ "dur": runtime * 1e6,
+ "pid": 0,
+ "tid": rank + 1,
+ }
+ )
+ ts += runtime * 1e6
+ ctx.trace[rank] += trace
+
# Return outputs
return tuple(value_map[v] for v in fn.outputs)
@@ -443,6 +516,8 @@ def run_process(ctx, num_warmup_steps, num_repetitions, rank, fn, inputs):
ranks = [ctx.device_to_rank[d] for d in group]
# ctx is a curried arg, hence is thread-local and can be modified:
ctx.groups[group] = dist.new_group(ranks)
+ global_group_ranks = sorted([ctx.device_to_rank[d] for d in ctx.global_group])
+ global_group = dist.new_group(global_group_ranks)
if ctx.use_gpu:
# Move inputs to GPU
@@ -450,13 +525,6 @@ def run_process(ctx, num_warmup_steps, num_repetitions, rank, fn, inputs):
events = []
- def add_event():
- if ctx.use_gpu:
- events.append(torch.cuda.Event(enable_timing=True))
- events[-1].record()
- else:
- events.append(perf_counter())
-
if ctx.profile:
num_wait_steps = 0
else:
@@ -466,7 +534,7 @@ def add_event():
try:
outputs = run_function(ctx, fn, inputs, rank)
if ctx.world_size > 1:
- torch.distributed.barrier()
+ torch.distributed.barrier(group=global_group)
except Exception as e:
print_exc()
print(f"{rank}: PyTorch backend exiting after 1 run in debug mode.")
@@ -482,18 +550,32 @@ def add_event():
schedule=torch.profiler.schedule(
wait=num_wait_steps, warmup=num_warmup_steps, active=num_repetitions
),
- on_trace_ready=torch.profiler.tensorboard_trace_handler(
- f"{fn.name}_{rank}_profile"
- ),
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(f"{fn.name}_profile"),
) as p:
+ op_runtimes_ts = None
for i in range(num_warmup_steps + num_repetitions):
- add_event()
+ record_op_runtimes = ctx.profile and i >= num_warmup_steps
+ if record_op_runtimes and op_runtimes_ts is None:
+ op_runtimes_ts = 0.0
+ add_event(ctx, events)
# TODO: Handle failures here?
- outputs = run_function(ctx, fn, inputs, rank)
+ outputs = run_function(
+ ctx,
+ fn,
+ inputs,
+ rank,
+ op_runtimes_ts=op_runtimes_ts,
+ )
if ctx.world_size > 1:
- torch.distributed.barrier()
- add_event()
+ torch.distributed.barrier(group=global_group)
+ if i == (num_warmup_steps + num_repetitions - 1):
+ add_event(ctx, events)
p.step()
+ if record_op_runtimes:
+ op_runtimes_ts = max(
+ ctx.trace[rank][-1]["ts"] + ctx.trace[rank][-1]["dur"]
+ for rank in ctx.trace.keys()
+ )
if ctx.use_gpu:
# Move outputs back to cpu
@@ -553,6 +635,11 @@ def run_multiprocesses(
if ctx.debug_stacktrace:
sys.exit(1)
+ if ctx.profile:
+ trace = list(itertools.chain.from_iterable(list(ctx.trace.values())))
+ with open(f"{per_rank_functions[0].name}_profile/trace.json", "w") as f:
+ json.dump(trace, f, indent=0)
+
per_rank_outputs, runtimes = zip(*outputs)
return per_rank_outputs, runtimes
@@ -592,23 +679,42 @@ def run_pytorch(
device_to_fns, groups = project(fn, input_types)
+ if len(device_to_fns) > torch.cuda.device_count():
+ raise ValueError(
+ f"Received {len(device_to_fns)} projected functions, "
+ f"but only {torch.cuda.device_count()} GPUs available"
+ )
+
# Map between DistIR devices and pytorch ranks:
device_to_rank = {}
world_size = 0
per_rank_fns = []
for d in device_to_fns:
- device_to_rank[d] = world_size
+ rank = world_size
+ device_to_rank[d] = rank
per_rank_fns.append(device_to_fns[d])
world_size += 1
+ global_group = tuple(sorted(device_to_fns.keys()))
+
+ if profile:
+ manager = torch.multiprocessing.Manager()
+ trace = manager.dict()
+ for d in sorted(device_to_rank.keys()):
+ trace[device_to_rank[d]] = []
+ else:
+ trace = None
+
ctx = DistributedContext(
world_size=world_size,
use_gpu=use_gpu,
groups={},
groups_list=list(groups),
+ global_group=global_group,
device_to_rank=device_to_rank,
debug_stacktrace=debug_stacktrace,
profile=profile,
+ trace=trace,
)
per_rank_inputs = [[] for _ in range(world_size)]
diff --git a/dist_ir/executor/__init__.py b/dist_ir/executor/__init__.py
index 83556bc3..056a93af 100644
--- a/dist_ir/executor/__init__.py
+++ b/dist_ir/executor/__init__.py
@@ -1,4 +1,10 @@
from .absint import AbstractInterpreter, AbstractState
+from .calibrate_simulator import (
+ calibrate_device_parameters,
+ calibrate_network_bandwidth,
+ calibrate_allreduce_parameters,
+ network_bandwidth_debug, # TODO: Remove
+)
from .concrete_value import ConcreteValue
from .cost_model import CostModel
from .simulator import Simulator
diff --git a/dist_ir/executor/cost_model.py b/dist_ir/executor/cost_model.py
index 025b6f2c..b731442d 100644
--- a/dist_ir/executor/cost_model.py
+++ b/dist_ir/executor/cost_model.py
@@ -5,7 +5,6 @@
from ..ir.type import Float32, Float64, Int32, Int64, Tensor, TupleType
BYTES_IN_Gb = 1.25e8
-KERNEL_LAUNCH_OVERHEAD = 10e-6
class CostModel:
@@ -178,20 +177,22 @@ def _elementwise_cost_fn(self, op, x, y=None):
flops = n
communication_cost = data_size / x.device.dram_bandwidth
computation_cost = flops / x.device.throughput
- latency = KERNEL_LAUNCH_OVERHEAD + communication_cost + computation_cost
+ latency = (
+ x.device.kernel_launch_overhead + communication_cost + computation_cost
+ )
return {x.device: latency}
def _concat_cost_fn(self, op, *xs):
# TODO: Compute cost properly
devices = [x.device for x in xs]
- return {device: KERNEL_LAUNCH_OVERHEAD for device in devices}
+ return {device: xs[0].device.kernel_launch_overhead for device in devices}
def _constant_of_shape_cost_fn(self, op, x):
- return {x.device: KERNEL_LAUNCH_OVERHEAD}
+ return {x.device: x.device.kernel_launch_overhead}
def _gather_cost_fn(self, op, x, y):
# TODO: Compute cost properly
- return {x.device: KERNEL_LAUNCH_OVERHEAD}
+ return {x.device: x.device.kernel_launch_overhead}
def _gemm_cost_fn(self, op, x, y, z):
gemm_costs = self._matmul_cost_fn(op, x, y)
@@ -210,10 +211,12 @@ def _join_cost_fn(self, op, *xs):
def _matmul_cost_fn(self, op, x, y):
data_size = x.dtype.size() * (x.shape[0] * x.shape[1] + y.shape[0] * y.shape[1])
- flops = 2 * x.shape[0] * x.shape[1] * y.shape[1]
+ flops = (2 * x.shape[1] - 1) * x.shape[0] * y.shape[1]
communication_cost = data_size / x.device.dram_bandwidth
computation_cost = flops / x.device.throughput
- latency = KERNEL_LAUNCH_OVERHEAD + communication_cost + computation_cost
+ latency = (
+ x.device.kernel_launch_overhead + communication_cost + computation_cost
+ )
return {x.device: latency}
def _matmul_grad_cost_fn(self, op, x, y, dz):
@@ -241,7 +244,7 @@ def _mpi_allgather_cost_fn(self, op, *xs):
)
average_bandwidth = np.mean(all_bandwidths)
average_input_size = np.mean([x.size() for x in xs])
- per_device_data = 2 * average_input_size * (len(devices) - 1) / len(devices)
+ per_device_data = 2 * average_input_size * (len(devices) - 1)
per_device_data_gb = per_device_data / BYTES_IN_Gb
cost = per_device_data_gb / average_bandwidth
return {device: cost for device in devices}
@@ -306,14 +309,14 @@ def _mpi_scatter_cost_fn(self, op, x):
return {d: cost for d in op.attributes["devices"]}
def _nonzero_cost_fn(self, op, x):
- return {x.device: KERNEL_LAUNCH_OVERHEAD}
+ return {x.device: x.device.kernel_launch_overhead}
def _reduce_mean_cost_fn(self, op, x):
# TODO: Repace with more accurate function?
return self._elementwise_cost_fn(op, x)
def _reshape_cost_fn(self, op, x, y):
- return {x.device: KERNEL_LAUNCH_OVERHEAD}
+ return {x.device: x.device.kernel_launch_overhead}
def _select_cost_fn(self, op, xs):
costs = {}
@@ -332,8 +335,8 @@ def _send_cost_fn(self, op, x):
transfer_time = input_size_gb / bandwidth
# NOTE: This assumes all tensors can be sent concurrently
# TODO: Do we need to model the link capacity?
- costs[input_device] = transfer_time
- costs[output_device] = transfer_time
+ costs[input_device] = transfer_time + input_device.kernel_launch_overhead
+ costs[output_device] = transfer_time + output_device.kernel_launch_overhead
return costs
@@ -346,24 +349,24 @@ def _sgd_cost_fn(self, op, *xs):
return costs
def _shape_cost_fn(self, op, x):
- return {x.device: KERNEL_LAUNCH_OVERHEAD}
+ return {x.device: x.device.kernel_launch_overhead}
def _slice_cost_fn(self, op, x, starts, ends, axes, steps=None):
- return {x.device: KERNEL_LAUNCH_OVERHEAD} # TODO is this accurate?
+ return {x.device: x.device.kernel_launch_overhead} # TODO is this accurate?
def _softmax_cost_fn(self, op, x):
# TODO: Repace with more accurate function?
return self._elementwise_cost_fn(op, x)
def _split_cost_fn(self, op, x):
- return {x.device: KERNEL_LAUNCH_OVERHEAD}
+ return {x.device: x.device.kernel_launch_overhead}
def _squeeze_cost_fn(self, op, x):
- return {x.device: KERNEL_LAUNCH_OVERHEAD}
+ return {x.device: x.device.kernel_launch_overhead}
def _transpose_cost_fn(self, op, x):
# TODO: Repace with more accurate function?
return self._elementwise_cost_fn(op, x)
def _unsqueeze_cost_fn(self, op, x):
- return {x.device: KERNEL_LAUNCH_OVERHEAD}
+ return {x.device: x.device.kernel_launch_overhead}
diff --git a/dist_ir/executor/simulator.py b/dist_ir/executor/simulator.py
index f80a17e8..0f09edd7 100644
--- a/dist_ir/executor/simulator.py
+++ b/dist_ir/executor/simulator.py
@@ -13,7 +13,7 @@
dispatch,
)
from .concrete_value import ConcreteValue
-from .cost_model import CostModel, KERNEL_LAUNCH_OVERHEAD
+from .cost_model import CostModel
SECONDS_TO_MICROSECONDS = 1e6
@@ -201,7 +201,7 @@ def simulate(self, function: Function, inputs: Tuple[Any]) -> SimulatorState:
if signature is None:
# Use default cost function if signature not in cost_functions
devices = _get_all_devices(inputs + outputs)
- costs = {device: KERNEL_LAUNCH_OVERHEAD for device in devices}
+ costs = {device: device.kernel_launch_overhead for device in devices}
else:
# Abstract inputs if necessary
abstracted_inputs = abstract_values(inputs, signature)
diff --git a/dist_ir/executor/type_register.py b/dist_ir/executor/type_register.py
index 2e527205..1520b0ea 100644
--- a/dist_ir/executor/type_register.py
+++ b/dist_ir/executor/type_register.py
@@ -538,7 +538,7 @@ def _split_uniform_prop_fn(op, x):
output_shape = list(x.shape)
# TODO: Move this check to attribute error function?
assert output_shape[split_dim] % num_splits == 0
- output_shape[split_dim] //= num_splits
+ output_shape[split_dim] //= int(num_splits)
output_shape = tuple(output_shape)
output_types = tuple(
Tensor(dtype=x.dtype, shape=output_shape, device=x.device)
diff --git a/dist_ir/ir/topology.py b/dist_ir/ir/topology.py
index 10783740..c29232a8 100644
--- a/dist_ir/ir/topology.py
+++ b/dist_ir/ir/topology.py
@@ -22,7 +22,11 @@ def add_device(
device_id = self._device_id_counter
self._device_id_counter += 1
device = Device(
- device_id, device_type, throughput, dram_bandwidth, kernel_launch_overhead
+ device_id,
+ device_type,
+ throughput=throughput,
+ dram_bandwidth=dram_bandwidth,
+ kernel_launch_overhead=kernel_launch_overhead,
)
self._devices.append(device)
self._bandwidths[device] = {}
@@ -52,7 +56,7 @@ def get_uniform_topology(
network_bandwidth=constants.DEFAULT_NETWORK_BANDWIDTH,
):
topology = Topology()
- d0 = topology.add_device("gpu")
+ d0 = topology.add_device("cpu")
for i in range(1, world_size + 1):
topology.add_device(
"gpu",
diff --git a/dist_ir/transforms/mlp_dhp_transform.py b/dist_ir/transforms/mlp_dhp_transform.py
index 93597770..bb177d78 100644
--- a/dist_ir/transforms/mlp_dhp_transform.py
+++ b/dist_ir/transforms/mlp_dhp_transform.py
@@ -36,12 +36,15 @@ def _identity(v, function, output_name):
def _split_value(v, function, num_splits, parallelism_level, dim=0):
output_names = [f"{v.name}_{parallelism_level}_{i}" for i in range(num_splits)]
- return function.add_op(
+ split_values = function.add_op(
"SplitUniform",
inputs=[v],
attributes={"axis": dim, "num_splits": num_splits},
output_names=output_names,
)
+ if not isinstance(split_values, tuple):
+ split_values = (split_values,)
+ return split_values
def _mpi_allgather_values(vs, function, dim, output_names):
@@ -214,6 +217,7 @@ def _partition_inputs_pp(
dim=0,
)
elif k == 1:
+ # Labels will be used on downstream device
consumer_devices = _get_consumer_devices_for_pp_value(
orig_inp,
function,
diff --git a/examples/calibrate_simulator.py b/examples/calibrate_simulator.py
new file mode 100644
index 00000000..e936986a
--- /dev/null
+++ b/examples/calibrate_simulator.py
@@ -0,0 +1,28 @@
+from dist_ir.executor import (
+ calibrate_device_parameters,
+ calibrate_network_bandwidth,
+ network_bandwidth_debug,
+)
+
+
+def main():
+ """
+ (
+ dram_bandwidth,
+ device_throughput,
+ kernel_launch_overhead,
+ ) = calibrate_device_parameters()
+ print(f"Device throughput: {device_throughput:e}")
+ print(f"DRAM bandwidth: {dram_bandwidth:.2e}")
+ print(f"Kernel launch overhead: {kernel_launch_overhead}")
+ network_bandwidth = calibrate_network_bandwidth()
+ print(f"Network bandwidth: {network_bandwidth}")
+ """
+ bandwidths = calibrate_network_bandwidth()
+ for k, v in bandwidths.items():
+ print(f"{k}: {v}")
+ # network_bandwidth_debug()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/grid_search.py b/examples/grid_search.py
index 47b0dd75..92e5e9ea 100644
--- a/examples/grid_search.py
+++ b/examples/grid_search.py
@@ -231,6 +231,7 @@ def run(self, config: DHPConfig):
# TODO: Measure peak memory?
peak_memory = 0
except Exception as e:
+ # TODO: Move this after excepting RuntimeError or remove catch block for RuntimeError?
print(f"Failed to run the configuration {config}:")
traceback.print_exc()
latency = -1
diff --git a/examples/mlp_benchmark.py b/examples/mlp_benchmark.py
new file mode 100644
index 00000000..3ccd1863
--- /dev/null
+++ b/examples/mlp_benchmark.py
@@ -0,0 +1,605 @@
+import argparse
+import csv
+import itertools
+import numpy as np
+import os
+import pickle
+import time
+import tqdm
+import traceback
+import torch
+
+from dist_ir.ir import cpprint
+from dist_ir.backend.torch import run_pytorch
+from dist_ir.executor import (
+ CostModel,
+ Simulator,
+ SequentialExecutor,
+ calibrate_device_parameters,
+ calibrate_network_bandwidth,
+ infer_types,
+)
+from dist_ir.transforms import mlp_dhp_transform
+from examples import mlp, mlp_grid_search
+
+torch.manual_seed(42)
+
+
+def get_inputs(batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers):
+ x = torch.randn(size=(batch_size, input_dim), dtype=torch.float32)
+ z = torch.randn(size=(batch_size, output_dim), dtype=torch.float32)
+ weights = [torch.randn(size=(input_dim, hidden_dim), dtype=torch.float32)]
+ for i in range(1, num_hidden_layers - 1):
+ weights.append(torch.randn(size=(hidden_dim, hidden_dim), dtype=torch.float32))
+ weights.append(torch.randn(size=(hidden_dim, output_dim), dtype=torch.float32))
+ return x, z, weights
+
+
+def mlp_dist_ir_simulation(
+ batch_size,
+ input_dim,
+ hidden_dim,
+ output_dim,
+ num_hidden_layers,
+ x,
+ z,
+ weights,
+ device_throughput,
+ dram_bandwidth,
+ kernel_launch_overhead,
+ network_bandwidth,
+ d,
+ t,
+ p,
+ k,
+ max_memory_GB=10,
+ warmup_steps=5,
+ active_steps=50,
+ verbose=False,
+):
+ world_size = d * t * p
+ topology = mlp.get_topology(
+ world_size,
+ device_throughput=device_throughput,
+ dram_bandwidth=dram_bandwidth,
+ kernel_launch_overhead=kernel_launch_overhead,
+ network_bandwidth=network_bandwidth,
+ )
+ fn = mlp.mlp(
+ batch_size,
+ input_dim,
+ hidden_dim,
+ output_dim,
+ num_hidden_layers,
+ device=topology.devices[0],
+ )
+ if world_size > 1:
+ init_fn, fn = mlp_dhp_transform(fn, d, t, p, k, topology.devices)
+ init_fn = infer_types(init_fn, init_fn.inputs)
+ input_types = tuple(output.type for output in init_fn.outputs)
+ else:
+ input_types = tuple(inp.type for inp in fn.inputs)
+ if verbose:
+ init_fn = infer_types(init_fn, init_fn.inputs)
+ fn = infer_types(fn, init_fn.outputs)
+ cpprint(fn)
+
+ simulator = Simulator(CostModel(topology))
+ simulation = simulator.interpret(fn, input_types)
+ simulated_time = max([simulation.timestamps[d] for d in simulation.timestamps])
+ peak_memory = max([simulation.peak_memory[d] for d in simulation.peak_memory])
+ return simulated_time, peak_memory
+
+
+def mlp_dist_ir_pytorch_backend(
+ batch_size,
+ input_dim,
+ hidden_dim,
+ output_dim,
+ num_hidden_layers,
+ x,
+ z,
+ weights,
+ d,
+ t,
+ p,
+ k,
+ warmup_steps=5,
+ active_steps=50,
+ profile=False,
+ verbose=False,
+):
+ world_size = d * t * p
+ topology = mlp.get_topology(world_size)
+ fn = mlp.mlp(
+ batch_size,
+ input_dim,
+ hidden_dim,
+ output_dim,
+ num_hidden_layers,
+ device=topology.devices[0],
+ )
+ input_data = [x, z] + weights
+ if world_size > 1:
+ init_fn, fn = mlp_dhp_transform(fn, d, t, p, k, topology.devices)
+ init_fn = infer_types(init_fn, init_fn.inputs)
+ fn = infer_types(fn, init_fn.outputs)
+ ex = SequentialExecutor("numpy")
+ input_data = [
+ torch.from_numpy(v).to(torch.float32)
+ for v in ex.compute(init_fn, [v.numpy() for v in input_data])
+ ]
+ if verbose:
+ fn = infer_types(fn, fn.inputs)
+ cpprint(fn)
+
+ # Measure actual execution time
+ per_rank_outputs, runtimes = run_pytorch(
+ fn,
+ input_data,
+ use_gpu=True,
+ num_repetitions=active_steps,
+ num_warmup=warmup_steps,
+ profile=profile,
+ )
+ # TODO or median of max?
+ actual_time = max(np.median(times) for times in runtimes)
+
+ if world_size == 1:
+ gradients = [
+ per_rank_outputs[0][i] for i, v in enumerate(fn.outputs) if "dw" in v.name
+ ]
+ else:
+ gradients = None
+
+ return gradients, actual_time
+
+
+def mlp_pure_pytorch(x, z, weights, warmup_steps=5, active_steps=50, profile=False):
+ batch_size = x.shape[0]
+ x = x.cuda()
+ z = z.cuda()
+ weights = [w.cuda() for w in weights]
+ events = []
+
+ if active_steps < 10:
+ print(
+ "WARNING: The first active step includes large overhead, "
+ "record more steps for a more accurate measurement"
+ )
+
+ def add_event():
+ events.append(torch.cuda.Event(enable_timing=True))
+ events[-1].record()
+
+ if profile:
+ wait_steps = 0
+ else:
+ wait_steps = warmup_steps + active_steps
+
+ with torch.profiler.profile(
+ activities=[
+ torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.CUDA,
+ ],
+ schedule=torch.profiler.schedule(
+ wait=wait_steps, warmup=warmup_steps, active=active_steps
+ ),
+ on_trace_ready=torch.profiler.tensorboard_trace_handler("mlp_pytorch_profile"),
+ ) as p:
+ for i in range(warmup_steps + active_steps):
+ add_event()
+ x_ = x.clone()
+ z_ = z.clone()
+ activations = [x_]
+ matmul_outputs = []
+ torch.cuda.empty_cache()
+ add_event()
+ for w_ in weights:
+ x_ = torch.matmul(x_, w_)
+ matmul_outputs.append(x_)
+ x_ = torch.relu(x_)
+ activations.append(x_)
+
+ loss = torch.square(x_ - z_) / batch_size
+ dy_ = 2 * (x_ - z_) / batch_size
+
+ gradients = []
+ for j, w_ in enumerate(reversed(weights)):
+ x_ = matmul_outputs[len(matmul_outputs) - 1 - j]
+ dy_[x_ <= 0] = 0
+ a_ = activations[len(activations) - 2 - j]
+ da_, dw_ = torch.matmul(dy_, w_.T), torch.matmul(a_.T, dy_)
+ dy_ = da_
+ gradients.append(dw_)
+ if i == (warmup_steps + active_steps - 1):
+ add_event()
+ p.step()
+ torch.cuda.synchronize()
+ runtimes = [
+ events[i].elapsed_time(events[i + 1]) / 1e3
+ for i in range(1, len(events) - 1, 2)
+ ]
+
+ return gradients, np.median(runtimes[warmup_steps:])
+
+
+def benchmark(
+ batch_size,
+ input_dim,
+ hidden_dim,
+ output_dim,
+ num_hidden_layers,
+ device_throughput,
+ dram_bandwidth,
+ kernel_launch_overhead,
+ network_bandwidth,
+ d=1,
+ t=1,
+ p=1,
+ k=1,
+ max_memory_GB=10,
+):
+ world_size = d * t * p
+ x, z, weights = get_inputs(
+ batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers
+ )
+ simulated_time, peak_memory = mlp_dist_ir_simulation(
+ batch_size,
+ input_dim,
+ hidden_dim,
+ output_dim,
+ num_hidden_layers,
+ x,
+ z,
+ weights,
+ device_throughput,
+ dram_bandwidth,
+ kernel_launch_overhead,
+ network_bandwidth,
+ d,
+ t,
+ p,
+ k,
+ )
+ return simulated_time, -1
+
+ if peak_memory / (1024 ** 3) > max_memory_GB:
+ if world_size == 1:
+ return -1, -1, -1
+ else:
+ return -1, -1
+
+ dist_ir_gradients, pytorch_backend_time = mlp_dist_ir_pytorch_backend(
+ batch_size,
+ input_dim,
+ hidden_dim,
+ output_dim,
+ num_hidden_layers,
+ x,
+ z,
+ weights,
+ d,
+ t,
+ p,
+ k,
+ )
+ torch.cuda.empty_cache()
+
+ if world_size == 1:
+ pytorch_gradients, pure_pytorch_time = mlp_pure_pytorch(x, z, weights)
+
+ for x, y in zip(pytorch_gradients, dist_ir_gradients):
+ np.testing.assert_array_almost_equal(
+ x.detach().cpu().numpy(), y.detach().cpu().numpy(), decimal=2
+ )
+
+ return simulated_time, pytorch_backend_time, pure_pytorch_time
+ else:
+ return simulated_time, pytorch_backend_time
+
+
+def distributed_grid_search(
+ device_throughput, dram_bandwidth, kernel_launch_overhead, network_bandwidth
+):
+ batch_size = 8192
+ all_dims = [1024, 2048, 4096]
+ all_num_layers = [8, 16]
+ world_size = 8 # torch.cuda.device_count()
+ all_degrees = mlp_grid_search.get_all_degrees(world_size)
+ configs = []
+ for (dim, num_layers) in itertools.product(all_dims, all_num_layers):
+ for (d, t, p) in all_degrees:
+ if p == 1:
+ k = 1
+ configs.append((d, t, p, k, dim, num_layers))
+ else:
+ for i in range(1, 5):
+ k = int(2 ** i)
+ configs.append((d, t, p, k, dim, num_layers))
+
+ fieldnames = [
+ "Dim",
+ "Layers",
+ "Data parallel degree",
+ "Tensor model parallel degree",
+ "Pipeline parallel degree",
+ "Microbatches",
+ "Simulated time",
+ "PyTorch backend time",
+ ]
+
+ with open("mlp_benchmark_dgx_simulation.csv", "w") as f:
+ writer = csv.writer(f)
+ writer.writerow(fieldnames)
+ # for (d, t, p, k, dim, layers) in configs:
+ for (d, t, p, k, dim, layers) in tqdm.tqdm(configs):
+ try:
+ assert d > 1 or t > 1 or p > 1
+ simulated_time, pytorch_backend_time = benchmark(
+ batch_size,
+ dim,
+ dim,
+ dim,
+ layers,
+ device_throughput,
+ dram_bandwidth,
+ kernel_launch_overhead,
+ network_bandwidth,
+ d,
+ t,
+ p,
+ k,
+ )
+ except Exception as e:
+ traceback.print_exc()
+ simulated_time = -1
+ pytorch_backend_time = -1
+ pure_pytorch_time = -1
+ writer.writerow(
+ [
+ dim,
+ layers,
+ d,
+ t,
+ p,
+ k,
+ simulated_time,
+ pytorch_backend_time,
+ ]
+ )
+ f.flush()
+ torch.cuda.empty_cache()
+
+
+def grid_search(device_throughput, dram_bandwidth, kernel_launch_overhead):
+ all_batch_sizes = [1024, 2048, 4096]
+ all_dims = [1024, 2048, 4096]
+ all_num_hidden_layers = [8, 12, 16]
+ fieldnames = [
+ "Batch size",
+ "Dim",
+ "Layers",
+ "Simulated time",
+ "PyTorch backend time",
+ "Pure PyTorch time",
+ ]
+
+ with open("mlp_benchmark.csv", "w") as f:
+ writer = csv.writer(f)
+ writer.writerow(fieldnames)
+ for (batch_size, dim, num_hidden_layers) in tqdm.tqdm(
+ list(itertools.product(all_batch_sizes, all_dims, all_num_hidden_layers))
+ ):
+ try:
+ simulated_time, pytorch_backend_time, pure_pytorch_time = benchmark(
+ batch_size,
+ dim,
+ dim,
+ dim,
+ num_hidden_layers,
+ device_throughput,
+ dram_bandwidth,
+ kernel_launch_overhead,
+ )
+ except Exception as e:
+ traceback.print_exc()
+ simulated_time = -1
+ pytorch_backend_time = -1
+ pure_pytorch_time = -1
+ writer.writerow(
+ [
+ batch_size,
+ dim,
+ num_hidden_layers,
+ simulated_time,
+ pytorch_backend_time,
+ pure_pytorch_time,
+ ]
+ )
+ f.flush()
+ torch.cuda.empty_cache()
+
+
+def main(args):
+ if args.simulation_parameters_file is not None and os.path.exists(
+ args.simulation_parameters_file
+ ):
+ with open(args.simulation_parameters_file, "rb") as f:
+ simulation_parameters = pickle.load(f)
+ print(
+ f"Reading simulation parameters from {args.simulation_parameters_file}..."
+ )
+ args.device_throughput = simulation_parameters["device_throughput"]
+ args.dram_bandwidth = simulation_parameters["dram_bandwidth"]
+ args.kernel_launch_overhead = simulation_parameters["kernel_launch_overhead"]
+ args.network_bandwidth = simulation_parameters["network_bandwidth"]
+ else:
+ simulation_parameters = {}
+ update_simulation_parameters = False
+ if args.calibrate_device_parameters and (
+ args.mode == "simulate"
+ or args.mode == "grid_search"
+ or args.mode == "distributed_grid_search"
+ ):
+ print("Calibrating device parameters...")
+ (
+ args.dram_bandwidth,
+ args.device_throughput,
+ args.kernel_launch_overhead,
+ ) = calibrate_device_parameters()
+ simulation_parameters["dram_bandwidth"] = args.dram_bandwidth
+ simulation_parameters["device_throughput"] = args.device_throughput
+ simulation_parameters["kernel_launch_overhead"] = args.kernel_launch_overhead
+ update_simulation_parameters = True
+ print(f"DRAM bandwidth: {args.dram_bandwidth:.2e}")
+ print(f"Device throughput: {args.device_throughput:.2e}")
+ print(f"Kernel launch overhead: {args.kernel_launch_overhead:.2e}")
+ if args.calibrate_network_bandwidth and (
+ args.mode == "simulate"
+ or args.mode == "grid_search"
+ or args.mode == "distributed_grid_search"
+ ):
+ args.network_bandwidth = calibrate_network_bandwidth()
+ simulation_parameters["network_bandwidth"] = args.network_bandwidth
+ print(f"Network bandwidth: {args.network_bandwidth}")
+ update_simulation_parameters = True
+ if update_simulation_parameters and args.simulation_parameters_file is not None:
+ with open(args.simulation_parameters_file, "wb") as f:
+ pickle.dump(simulation_parameters, f)
+ if args.mode == "grid_search":
+ grid_search(
+ args.device_throughput,
+ args.dram_bandwidth,
+ args.kernel_launch_overhead,
+ )
+ elif args.mode == "distributed_grid_search":
+ distributed_grid_search(
+ args.device_throughput,
+ args.dram_bandwidth,
+ args.kernel_launch_overhead,
+ args.network_bandwidth,
+ )
+ elif args.mode == "simulate":
+ x, z, weights = get_inputs(
+ args.batch_size, args.dim, args.dim, args.dim, args.layers
+ )
+ simulated_time, peak_memory = mlp_dist_ir_simulation(
+ args.batch_size,
+ Gargs.dim,
+ args.dim,
+ args.dim,
+ args.layers,
+ x,
+ z,
+ weights,
+ args.device_throughput,
+ args.dram_bandwidth,
+ args.kernel_launch_overhead,
+ args.network_bandwidth,
+ args.d,
+ args.t,
+ args.p,
+ args.k,
+ verbose=args.verbose,
+ )
+ print(f"Simulated latency: {simulated_time * 1000:.2f} ms")
+ print(f"Simulated peak memory: {peak_memory / (1024 ** 3):.2f} GB")
+ elif args.mode == "backend":
+ x, z, weights = get_inputs(
+ args.batch_size, args.dim, args.dim, args.dim, args.layers
+ )
+ _, pytorch_backend_time = mlp_dist_ir_pytorch_backend(
+ args.batch_size,
+ args.dim,
+ args.dim,
+ args.dim,
+ args.layers,
+ x,
+ z,
+ weights,
+ args.d,
+ args.t,
+ args.p,
+ args.k,
+ warmup_steps=args.warmup_steps,
+ active_steps=args.active_steps,
+ profile=args.profile,
+ verbose=args.verbose,
+ )
+ print(f"PyTorch backend latency: {pytorch_backend_time * 1000:.2f} ms")
+ elif args.mode == "pytorch":
+ x, z, weights = get_inputs(
+ args.batch_size, args.dim, args.dim, args.dim, args.layers
+ )
+ _, pure_pytorch_time = mlp_pure_pytorch(
+ x,
+ z,
+ weights,
+ warmup_steps=args.warmup_steps,
+ active_steps=args.active_steps,
+ profile=args.profile,
+ )
+
+ print(f"Pure PyTorch latency: {pure_pytorch_time * 1000:.2f} ms")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="MLP benchmark")
+ parser.add_argument(
+ "--mode",
+ choices=[
+ "grid_search",
+ "distributed_grid_search",
+ "pytorch",
+ "simulate",
+ "backend",
+ ],
+ required=True,
+ )
+ parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
+ parser.add_argument("--dim", type=int, default=256, help="Weight dim")
+ parser.add_argument("--layers", type=int, default=16, help="# layers")
+ parser.add_argument("--warmup_steps", type=int, default=5, help="# warmup steps")
+ parser.add_argument("--active_steps", type=int, default=100, help="# active steps")
+ parser.add_argument(
+ "--calibrate_device_parameters",
+ action="store_true",
+ default=False,
+ help="Calibrate device parameters",
+ )
+ parser.add_argument(
+ "--calibrate_network_bandwidth",
+ action="store_true",
+ default=False,
+ help="Calibrate network bandwidth",
+ )
+ parser.add_argument(
+ "--simulation_parameters_file",
+ type=str,
+ default=None,
+ help="File to load/save simulation parameters from/to",
+ )
+ parser.add_argument("--profile", action="store_true", default=False, help="Profile")
+ parser.add_argument(
+ "--device_throughput", type=float, default=1.4e13, help="Device throughput"
+ )
+ parser.add_argument(
+ "--dram_bandwidth", type=float, default=9e11, help="DRAM Bandwidth"
+ )
+ parser.add_argument(
+ "--network_bandwidth", type=float, default=64, help="Network bandwidth in Gbps"
+ )
+ parser.add_argument(
+ "--kernel_launch_overhead",
+ type=float,
+ default=1e-5,
+ help="Kernel launch overhead",
+ )
+ parser.add_argument("-d", type=int, default=1, help="Data parallel degree")
+ parser.add_argument("-t", type=int, default=1, help="Tensor model parallel degree")
+ parser.add_argument("-p", type=int, default=1, help="Pipeline parallel degree")
+ parser.add_argument("-k", type=int, default=1, help="# microbatches")
+ parser.add_argument("--verbose", action="store_true", help="Verbose")
+ args = parser.parse_args()
+ main(args)
diff --git a/examples/send_benchmark.py b/examples/send_benchmark.py
new file mode 100644
index 00000000..14f08a0e
--- /dev/null
+++ b/examples/send_benchmark.py
@@ -0,0 +1,63 @@
+import argparse
+import os
+import time
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+
+NUM_WARMUP_TRIALS = 25
+NUM_TRIALS = 10
+
+
+def send(rank, src_rank, world_size, group_ranks):
+ os.environ["MASTER_ADDR"] = "127.0.0.1" # TODO make these configurable
+ os.environ["MASTER_PORT"] = "29500"
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+ group = dist.new_group(group_ranks)
+ runtimes = []
+ x = torch.randn(size=(8192, 8192), dtype=torch.float32).to(f"cuda:{rank}")
+ for i in range(NUM_WARMUP_TRIALS + NUM_TRIALS):
+ # torch.distributed.barrier(group=group)
+ start = time.time()
+ dist.broadcast(x, src_rank, group=group)
+ torch.cuda.synchronize(device=rank)
+ end = time.time()
+ runtimes.append(end - start)
+ dist.destroy_process_group()
+ print(f"Send latencies: {runtimes[NUM_WARMUP_TRIALS:]}")
+
+
+def recv(rank, src_rank, world_size, group_ranks):
+ os.environ["MASTER_ADDR"] = "127.0.0.1" # TODO make these configurable
+ os.environ["MASTER_PORT"] = "29500"
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+ group = dist.new_group(group_ranks)
+ runtimes = []
+ x = torch.zeros(size=(8192, 8192), dtype=torch.float32).to(f"cuda:{rank}")
+ for i in range(NUM_WARMUP_TRIALS + NUM_TRIALS):
+ # torch.distributed.barrier(group=group)
+ start = time.time()
+ dist.broadcast(x, src_rank, group=group)
+ torch.cuda.synchronize(device=rank)
+ end = time.time()
+ runtimes.append(end - start)
+ dist.destroy_process_group()
+ print(f"Recv latencies: {runtimes[NUM_WARMUP_TRIALS:]}")
+
+
+def main(args):
+ p_src = mp.Process(target=send, args=(args.src_rank, args.src_rank, 2, [0, 1]))
+ p_dst = mp.Process(target=recv, args=(1 - args.src_rank, args.src_rank, 2, [0, 1]))
+
+ p_src.start()
+ p_dst.start()
+ p_src.join()
+ p_dst.join()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--src_rank", choices=[0, 1], type=int, required=True)
+ args = parser.parse_args()
+ main(args)
diff --git a/notebooks/mlp_training_grid_search_simulator_accuracy.ipynb b/notebooks/mlp_training_grid_search_simulator_accuracy.ipynb
new file mode 100644
index 00000000..109979f3
--- /dev/null
+++ b/notebooks/mlp_training_grid_search_simulator_accuracy.ipynb
@@ -0,0 +1,4146 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "plain-variance",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from collections import defaultdict\n",
+ "import csv\n",
+ "from itertools import product\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import plotly.express as px\n",
+ "from plotly.validators.scatter.marker import SymbolValidator"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "preceding-firewall",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "SIMULATION_FILENAME = \"~/Downloads/mlp_grid_search_results_simulation_v100.csv\"\n",
+ "PYTORCH_FILENAME = \"~/Downloads/mlp_grid_search_results_pytorch_v100.csv\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "derived-socket",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.rcParams[\"font.size\"] = 12\n",
+ "plt.rcParams[\"figure.max_open_warning\"] = False\n",
+ "pd.options.mode.chained_assignment = None"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "palestinian-franklin",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "raw_symbols = SymbolValidator().values\n",
+ "symbols = []\n",
+ "for i in range(0, len(raw_symbols), 3):\n",
+ " symbols.append(raw_symbols[i])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "adjusted-natural",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_parallelism_style(dp, hp, pp):\n",
+ " if dp == 1 and hp == 1 and pp == 1:\n",
+ " return \"Sequential\"\n",
+ " elif dp > 1 and hp == 1 and pp == 1:\n",
+ " return \"D\"\n",
+ " elif dp == 1 and hp > 1 and pp == 1:\n",
+ " return \"T\"\n",
+ " elif dp == 1 and hp == 1 and pp > 1:\n",
+ " return \"P\"\n",
+ " elif dp > 1 and hp > 1 and pp == 1:\n",
+ " return \"D/T\"\n",
+ " elif dp == 1 and hp > 1 and pp > 1:\n",
+ " return \"T/P\"\n",
+ " elif dp > 1 and hp == 1 and pp > 1:\n",
+ " return \"D/P\"\n",
+ " elif dp > 1 and hp > 1 and pp > 1:\n",
+ " return \"D/T/P\"\n",
+ " else:\n",
+ " raise ValueError(f\"Invalid degree combination dp={dp}, hp={hp}, pp={pp}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "professional-nation",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plot(x, y, xlabel, ylabel):\n",
+ " df = pd.read_csv(PYTORCH_FILENAME)\n",
+ " df_simulation = pd.read_csv(SIMULATION_FILENAME)\n",
+ " df = df.rename(\n",
+ " columns={\"latency\": \"pytorch_latency\", \"throughput\": \"pytorch_throughput\"}\n",
+ " )\n",
+ " df[\"simulated_latency\"] = df_simulation[\"latency\"]\n",
+ " df[\"simulated_throughput\"] = df_simulation[\"throughput\"]\n",
+ " df[\"dummy_column_for_size\"] = 1.0\n",
+ " parallelism_styles = [\n",
+ " get_parallelism_style(dp, hp, pp)\n",
+ " for (dp, hp, pp) in df[[\"dp_degree\", \"hp_degree\", \"pp_degree\"]].values\n",
+ " ]\n",
+ " df[\"parallelism_style\"] = parallelism_styles\n",
+ " fig = px.scatter(\n",
+ " df,\n",
+ " x=x,\n",
+ " y=y,\n",
+ " color=\"parallelism_style\",\n",
+ " labels={\n",
+ " x: xlabel,\n",
+ " y: ylabel,\n",
+ " \"parallelism_style\": \"Parallelism style\",\n",
+ " },\n",
+ " color_discrete_sequence=[\n",
+ " \"#1f77b4\", # muted blue\n",
+ " \"#ff7f0e\", # safety orange\n",
+ " \"#2ca02c\", # cooked asparagus green\n",
+ " \"#d62728\", # brick red\n",
+ " \"#9467bd\", # muted purple\n",
+ " \"#8c564b\", # chestnut brown\n",
+ " \"#e377c2\", # raspberry yogurt pink\n",
+ " \"#7f7f7f\", # middle gray\n",
+ " \"#bcbd22\", # curry yellow-green\n",
+ " \"#17becf\", # blue-teal\n",
+ " ],\n",
+ " category_orders={\n",
+ " \"parallelism_style\": [\n",
+ " \"D\",\n",
+ " \"T\",\n",
+ " \"P\",\n",
+ " \"D/T\",\n",
+ " \"T/P\",\n",
+ " \"D/P\",\n",
+ " \"D/T/P\",\n",
+ " \"Sequential\",\n",
+ " ],\n",
+ " },\n",
+ " hover_data=[\"dp_degree\", \"hp_degree\", \"pp_degree\", \"num_microbatches\"],\n",
+ " symbol=\"parallelism_style\",\n",
+ " size=\"dummy_column_for_size\",\n",
+ " size_max=10,\n",
+ " )\n",
+ " min_val = min(min(df[x]), min(df[y]))\n",
+ " max_val = max(max(df[x]), max(df[y]))\n",
+ " fig.add_shape(\n",
+ " type=\"line\",\n",
+ " x0=min_val,\n",
+ " y0=min_val,\n",
+ " x1=max_val,\n",
+ " y1=max_val,\n",
+ " line=dict(color=\"green\", width=3, dash=\"dot\"),\n",
+ " )\n",
+ " fig.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "interpreted-trailer",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ " \n",
+ " "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "customdata": [
+ [
+ 2,
+ 1,
+ 1,
+ 1
+ ],
+ [
+ 2,
+ 1,
+ 1,
+ 1
+ ],
+ [
+ 2,
+ 1,
+ 1,
+ 1
+ ],
+ [
+ 2,
+ 1,
+ 1,
+ 1
+ ],
+ [
+ 4,
+ 1,
+ 1,
+ 1
+ ],
+ [
+ 4,
+ 1,
+ 1,
+ 1
+ ],
+ [
+ 4,
+ 1,
+ 1,
+ 1
+ ],
+ [
+ 4,
+ 1,
+ 1,
+ 1
+ ]
+ ],
+ "hovertemplate": "Parallelism style=D
PyTorch latency (seconds)=%{x}
Simulated latency (seconds)=%{y}
dummy_column_for_size=%{marker.size}
dp_degree=%{customdata[0]}
hp_degree=%{customdata[1]}
pp_degree=%{customdata[2]}
num_microbatches=%{customdata[3]}
PyTorch latency (seconds)=%{x}
Simulated latency (seconds)=%{y}
dummy_column_for_size=%{marker.size}
dp_degree=%{customdata[0]}
hp_degree=%{customdata[1]}
pp_degree=%{customdata[2]}
num_microbatches=%{customdata[3]}
PyTorch latency (seconds)=%{x}
Simulated latency (seconds)=%{y}
dummy_column_for_size=%{marker.size}
dp_degree=%{customdata[0]}
hp_degree=%{customdata[1]}
pp_degree=%{customdata[2]}
num_microbatches=%{customdata[3]}
PyTorch latency (seconds)=%{x}
Simulated latency (seconds)=%{y}
dummy_column_for_size=%{marker.size}
dp_degree=%{customdata[0]}
hp_degree=%{customdata[1]}
pp_degree=%{customdata[2]}
num_microbatches=%{customdata[3]}
PyTorch latency (seconds)=%{x}
Simulated latency (seconds)=%{y}
dummy_column_for_size=%{marker.size}
dp_degree=%{customdata[0]}
hp_degree=%{customdata[1]}
pp_degree=%{customdata[2]}
num_microbatches=%{customdata[3]}
PyTorch latency (seconds)=%{x}
Simulated latency (seconds)=%{y}
dummy_column_for_size=%{marker.size}
dp_degree=%{customdata[0]}
hp_degree=%{customdata[1]}
pp_degree=%{customdata[2]}
num_microbatches=%{customdata[3]}
PyTorch latency (seconds)=%{x}
Simulated latency (seconds)=%{y}
dummy_column_for_size=%{marker.size}
dp_degree=%{customdata[0]}
hp_degree=%{customdata[1]}
pp_degree=%{customdata[2]}
num_microbatches=%{customdata[3]}