Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX exporter #26

Draft
wants to merge 63 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
8ac2c50
Projector: DistIR function -> per-rank function
siddharth-krishna Apr 19, 2021
193dc84
Fix bug: constructing variadic ops when specifying output_values
siddharth-krishna Apr 19, 2021
c5bfbd9
Fix sequential executor docstring
siddharth-krishna Apr 19, 2021
d2f9e0a
A distributed PyTorch backend
siddharth-krishna Apr 25, 2021
d9cc787
Upgrade PyTorch version
siddharth-krishna Apr 25, 2021
a2a4cc2
Add test: one-weird-trick (matmul version)
siddharth-krishna Apr 25, 2021
7ab770d
Refactor run_multiprocess
siddharth-krishna Apr 25, 2021
4e37ce3
Make run_multiprocess take Functions not nn.Modules
siddharth-krishna Apr 25, 2021
9576811
End-of-file newlines
siddharth-krishna Apr 25, 2021
c4ca1a0
Black
siddharth-krishna Apr 27, 2021
a095665
Parametrize pytest test_owt
siddharth-krishna Apr 27, 2021
3279ca1
Revert "Fix bug: constructing variadic ops when specifying output_val…
siddharth-krishna Apr 27, 2021
c84c61d
Fix Op constructor: better handling of variadic ops/pre-created outputs
siddharth-krishna Apr 27, 2021
bb604f6
Add support for Relu
siddharth-krishna Apr 28, 2021
8d8bbda
Add DP test
siddharth-krishna Apr 29, 2021
42751da
Backend: run and time on GPU
siddharth-krishna Apr 30, 2021
f73d33b
Refactor grid_search for interactive use
siddharth-krishna Apr 30, 2021
20bae75
Timing code for CPUs
siddharth-krishna May 2, 2021
8e8513d
Handle ops with multiple outputs
siddharth-krishna May 2, 2021
6f3ae59
Add support for all MLP training ops
siddharth-krishna May 2, 2021
49d385c
Type inference: fix function name bug
siddharth-krishna May 2, 2021
2cce7c3
Prettyprint: support for printing FunctionMakers
siddharth-krishna May 2, 2021
91f5ecf
Fix backend op implementations
siddharth-krishna May 2, 2021
172a5db
Convert per-rank fns to Modules inside each thread
siddharth-krishna May 2, 2021
2ed3cf7
Per-rank projector: remove types
siddharth-krishna May 2, 2021
7eee9f9
Add some more tests
siddharth-krishna May 2, 2021
78da37d
Default number of repetitions = 1
siddharth-krishna May 2, 2021
56e05dc
DHP transform: return separate init_fn and transformed fn
siddharth-krishna May 2, 2021
78bc1e9
Run pytest with root dir included in PYTHONPATH
siddharth-krishna May 2, 2021
8a288b6
Grid search: remove unused imports
siddharth-krishna May 2, 2021
4b2b301
Revert unintended changes
siddharth-krishna May 2, 2021
75ec41a
Move run_pytorch to backend.torch
siddharth-krishna May 2, 2021
6dbf57c
Interpret Function instead of creating fx.Graph
siddharth-krishna May 5, 2021
4a77383
Prettyprint attributes as Python-style kwargs
siddharth-krishna May 5, 2021
e8735c3
Use broadcast with pairwise groups for send/recv on GPUs
siddharth-krishna May 5, 2021
fd2e7b1
Move new tensors to GPU in each op, outputs back to CPU
siddharth-krishna May 5, 2021
31031cc
Add a mock multiprocess backend for debugging
siddharth-krishna May 6, 2021
5d99a62
Use spawn start method for multiprocessing
siddharth-krishna May 6, 2021
2c8852a
Fix MLP DHP tests
siddharth-krishna May 6, 2021
1d54fea
Revert "Fix MLP DHP tests"
siddharth-krishna May 6, 2021
18eaa08
Fix MLP DHP tests for real
siddharth-krishna May 6, 2021
0ce2fdb
Add code to plot grid search results
siddharth-krishna May 7, 2021
5588b87
Don't use globals while multiprocessing
siddharth-krishna May 7, 2021
7d714d4
Fix mock backend, use_gpu=False by default
siddharth-krishna May 7, 2021
8646842
Partial grid search on 4 devices
siddharth-krishna May 7, 2021
1db0736
Support collectives between a subset of ranks
siddharth-krishna May 10, 2021
a79ab67
Bug fixes for distributed process groups (#24)
santhnm2 May 11, 2021
0c0f7f5
Debugging MLP deadlock
siddharth-krishna May 13, 2021
33cefce
Fix collective projector
santhnm2 May 14, 2021
a5c8b34
Enable grid search test again
siddharth-krishna May 13, 2021
ca68ec4
Some comments, debugging code, cuda sync earlier
siddharth-krishna May 20, 2021
b31209d
Projector for gather
siddharth-krishna May 20, 2021
54ce0d2
Don't save input/outputs to file in torch backend
siddharth-krishna May 25, 2021
5d1b63f
Remove unnecessary global
siddharth-krishna May 25, 2021
3466627
Free tensors after use
siddharth-krishna May 25, 2021
7d16c90
Map DistIR devices to pytorch backend ranks
siddharth-krishna May 26, 2021
8c4f5e1
Fix tests
siddharth-krishna May 26, 2021
79d4a65
Some documentation and cleanup
siddharth-krishna May 26, 2021
da7ff5d
Fix comment
siddharth-krishna May 26, 2021
72589b2
Remove experiment code and dead code
siddharth-krishna May 26, 2021
849ab4e
Add ONNX exporter
santhnm2 Jun 1, 2021
054ce0a
Remove singleton types
santhnm2 Jun 1, 2021
fab0020
Add test to verify that out-of-opset ops throw error
santhnm2 Jun 2, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ jobs:
run: python setup.py install

- name: Test with pytest
run: pytest
run: python -m pytest
2 changes: 2 additions & 0 deletions dist_ir/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import onnx
from . import torch
78 changes: 78 additions & 0 deletions dist_ir/backend/onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto

from ..executor.rank_projector import project
from ..ir.type import Bool, Int32, Int64, Float


def _get_onnx_dtype(dtype):
if isinstance(dtype, type(Int32())):
return AttributeProto.INT
elif isinstance(dtype, type(Int64())):
# TODO: Check this
return AttributeProto.INT
elif isinstance(dtype, type(Float())):
# TODO: Split DistIR Float to Float16 and Float32
return AttributeProto.FLOAT
elif isinstance(dtype, type(Bool())):
# TODO: Check this
return AttributeProto.INT
else:
return AttributeProto.UNDEFINED


def _export_onnx_helper(fn):
value_map = {}
nodes = []
for inp in fn.inputs:
print(inp)
assert inp.name is not None and inp.name not in value_map
assert inp.type is not None
assert inp.type.dtype is not None
assert inp.type.shape is not None
value_map[inp.name] = helper.make_tensor_value_info(
name=inp.name,
elem_type=_get_onnx_dtype(inp.type.dtype),
shape=inp.type.shape,
)
for op in fn.ops:
inputs = [value_map[inp.name].name for inp in op.inputs]
for output in op.outputs:
assert output.name is not None and output.name not in value_map
assert output.type is not None
assert output.type.dtype is not None
assert output.type.shape is not None
value_map[output.name] = helper.make_tensor_value_info(
name=output.name,
elem_type=_get_onnx_dtype(output.type.dtype),
shape=output.type.shape,
)
outputs = [value_map[output.name].name for output in op.outputs]
node = helper.make_node(
op_type=op.op_type,
inputs=inputs,
outputs=outputs,
name=op.name,
**op.attributes,
)
nodes.append(node)
graph_def = helper.make_graph(
nodes=nodes,
name=fn.name,
inputs=[value_map[inp.name] for inp in fn.inputs],
outputs=[value_map[output.name] for output in fn.outputs],
)
model_def = helper.make_model(graph_def)
return model_def


def export_onnx(fn):
device_to_fns, groups = project(fn, tuple(v.type for v in fn.inputs))
devices = sorted(device_to_fns.keys())
device_to_onnx_fns = {}
for device in sorted(devices):
projected_fn = device_to_fns[device]
onnx_fn = _export_onnx_helper(projected_fn)
device_to_onnx_fns[device] = onnx_fn
return device_to_onnx_fns
Loading