From a5d54aa1b02ea5eaeadf391330d7c906f2932b5c Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Sun, 19 Sep 2021 21:30:50 -0700 Subject: [PATCH] Add fp16 support --- dist_ir/backend/torch.py | 4 +- examples/gpt2.py | 2 + examples/gpt2_grid_search.py | 2 + examples/grid_search.py | 3 + examples/mlp.py | 106 ++++++++++++++++++++++----------- examples/mlp_grid_search.py | 25 ++++++-- examples/parser.py | 5 +- test/test_grid_search.py | 48 ++++++++++++--- test/test_mlp_dhp_transform.py | 25 ++++---- test/test_pytorch_backend.py | 62 ++++++++++++++----- 10 files changed, 204 insertions(+), 78 deletions(-) diff --git a/dist_ir/backend/torch.py b/dist_ir/backend/torch.py index ccba2ab7..c0fb78f3 100644 --- a/dist_ir/backend/torch.py +++ b/dist_ir/backend/torch.py @@ -14,7 +14,7 @@ from ..executor.rank_projector import project from ..ir import Function, cpprint from ..ir.device import Device -from ..ir.type import Int32, Int64, Float32, Type +from ..ir.type import Int32, Int64, Float16, Float32, Type # NOTE: The code currently suffers from this issue, more investigation needed: # https://github.com/pytorch/pytorch/issues/11201 @@ -166,6 +166,8 @@ def _recv(shape=None, from_d=None, group=None, dtype=None, ctx=None): x = torch.zeros(shape).int() elif isinstance(dtype, Int64): x = torch.zeros(shape).long() + elif isinstance(dtype, Float16): + x = torch.zeros(shape).half() elif isinstance(dtype, Float32): x = torch.zeros(shape).float() else: diff --git a/examples/gpt2.py b/examples/gpt2.py index 9cffee2f..d2d2d949 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -561,6 +561,8 @@ def _resolve_dtype(dtype): return torch.int32 elif dtype == np.int64: return torch.int64 + elif dtype == np.float16: + return torch.float16 elif dtype == np.float32: return torch.float32 else: diff --git a/examples/gpt2_grid_search.py b/examples/gpt2_grid_search.py index 8f085d12..c88ccf9d 100644 --- a/examples/gpt2_grid_search.py +++ b/examples/gpt2_grid_search.py @@ -10,6 +10,7 @@ class GPTGridSearch(GridSearch): def __init__( self, backend, + dtype, use_gpu, output_file, device_throughput, @@ -38,6 +39,7 @@ def __init__( super().__init__( model_params, backend, + dtype, use_gpu, output_file, device_throughput, diff --git a/examples/grid_search.py b/examples/grid_search.py index 47b0dd75..e475658b 100644 --- a/examples/grid_search.py +++ b/examples/grid_search.py @@ -43,6 +43,7 @@ def __init__( self, model_params, backend, + dtype, use_gpu, output_file, device_throughput, @@ -55,6 +56,7 @@ def __init__( ): self.model_params = model_params self.backend = backend + self.dtype = dtype self.use_gpu = use_gpu self.output_file = output_file self.device_throughput = device_throughput @@ -272,6 +274,7 @@ def run_grid_search(args, grid_search_cls): } grid_search = grid_search_cls( args.backend, + args.dtype, args.use_gpu, args.output_file, args.device_throughput, diff --git a/examples/mlp.py b/examples/mlp.py index 949c2b9a..0346d59e 100644 --- a/examples/mlp.py +++ b/examples/mlp.py @@ -5,8 +5,14 @@ import torch from dist_ir.ir import FunctionMaker, Topology, get_uniform_topology, Value -from dist_ir.ir.type import Int32, Float32, Tensor, abstract_values -from dist_ir.executor import CostModel, Simulator, infer_types +from dist_ir.ir.type import Int32, Float16, Float32, Tensor, abstract_values +from dist_ir.executor import ( + CostModel, + Simulator, + ConcreteValue, + infer_types, + sequentially_execute, +) from dist_ir.transforms import mlp_dhp_transform from .parser import Parser import dist_ir.backend.torch as torch_backend @@ -39,42 +45,40 @@ def get_typed_input_values(inputs, batch_size, input_dim, output_dim): return tuple(typed_inputs) -def get_input_data(batch_size, dim, num_layers): - x = np.random.normal(size=(batch_size, dim)) - z = np.random.normal(size=(batch_size, dim)) - n = batch_size - weights = [np.random.normal(size=(dim, dim))] - for i in range(1, num_layers - 1): - weights.append(np.random.normal(size=(dim, dim))) - weights.append(np.random.normal(size=(dim, dim))) +def get_input_data(inputs, batch_size, input_dim, output_dim, device, dtype): + input_data = [] + x = np.random.normal(0, 0.02, size=(batch_size, input_dim)) + z = np.random.normal(0, 0.02, size=(batch_size, output_dim)) + n = np.int64(batch_size) + weights = [np.random.normal(0, 0.02, size=inp.type.shape) for inp in inputs[3:]] input_data = [x, z, n] + weights - input_data = [ - v.astype(np.float32) if i != 2 else v for i, v in enumerate(input_data) - ] + input_data = [v.astype(dtype) if i != 2 else v for i, v in enumerate(input_data)] + input_data = [ConcreteValue(v, device) for v in input_data] + assert len(input_data) == len(inputs) return input_data -def mlp(input_dim, hidden_dim, output_dim, num_hidden_layers, device): +def mlp(input_dim, hidden_dim, output_dim, num_hidden_layers, device, dtype): function = FunctionMaker(name="mlp") x = function.add_input_value( "x", - Tensor(dtype=Float32(), shape=None, device=device), + Tensor(dtype=dtype(), shape=None, device=device), ) z = function.add_input_value( "z", - Tensor(dtype=Float32(), shape=None, device=device), + Tensor(dtype=dtype(), shape=None, device=device), ) n = function.add_input_value("n", Int32(device=device)) weights = [] for i in range(num_hidden_layers - 1): w = function.add_input_value( f"w{chr(ord('A')+i)}", - Tensor(dtype=Float32(), shape=(input_dim, hidden_dim), device=device), + Tensor(dtype=dtype(), shape=(input_dim, hidden_dim), device=device), ) weights.append(w) w = function.add_input_value( f"w{chr(ord('A')+i+1)}", - Tensor(dtype=Float32(), shape=(hidden_dim, output_dim), device=device), + Tensor(dtype=dtype(), shape=(hidden_dim, output_dim), device=device), ) weights.append(w) @@ -107,24 +111,24 @@ def mlp(input_dim, hidden_dim, output_dim, num_hidden_layers, device): def mlp_inference( - batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, device + batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, device, dtype ): function = FunctionMaker(name="mlp") weights = [] for i in range(num_hidden_layers - 1): w = function.add_input_value( f"w{chr(ord('A')+i)}", - Tensor(dtype=Float32(), shape=(input_dim, hidden_dim), device=device), + Tensor(dtype=dtype(), shape=(input_dim, hidden_dim), device=device), ) weights.append(w) w = function.add_input_value( f"w{chr(ord('A')+i+1)}", - Tensor(dtype=Float32(), shape=(hidden_dim, output_dim), device=device), + Tensor(dtype=dtype(), shape=(hidden_dim, output_dim), device=device), ) weights.append(w) x = function.add_input_value( "x", - Tensor(dtype=Float32(), shape=(batch_size, input_dim), device=device), + Tensor(dtype=dtype(), shape=(batch_size, input_dim), device=device), ) a = x @@ -136,7 +140,7 @@ def mlp_inference( def mlp_inference_dp( - batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, devices + batch_size, input_dim, hidden_dim, output_dim, num_hidden_layers, devices, dtype ): num_devices = len(devices) assert batch_size % num_devices == 0 @@ -147,16 +151,16 @@ def mlp_inference_dp( for i in range(num_hidden_layers - 1): weights[i, d] = function.add_input_value( f"w{chr(ord('A')+i)}_{d.device_id}", - Tensor(dtype=Float32(), shape=(input_dim, hidden_dim), device=d), + Tensor(dtype=dtype(), shape=(input_dim, hidden_dim), device=d), ) weights[num_hidden_layers - 1, d] = function.add_input_value( f"w{chr(ord('A')+i+1)}_{d.device_id}", - Tensor(dtype=Float32(), shape=(hidden_dim, output_dim), device=d), + Tensor(dtype=dtype(), shape=(hidden_dim, output_dim), device=d), ) x[d] = function.add_input_value( f"x_{d.device_id}", Tensor( - dtype=Float32(), shape=(batch_size // num_devices, input_dim), device=d + dtype=dtype(), shape=(batch_size // num_devices, input_dim), device=d ), ) @@ -274,12 +278,31 @@ def simulate(function, input_types, topology, allreduce_parameters=None): return simulation -def run_pytorch(function, input_data, world_size, use_gpu=True): +def run_pytorch(function, input_data, world_size, use_gpu=torch.cuda.is_available()): + # TODO: Move this to a utils file + def _resolve_dtype(dtype): + if dtype == np.int32: + return torch.int32 + elif dtype == np.int64: + return torch.int64 + elif dtype == np.float16: + return torch.float16 + elif dtype == np.float32: + return torch.float32 + else: + raise NotImplementedError(dtype) + if use_gpu and world_size > torch.cuda.device_count(): raise ValueError( f"Specified world size is {world_size}, but only " f"{torch.cuda.device_count()} GPUs available" ) + pytorch_input_data = [ + torch.tensor(x.val, dtype=_resolve_dtype(x.val.dtype)) + if isinstance(x.val, np.ndarray) + else torch.tensor(x.val, dtype=torch.int32) + for x in input_data + ] input_types = abstract_values( input_data, tuple( @@ -287,7 +310,6 @@ def run_pytorch(function, input_data, world_size, use_gpu=True): for i in range(len(input_data)) ), ) - pytorch_input_data = [torch.tensor(x.val, dtype=torch.float32) for x in input_data] per_rank_outputs, runtimes = torch_backend.run_pytorch( function, pytorch_input_data, @@ -302,6 +324,7 @@ def run_pytorch(function, input_data, world_size, use_gpu=True): def run_mlp( phase, backend, + dtype, use_gpu, batch_size, input_dim, @@ -319,6 +342,8 @@ def run_mlp( trace_file, verbose=False, ): + dist_ir_dtype = Float32 if dtype == "fp32" else Float16 + numpy_dtype = np.float32 if dtype == "fp32" else np.float16 world_size = dp_degree * hp_degree * pp_degree topology = get_uniform_topology( world_size, @@ -335,6 +360,7 @@ def run_mlp( output_dim, num_hidden_layers, topology.devices[0], + dist_ir_dtype, ) elif phase == "inference": fn = mlp_inference( @@ -343,6 +369,7 @@ def run_mlp( output_dim, num_hidden_layers, topology.devices[0], + dist_ir_dtype, ) if verbose: @@ -350,6 +377,16 @@ def run_mlp( print("Parameter count:", parameter_count_str) print("Model size:", model_size_str) + if backend == "pytorch": + input_data = get_input_data( + fn.inputs, + batch_size, + input_dim, + output_dim, + topology.devices[0], + numpy_dtype, + ) + if world_size > 1: init_fn, transformed_fn = mlp_dhp_transform( fn, @@ -365,6 +402,8 @@ def run_mlp( init_fn = infer_types(init_fn, typed_inputs) transformed_fn = infer_types(transformed_fn, init_fn.outputs) input_types = tuple(output.type for output in init_fn.outputs) + if backend == "pytorch": + transformed_input_data = sequentially_execute(init_fn, input_data) else: typed_inputs = get_typed_input_values( fn.inputs, batch_size, input_dim, output_dim @@ -372,6 +411,8 @@ def run_mlp( fn = infer_types(fn, typed_inputs) transformed_fn = fn input_types = tuple(inp.type for inp in fn.inputs) + if backend == "pytorch": + transformed_input_data = input_data transformed_fn = add_optimizer_ops(transformed_fn) if backend == "simulate": simulation = simulate(transformed_fn, input_types, topology) @@ -381,19 +422,14 @@ def run_mlp( simulation.dump_chrome_trace(trace_file) return simulation elif backend == "pytorch": - input_data = [ - ConcreteValue( - np.random.normal(size=typ.size).astype(np.float32), device=typ.device - ) - for typ in input_types - ] - return run_pytorch(fn, input_data, world_size, use_gpu) + return run_pytorch(transformed_fn, transformed_input_data, world_size, use_gpu) def main(args): run_mlp( args.phase, args.backend, + args.dtype, args.use_gpu, args.batch_size, args.input_dim, diff --git a/examples/mlp_grid_search.py b/examples/mlp_grid_search.py index c138f188..9668c2d5 100644 --- a/examples/mlp_grid_search.py +++ b/examples/mlp_grid_search.py @@ -1,5 +1,7 @@ +import numpy as np + from dist_ir.ir import Value -from dist_ir.ir.type import Tensor +from dist_ir.ir.type import Tensor, Float32, Float16 from dist_ir.executor import infer_types, sequentially_execute, ConcreteValue from dist_ir.transforms import mlp_dhp_transform from . import mlp @@ -11,6 +13,7 @@ class MLPGridSearch(GridSearch): def __init__( self, backend, + dtype, use_gpu, output_file, device_throughput, @@ -30,6 +33,7 @@ def __init__( super().__init__( model_params, backend, + dtype, use_gpu, output_file, device_throughput, @@ -46,16 +50,25 @@ def get_model_and_input_data(self, batch_size, model_size): if model_size not in self.models: num_layers, dim = self.model_params[model_size] self.models[model_size] = mlp.mlp( - dim, dim, dim, num_layers, self.topology.devices[0] + dim, + dim, + dim, + num_layers, + self.topology.devices[0], + Float32 if self.dtype == "fp32" else Float16, ) fn = self.models[model_size] num_layers, dim = self.model_params[model_size] if self.backend == "pytorch": - input_data = mlp.get_input_data(batch_size, dim, num_layers) - input_data = tuple( - ConcreteValue(t, inp.type.device) - for t, inp in zip(input_data, fn.inputs) + dtype = np.float32 if self.dtype == "fp32" else np.float16 + input_data = mlp.get_input_data( + fn.inputs, + batch_size, + dim, + dim, + self.topology.devices[0], + dtype, ) else: input_data = mlp.get_typed_input_values(fn.inputs, batch_size, dim, dim) diff --git a/examples/parser.py b/examples/parser.py index f0567d90..a3e036f5 100644 --- a/examples/parser.py +++ b/examples/parser.py @@ -75,6 +75,7 @@ def add_simulation_config_arguments(self): def add_execution_mode_config_arguments(self): self.add_argument("--backend", choices=["simulate", "pytorch"], required=True) + self.add_argument("--dtype", choices=["fp32", "fp16"], default="fp32") def add_simulation_output_config_arguments(self): self.add_argument("--trace_file", type=str, default=None, help="Trace file") @@ -186,7 +187,3 @@ def add_gpt2_model_path_config_arguments(self): "text/machine_comprehension/gpt-2/model/gpt2-10.onnx?raw=True)" ), ) - - def add_calibration_arguments(self): - # TODO: Add for simulator accuracy - pass diff --git a/test/test_grid_search.py b/test/test_grid_search.py index 0ae951c8..50fb8633 100644 --- a/test/test_grid_search.py +++ b/test/test_grid_search.py @@ -17,11 +17,22 @@ @pytest.mark.parametrize( - ("backend"), - ["simulate", "pytorch"], + "backend, dtype", + [ + ("simulate", "fp32"), + ("simulate", "fp16"), + ("pytorch", "fp32"), + pytest.param( + "pytorch", + "fp16", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="fp16 only available on GPU" + ), + ), + ], ) -def test_mlp_grid_search(backend): - all_world_sizes = [1, 2, 4] +def test_mlp_grid_search(backend, dtype): + all_world_sizes = [1, 2] all_batch_sizes = [256] all_model_sizes = ["mlp-xs"] with tempfile.NamedTemporaryFile() as tf: @@ -30,6 +41,7 @@ def test_mlp_grid_search(backend): writer.writeheader() grid_search = MLPGridSearch( backend=backend, + dtype=dtype, use_gpu=torch.cuda.is_available(), output_file=tf.name, device_throughput=constants.DEFAULT_DEVICE_THROUGHPUT, @@ -56,6 +68,7 @@ def test_mlp_grid_search(backend): simulation = mlp.run_mlp( phase="training", backend="simulate", + dtype=dtype, use_gpu=False, batch_size=all_batch_sizes[0], input_dim=dim, @@ -88,11 +101,22 @@ def test_mlp_grid_search(backend): @pytest.mark.parametrize( - ("backend"), - ["simulate", "pytorch"], + "backend, dtype", + [ + ("simulate", "fp32"), + ("simulate", "fp16"), + ("pytorch", "fp32"), + pytest.param( + "pytorch", + "fp16", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="fp16 only available on GPU" + ), + ), + ], ) -def test_gpt_grid_search(backend): - all_world_sizes = [1, 2, 4] +def test_gpt_grid_search(backend, dtype): + all_world_sizes = [1, 2] all_batch_sizes = [64] all_model_sizes = ["gpt2-xs"] with tempfile.NamedTemporaryFile() as tf: @@ -101,6 +125,7 @@ def test_gpt_grid_search(backend): writer.writeheader() grid_search = GPTGridSearch( backend=backend, + dtype=dtype, use_gpu=torch.cuda.is_available(), output_file=tf.name, device_throughput=constants.DEFAULT_DEVICE_THROUGHPUT, @@ -159,3 +184,10 @@ def test_gpt_grid_search(backend): & (df["num_microbatches"] == p) ]["latency"].values[0] assert math.isclose(latency, grid_search_latency, abs_tol=10 ** -8) + + +if __name__ == "__main__": + print(f"MLP fp32") + test_mlp_grid_search("pytorch", "fp32") + print(f"MLP fp16") + test_mlp_grid_search("pytorch", "fp16") diff --git a/test/test_mlp_dhp_transform.py b/test/test_mlp_dhp_transform.py index bc9ebfc1..1a969a74 100644 --- a/test/test_mlp_dhp_transform.py +++ b/test/test_mlp_dhp_transform.py @@ -6,8 +6,8 @@ from examples import mlp from dist_ir.ir import FunctionMaker, get_uniform_topology +from dist_ir.ir.type import Float32, Float16 from dist_ir.executor import infer_types, sequentially_execute, ConcreteValue -from dist_ir.ir.type import Float32, Tensor from dist_ir.transforms import mlp_dhp_transform BATCH_SIZE = 64 @@ -51,19 +51,22 @@ def _verify_hp(function, transformed_function, outputs, transformed_outputs, dp= @pytest.mark.parametrize( - ("dp_degree", "hp_degree", "pp_degree"), - list(itertools.product([1, 2], [1, 2], [1, 2])), + ("dp_degree", "hp_degree", "pp_degree", "dtype"), + list(itertools.product([1, 2], [1, 2], [1, 2], ["fp32", "fp16"])), ) def test_mlp_dhp_transform( dp_degree, hp_degree, pp_degree, + dtype, batch_size=BATCH_SIZE, num_hidden_layers=8, input_dim=INPUT_DIM, ): num_microbatches = pp_degree world_size = dp_degree * hp_degree * pp_degree + dist_ir_dtype = Float32 if dtype == "fp32" else Float16 + numpy_dtype = np.float32 if dtype == "fp32" else np.float16 topology = get_uniform_topology(world_size) function = mlp.mlp( input_dim, @@ -71,6 +74,7 @@ def test_mlp_dhp_transform( input_dim, num_hidden_layers, topology.devices[0], + dist_ir_dtype, ) typed_inputs = mlp.get_typed_input_values( function.inputs, batch_size, input_dim, input_dim @@ -90,13 +94,14 @@ def test_mlp_dhp_transform( transformed_function = infer_types(transformed_function, init_function.outputs) transformed_function = mlp.add_optimizer_ops(transformed_function) - input_data = [ - ConcreteValue( - np.random.normal(size=inp.type.shape) if i != 2 else batch_size, - topology.devices[0], - ) - for i, inp in enumerate(typed_inputs) - ] + input_data = mlp.get_input_data( + init_function.inputs, + batch_size, + input_dim, + input_dim, + topology.devices[0], + numpy_dtype, + ) outputs = sequentially_execute(function, input_data) dist_input_data = sequentially_execute(init_function, input_data) transformed_outputs = sequentially_execute(transformed_function, dist_input_data) diff --git a/test/test_pytorch_backend.py b/test/test_pytorch_backend.py index 9e01b98c..cc8578e7 100644 --- a/test/test_pytorch_backend.py +++ b/test/test_pytorch_backend.py @@ -9,7 +9,7 @@ from dist_ir.executor.simulator import Simulator from dist_ir.executor.type_inference import infer_types from dist_ir.ir import Device, FunctionMaker, cpprint, Value -from dist_ir.ir.type import Float32, Tensor +from dist_ir.ir.type import Float16, Float32, Tensor from dist_ir.ir.topology import Topology, get_uniform_topology # TODO make examples submodule of dist_ir? @@ -79,20 +79,34 @@ def create_owt_model(num_devices, num_layers): @pytest.mark.parametrize( - "num_devices, num_layers, use_gpu", + "num_devices, num_layers, use_gpu, dtype", [ - (2, 4, False), + (2, 4, False, "fp32"), pytest.param( 2, 4, True, + "fp32", + marks=pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Not enough available GPUs" + ), + ), + pytest.param( + 2, + 4, + True, + "fp16", marks=pytest.mark.skipif( torch.cuda.device_count() < 2, reason="Not enough available GPUs" ), ), ], ) -def test_owt(num_devices, num_layers, use_gpu): +def test_owt(num_devices, num_layers, use_gpu, dtype): + dist_ir_dtype = Float32 if dtype == "fp32" else Float16 + numpy_dtype = np.float32 if dtype == "fp32" else np.float16 + torch_dtype = torch.float32 if dtype == "fp32" else torch.float16 + fn = create_owt_model(num_devices, num_layers) devices = [Device(0, "cpu")] @@ -110,11 +124,11 @@ def test_owt(num_devices, num_layers, use_gpu): else: shape = (hidden_dim, hidden_dim // num_devices) # w{l}_{d}: - input_vals.append(Value("", Tensor(Float32(), shape, devices[d]))) + input_vals.append(Value("", Tensor(dist_ir_dtype(), shape, devices[d]))) for d in range(1, num_devices + 1): # x_{d}: shape = (batch_size // num_devices, hidden_dim) - input_vals.append(Value("", Tensor(Float32(), shape, devices[d]))) + input_vals.append(Value("", Tensor(dist_ir_dtype(), shape, devices[d]))) # Test type inference: fn = infer_types(fn, input_vals) @@ -125,8 +139,11 @@ def test_owt(num_devices, num_layers, use_gpu): # Test with sequential executor: np.random.seed(0) - weights = [np.random.randn(hidden_dim, hidden_dim) for l in range(num_layers)] - x = np.random.randn(batch_size, hidden_dim) + weights = [ + np.random.normal(0, 0.02, size=(hidden_dim, hidden_dim)).astype(numpy_dtype) + for l in range(num_layers) + ] + x = np.random.normal(0, 0.02, size=(batch_size, hidden_dim)).astype(numpy_dtype) # Split inputs for distributed function input_arrays = [] @@ -150,7 +167,7 @@ def test_owt(num_devices, num_layers, use_gpu): # Run per-rank modules using PyTorch backend: per_rank_outputs, _ = run_pytorch( - fn, [torch.tensor(a) for a in input_arrays], use_gpu=use_gpu + fn, [torch.tensor(a).to(torch_dtype) for a in input_arrays], use_gpu=use_gpu ) # Check outputs: @@ -249,18 +266,27 @@ def test_send_recv(use_gpu): @pytest.mark.parametrize( - "use_gpu", + "use_gpu, dtype", [ - False, + (False, "fp32"), + (False, "fp16"), + pytest.param( + True, + "fp16", + marks=pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Not enough available GPUs" + ), + ), pytest.param( True, + "fp32", marks=pytest.mark.skipif( torch.cuda.device_count() < 2, reason="Not enough available GPUs" ), ), ], ) -def test_dp_mlp(use_gpu): +def test_dp_mlp(use_gpu, dtype): num_devices = 2 num_layers = 4 batch_size = 4 @@ -268,7 +294,13 @@ def test_dp_mlp(use_gpu): devices = [Device(d, "gpu") for d in range(num_devices + 1)] fn = mlp_inference_dp( - batch_size, hidden_dim, hidden_dim, hidden_dim, num_layers, devices[1:] + batch_size, + hidden_dim, + hidden_dim, + hidden_dim, + num_layers, + devices[1:], + Float32 if dtype == "fp32" else Float16, ) fn = infer_types(fn, fn.inputs) cpprint(fn) @@ -293,7 +325,9 @@ def new_inputs(): # Project and run on backend: per_rank_outputs, runtimes = run_pytorch( - fn, convert_inputs_dp(weights, x), use_gpu=use_gpu + fn, + convert_inputs_dp(weights, x), + use_gpu=use_gpu, ) # Check outputs: