Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mxr to onnx #3682

Merged
merged 35 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
65a99e0
migraphx_py
richagadgil Dec 4, 2024
a0e379a
mxr to onnx
richagadgil Dec 6, 2024
f929698
add args
richagadgil Dec 6, 2024
87a72fa
foramt
richagadgil Dec 13, 2024
d3ef19d
Merge branch 'develop' into mxr_to_onnx
richagadgil Dec 13, 2024
d79e5c9
fix hash err
richagadgil Dec 13, 2024
4792688
fix eq
richagadgil Dec 13, 2024
2a49d89
fixed
richagadgil Dec 16, 2024
99714b9
format
richagadgil Dec 16, 2024
e1e1ea5
check ci
richagadgil Dec 19, 2024
acfa267
Update migraphx_py.cpp
richagadgil Dec 19, 2024
699fc80
check licnse
richagadgil Dec 19, 2024
be6756f
add os
richagadgil Dec 19, 2024
d687682
check
richagadgil Dec 19, 2024
f7cccb7
check
richagadgil Dec 19, 2024
0427003
check
richagadgil Dec 19, 2024
32b2b45
check
richagadgil Dec 19, 2024
e0c3acc
check
richagadgil Dec 19, 2024
5259a9e
tools
richagadgil Dec 19, 2024
107248f
tools
richagadgil Dec 19, 2024
b3127f9
check
richagadgil Dec 19, 2024
0593f30
fetch first
richagadgil Dec 19, 2024
8b26986
remove origin
richagadgil Dec 19, 2024
cff16f0
add origin
richagadgil Dec 19, 2024
932c1ec
mb
richagadgil Dec 19, 2024
454bfab
add origin
richagadgil Dec 20, 2024
ff0490d
change filename
richagadgil Dec 20, 2024
2eac640
last check
richagadgil Dec 20, 2024
e6e99b3
revert changes
richagadgil Dec 24, 2024
56feecb
restore
richagadgil Dec 24, 2024
3e55e2a
fix is_same
richagadgil Jan 6, 2025
adb06b5
change ::value to {}
richagadgil Jan 7, 2025
258b7d7
Merge branch 'develop' into mxr_to_onnx
richagadgil Jan 9, 2025
f33b1fd
Update migraphx_py.cpp license
richagadgil Jan 9, 2025
a72564f
Merge branch 'develop' into mxr_to_onnx
kahmed10 Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions src/py/migraphx_py.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -44,6 +44,7 @@
#include <migraphx/float8.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/version.h>
#include <migraphx/iterator_for.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#endif
Expand Down Expand Up @@ -259,6 +260,41 @@ py::buffer_info to_buffer_info(T& x)
return b;
}

py::object to_py_object(const migraphx::value& val)
{
py::object result;

val.visit_value([&](const auto& x) {
if constexpr(std::is_same<std::decay_t<decltype(x)>, std::vector<migraphx::value>>{})
{
if(val.is_object())
{
py::dict py_dict;
for(const auto& item : x)
{
py_dict[py::str(item.get_key())] = to_py_object(item.without_key());
}
result = py_dict;
}
else
{
py::list py_list;
for(const auto& item : x)
{
py_list.append(to_py_object(item));
}
result = py_list;
}
}
else
{
result = py::cast(x);
}
});

return result;
}

migraphx::shape to_shape(const py::buffer_info& info)
{
migraphx::shape::type_t t;
Expand Down Expand Up @@ -380,7 +416,16 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)

py::class_<migraphx::instruction_ref>(m, "instruction_ref")
.def("shape", [](migraphx::instruction_ref i) { return i->get_shape(); })
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); });
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); })
.def("inputs", [](migraphx::instruction_ref i) { return i->inputs(); })
.def("name", [](migraphx::instruction_ref i) { return i->name(); })
.def("__hash__",
[](const migraphx::instruction_ref& i) {
return std::hash<migraphx::instruction_ref>()(i);
})
.def("__eq__", [](const migraphx::instruction_ref& i, const migraphx::instruction_ref& j) {
return std::equal_to<migraphx::instruction_ref>()(i, j);
});

py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
Expand Down Expand Up @@ -422,7 +467,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
return mm.add_return(args);
},
py::arg("args"))
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); });
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); })
.def(
"__iter__",
[](const migraphx::module& mm) {
auto r = migraphx::iterator_for(mm);
return py::make_iterator(r.begin(), r.end());
},
py::keep_alive<0, 1>());

py::class_<migraphx::program>(m, "program")
.def(py::init([]() { return migraphx::program(); }))
Expand Down Expand Up @@ -502,7 +554,10 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
}
return migraphx::make_op(name, v);
}))
.def("name", &migraphx::operation::name);
.def("name", &migraphx::operation::name)
.def("values", [](const migraphx::operation& operation) -> py::object {
return to_py_object(operation.to_value());
});

py::enum_<migraphx::op::pooling_mode>(op, "pooling_mode")
.value("average", migraphx::op::pooling_mode::average)
Expand Down
173 changes: 173 additions & 0 deletions tools/converters/mxr_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import migraphx
import onnx
from onnx import helper, TensorProto, checker
import numpy as np
import os
import argparse


# Utility function to map MIGraphX types to ONNX data types
def get_dtype(instruction):
type_mapping = {
'float_type': TensorProto.FLOAT,
'bf16_type': TensorProto.BFLOAT16
}
return type_mapping[instruction.shape().type_string()]


# Utility function to get the shape of an instruction
def get_shape(instruction):
if isinstance(instruction, list):
raise ValueError("Expected instruction, got a list.")
return instruction.shape().lens()


# Utility function to map MIGraphX operations to ONNX operations
def map_operation(operation):
mxr_to_onnx_op = {
"dot": "MatMul",
"mul": "MatMul",
"add": "Add",
"multibroadcast": "Expand",
"erf": "Erf",
"tanh": "Tanh",
"exp": "Exp",
"div": "Div",
"relu": "Relu"
}

if operation not in mxr_to_onnx_op:
raise NotImplementedError(f"Operation '{operation}' is not supported.")
return mxr_to_onnx_op[operation]


# Helper function to create ONNX nodes for specific operations
def create_node(instruction, parameters, node_name, n, initializers):
if node_name == "multibroadcast" or node_name == "reshape":
shape_key = "out_lens" if node_name == "multibroadcast" else "dims"
shape_array = np.array(parameters[shape_key], dtype=np.int64)
initializer_name = f"{node_name}_shape_{n}"

initializers.append(
helper.make_tensor(name=initializer_name,
data_type=TensorProto.INT64,
dims=shape_array.shape,
vals=shape_array.flatten().tolist()))
return helper.make_node(
map_operation(node_name),
inputs=[str(hash(i))
for i in instruction.inputs()] + [initializer_name],
outputs=[str(hash(instruction))])

elif node_name == "transpose":
return helper.make_node(
"Transpose",
inputs=[str(hash(i)) for i in instruction.inputs()],
outputs=[str(hash(instruction))],
perm=parameters["permutation"])

elif node_name == "convolution":
return helper.make_node(
"Conv",
inputs=[str(hash(i)) for i in instruction.inputs()],
outputs=[str(hash(instruction))
], #[str(hash(i)) for i in instruction.outputs()],
dilations=parameters["dilation"],
group=parameters["group"],
pads=parameters["padding"],
strides=parameters["stride"])

return helper.make_node(
map_operation(node_name),
inputs=[str(hash(i)) for i in instruction.inputs()],
outputs=[str(hash(instruction))])


# Main function to convert MIGraphX module to ONNX model
def generate_onnx(module):
inputs = {}
operations = []
initializers = []
n = 0 # Node counter
output = None

for instruction in module:
op_name = instruction.op().name()

# Handle input nodes
if op_name in ["@literal", "@param"]:

inputs[str(hash(instruction))] = helper.make_tensor_value_info(
str(hash(instruction)), get_dtype(instruction),
get_shape(instruction))

# Handle computational nodes
elif "@" not in op_name:
n += 1
parameters = instruction.op().values()

operations.append(
create_node(instruction, parameters, op_name, n, initializers))

# Handle return node
elif op_name == "@return":

output = [
helper.make_tensor_value_info(str(hash(i)), get_dtype(i),
get_shape(i))
for i in instruction.inputs()
]

# Create the ONNX graph
graph = helper.make_graph(nodes=operations,
name="Graph",
inputs=list(inputs.values()),
initializer=initializers,
outputs=output if output else [])

return helper.make_model(graph, producer_name="onnx-dot-add-example")


# Main function to process MIGraphX files and generate ONNX models
def main(mxr_directory_path, onnx_directory_path):
for file_name in os.listdir(mxr_directory_path):
file_path = os.path.join(mxr_directory_path, file_name)
if ".mxr" in file_path:
try:
program = migraphx.load(file_path)
module = program.get_main_module()
model = generate_onnx(module)

# Validate the generated ONNX model
try:
checker.check_model(model)
print(f"ONNX model for {file_path} is valid.")
except onnx.checker.ValidationError as e:
print(f"Validation failed for {file_path}: {e}")
except Exception as e:
print(
f"Unexpected error during validation for {file_path}: {e}"
)

os.makedirs(onnx_directory_path, exist_ok=True)
onnx_file_path = os.path.join(onnx_directory_path,
file_name.replace("mxr", "onnx"))
onnx.save(model, onnx_file_path)

except Exception as e:
print(f"Error processing {file_path}: {e}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Process MXR files and generate ONNX models.")
parser.add_argument("mxr_directory_path",
type=str,
help="Path to the directory containing MXR files.")
parser.add_argument(
"onnx_directory_path",
type=str,
help="Path to the directory where ONNX models will be saved.")

args = parser.parse_args()
main(args.mxr_directory_path, args.onnx_directory_path)
Loading