Skip to content

Commit

Permalink
Add fp16 support
Browse files Browse the repository at this point in the history
  • Loading branch information
santhnm2 committed Sep 20, 2021
1 parent 3c5e828 commit a5d54aa
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 78 deletions.
4 changes: 3 additions & 1 deletion dist_ir/backend/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..executor.rank_projector import project
from ..ir import Function, cpprint
from ..ir.device import Device
from ..ir.type import Int32, Int64, Float32, Type
from ..ir.type import Int32, Int64, Float16, Float32, Type

# NOTE: The code currently suffers from this issue, more investigation needed:
# https://github.com/pytorch/pytorch/issues/11201
Expand Down Expand Up @@ -166,6 +166,8 @@ def _recv(shape=None, from_d=None, group=None, dtype=None, ctx=None):
x = torch.zeros(shape).int()
elif isinstance(dtype, Int64):
x = torch.zeros(shape).long()
elif isinstance(dtype, Float16):
x = torch.zeros(shape).half()
elif isinstance(dtype, Float32):
x = torch.zeros(shape).float()
else:
Expand Down
2 changes: 2 additions & 0 deletions examples/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,8 @@ def _resolve_dtype(dtype):
return torch.int32
elif dtype == np.int64:
return torch.int64
elif dtype == np.float16:
return torch.float16
elif dtype == np.float32:
return torch.float32
else:
Expand Down
2 changes: 2 additions & 0 deletions examples/gpt2_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class GPTGridSearch(GridSearch):
def __init__(
self,
backend,
dtype,
use_gpu,
output_file,
device_throughput,
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(
super().__init__(
model_params,
backend,
dtype,
use_gpu,
output_file,
device_throughput,
Expand Down
3 changes: 3 additions & 0 deletions examples/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
self,
model_params,
backend,
dtype,
use_gpu,
output_file,
device_throughput,
Expand All @@ -55,6 +56,7 @@ def __init__(
):
self.model_params = model_params
self.backend = backend
self.dtype = dtype
self.use_gpu = use_gpu
self.output_file = output_file
self.device_throughput = device_throughput
Expand Down Expand Up @@ -272,6 +274,7 @@ def run_grid_search(args, grid_search_cls):
}
grid_search = grid_search_cls(
args.backend,
args.dtype,
args.use_gpu,
args.output_file,
args.device_throughput,
Expand Down
106 changes: 71 additions & 35 deletions examples/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@
import torch

from dist_ir.ir import FunctionMaker, Topology, get_uniform_topology, Value
from dist_ir.ir.type import Int32, Float32, Tensor, abstract_values
from dist_ir.executor import CostModel, Simulator, infer_types
from dist_ir.ir.type import Int32, Float16, Float32, Tensor, abstract_values
from dist_ir.executor import (
CostModel,
Simulator,
ConcreteValue,
infer_types,
sequentially_execute,
)
from dist_ir.transforms import mlp_dhp_transform
from .parser import Parser
import dist_ir.backend.torch as torch_backend
Expand Down Expand Up @@ -39,42 +45,40 @@ def get_typed_input_values(inputs, batch_size, input_dim, output_dim):
return tuple(typed_inputs)


def get_input_data(batch_size, dim, num_layers):
x = np.random.normal(size=(batch_size, dim))
z = np.random.normal(size=(batch_size, dim))
n = batch_size
weights = [np.random.normal(size=(dim, dim))]
for i in range(1, num_layers - 1):
weights.append(np.random.normal(size=(dim, dim)))
weights.append(np.random.normal(size=(dim, dim)))
def get_input_data(inputs, batch_size, input_dim, output_dim, device, dtype):
input_data = []
x = np.random.normal(0, 0.02, size=(batch_size, input_dim))
z = np.random.normal(0, 0.02, size=(batch_size, output_dim))
n = np.int64(batch_size)
weights = [np.random.normal(0, 0.02, size=inp.type.shape) for inp in inputs[3:]]
input_data = [x, z, n] + weights
input_data = [
v.astype(np.float32) if i != 2 else v for i, v in enumerate(input_data)
]
input_data = [v.astype(dtype) if i != 2 else v for i, v in enumerate(input_data)]
input_data = [ConcreteValue(v, device) for v in input_data]
assert len(input_data) == len(inputs)
return input_data


def mlp(input_dim, hidden_dim, output_dim, num_hidden_layers, device):
def mlp(input_dim, hidden_dim, output_dim, num_hidden_layers, device, dtype):
function = FunctionMaker(name="mlp")
x = function.add_input_value(
"x",
Tensor(dtype=Float32(), shape=None, device=device),
Tensor(dtype=dtype(), shape=None, device=device),
)
z = function.add_input_value(
"z",
Tensor(dtype=Float32(), shape=None, device=device),
Tensor(dtype=dtype(), shape=None, device=device),
)
n = function.add_input_value("n", Int32(device=device))
weights = []
for i in range(num_hidden_layers - 1):
w = function.add_input_value(
f"w{chr(ord('A')+i)}",
Tensor(dtype=Float32(), shape=(input_dim, hidden_dim), device=device),
Tensor(dtype=dtype(), shape=(input_dim, hidden_dim), device=device),
)
weights.append(w)
w = function.add_input_value(
f"w{chr(ord('A')+i+1)}",
Tensor(dtype=Float32(), shape=(hidden_dim, output_dim), device=device),
Tensor(dtype=dtype(), shape=(hidden_dim, output_dim), device=device),
)
weights.append(w)

Expand Down Expand Up @@ -107,24 +111,24 @@ def mlp(input_dim, hidden_dim, output_dim, num_hidden_layers, device):


def mlp_inference(
batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, device
batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, device, dtype
):
function = FunctionMaker(name="mlp")
weights = []
for i in range(num_hidden_layers - 1):
w = function.add_input_value(
f"w{chr(ord('A')+i)}",
Tensor(dtype=Float32(), shape=(input_dim, hidden_dim), device=device),
Tensor(dtype=dtype(), shape=(input_dim, hidden_dim), device=device),
)
weights.append(w)
w = function.add_input_value(
f"w{chr(ord('A')+i+1)}",
Tensor(dtype=Float32(), shape=(hidden_dim, output_dim), device=device),
Tensor(dtype=dtype(), shape=(hidden_dim, output_dim), device=device),
)
weights.append(w)
x = function.add_input_value(
"x",
Tensor(dtype=Float32(), shape=(batch_size, input_dim), device=device),
Tensor(dtype=dtype(), shape=(batch_size, input_dim), device=device),
)

a = x
Expand All @@ -136,7 +140,7 @@ def mlp_inference(


def mlp_inference_dp(
batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, devices
batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, devices, dtype
):
num_devices = len(devices)
assert batch_size % num_devices == 0
Expand All @@ -147,16 +151,16 @@ def mlp_inference_dp(
for i in range(num_hidden_layers - 1):
weights[i, d] = function.add_input_value(
f"w{chr(ord('A')+i)}_{d.device_id}",
Tensor(dtype=Float32(), shape=(input_dim, hidden_dim), device=d),
Tensor(dtype=dtype(), shape=(input_dim, hidden_dim), device=d),
)
weights[num_hidden_layers - 1, d] = function.add_input_value(
f"w{chr(ord('A')+i+1)}_{d.device_id}",
Tensor(dtype=Float32(), shape=(hidden_dim, output_dim), device=d),
Tensor(dtype=dtype(), shape=(hidden_dim, output_dim), device=d),
)
x[d] = function.add_input_value(
f"x_{d.device_id}",
Tensor(
dtype=Float32(), shape=(batch_size // num_devices, input_dim), device=d
dtype=dtype(), shape=(batch_size // num_devices, input_dim), device=d
),
)

Expand Down Expand Up @@ -274,20 +278,38 @@ def simulate(function, input_types, topology, allreduce_parameters=None):
return simulation


def run_pytorch(function, input_data, world_size, use_gpu=True):
def run_pytorch(function, input_data, world_size, use_gpu=torch.cuda.is_available()):
# TODO: Move this to a utils file
def _resolve_dtype(dtype):
if dtype == np.int32:
return torch.int32
elif dtype == np.int64:
return torch.int64
elif dtype == np.float16:
return torch.float16
elif dtype == np.float32:
return torch.float32
else:
raise NotImplementedError(dtype)

if use_gpu and world_size > torch.cuda.device_count():
raise ValueError(
f"Specified world size is {world_size}, but only "
f"{torch.cuda.device_count()} GPUs available"
)
pytorch_input_data = [
torch.tensor(x.val, dtype=_resolve_dtype(x.val.dtype))
if isinstance(x.val, np.ndarray)
else torch.tensor(x.val, dtype=torch.int32)
for x in input_data
]
input_types = abstract_values(
input_data,
tuple(
Tensor if isinstance(input_data[i].val, np.ndarray) else Int32
for i in range(len(input_data))
),
)
pytorch_input_data = [torch.tensor(x.val, dtype=torch.float32) for x in input_data]
per_rank_outputs, runtimes = torch_backend.run_pytorch(
function,
pytorch_input_data,
Expand All @@ -302,6 +324,7 @@ def run_pytorch(function, input_data, world_size, use_gpu=True):
def run_mlp(
phase,
backend,
dtype,
use_gpu,
batch_size,
input_dim,
Expand All @@ -319,6 +342,8 @@ def run_mlp(
trace_file,
verbose=False,
):
dist_ir_dtype = Float32 if dtype == "fp32" else Float16
numpy_dtype = np.float32 if dtype == "fp32" else np.float16
world_size = dp_degree * hp_degree * pp_degree
topology = get_uniform_topology(
world_size,
Expand All @@ -335,6 +360,7 @@ def run_mlp(
output_dim,
num_hidden_layers,
topology.devices[0],
dist_ir_dtype,
)
elif phase == "inference":
fn = mlp_inference(
Expand All @@ -343,13 +369,24 @@ def run_mlp(
output_dim,
num_hidden_layers,
topology.devices[0],
dist_ir_dtype,
)

if verbose:
parameter_count, model_size, parameter_count_str, model_size_str = get_stats(fn)
print("Parameter count:", parameter_count_str)
print("Model size:", model_size_str)

if backend == "pytorch":
input_data = get_input_data(
fn.inputs,
batch_size,
input_dim,
output_dim,
topology.devices[0],
numpy_dtype,
)

if world_size > 1:
init_fn, transformed_fn = mlp_dhp_transform(
fn,
Expand All @@ -365,13 +402,17 @@ def run_mlp(
init_fn = infer_types(init_fn, typed_inputs)
transformed_fn = infer_types(transformed_fn, init_fn.outputs)
input_types = tuple(output.type for output in init_fn.outputs)
if backend == "pytorch":
transformed_input_data = sequentially_execute(init_fn, input_data)
else:
typed_inputs = get_typed_input_values(
fn.inputs, batch_size, input_dim, output_dim
)
fn = infer_types(fn, typed_inputs)
transformed_fn = fn
input_types = tuple(inp.type for inp in fn.inputs)
if backend == "pytorch":
transformed_input_data = input_data
transformed_fn = add_optimizer_ops(transformed_fn)
if backend == "simulate":
simulation = simulate(transformed_fn, input_types, topology)
Expand All @@ -381,19 +422,14 @@ def run_mlp(
simulation.dump_chrome_trace(trace_file)
return simulation
elif backend == "pytorch":
input_data = [
ConcreteValue(
np.random.normal(size=typ.size).astype(np.float32), device=typ.device
)
for typ in input_types
]
return run_pytorch(fn, input_data, world_size, use_gpu)
return run_pytorch(transformed_fn, transformed_input_data, world_size, use_gpu)


def main(args):
run_mlp(
args.phase,
args.backend,
args.dtype,
args.use_gpu,
args.batch_size,
args.input_dim,
Expand Down
25 changes: 19 additions & 6 deletions examples/mlp_grid_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

from dist_ir.ir import Value
from dist_ir.ir.type import Tensor
from dist_ir.ir.type import Tensor, Float32, Float16
from dist_ir.executor import infer_types, sequentially_execute, ConcreteValue
from dist_ir.transforms import mlp_dhp_transform
from . import mlp
Expand All @@ -11,6 +13,7 @@ class MLPGridSearch(GridSearch):
def __init__(
self,
backend,
dtype,
use_gpu,
output_file,
device_throughput,
Expand All @@ -30,6 +33,7 @@ def __init__(
super().__init__(
model_params,
backend,
dtype,
use_gpu,
output_file,
device_throughput,
Expand All @@ -46,16 +50,25 @@ def get_model_and_input_data(self, batch_size, model_size):
if model_size not in self.models:
num_layers, dim = self.model_params[model_size]
self.models[model_size] = mlp.mlp(
dim, dim, dim, num_layers, self.topology.devices[0]
dim,
dim,
dim,
num_layers,
self.topology.devices[0],
Float32 if self.dtype == "fp32" else Float16,
)

fn = self.models[model_size]
num_layers, dim = self.model_params[model_size]
if self.backend == "pytorch":
input_data = mlp.get_input_data(batch_size, dim, num_layers)
input_data = tuple(
ConcreteValue(t, inp.type.device)
for t, inp in zip(input_data, fn.inputs)
dtype = np.float32 if self.dtype == "fp32" else np.float16
input_data = mlp.get_input_data(
fn.inputs,
batch_size,
dim,
dim,
self.topology.devices[0],
dtype,
)
else:
input_data = mlp.get_typed_input_values(fn.inputs, batch_size, dim, dim)
Expand Down
Loading

0 comments on commit a5d54aa

Please sign in to comment.