Skip to content

Commit

Permalink
mxr to onnx (#3682)
Browse files Browse the repository at this point in the history
Changes to migraphx_py to expose information about instruction_ref to engineer ONNX file.
  • Loading branch information
richagadgil authored Jan 13, 2025
1 parent eb1717d commit c87d983
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 4 deletions.
63 changes: 59 additions & 4 deletions src/py/migraphx_py.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -44,6 +44,7 @@
#include <migraphx/float8.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/version.h>
#include <migraphx/iterator_for.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#endif
Expand Down Expand Up @@ -259,6 +260,41 @@ py::buffer_info to_buffer_info(T& x)
return b;
}

py::object to_py_object(const migraphx::value& val)
{
py::object result;

val.visit_value([&](const auto& x) {
if constexpr(std::is_same<std::decay_t<decltype(x)>, std::vector<migraphx::value>>{})
{
if(val.is_object())
{
py::dict py_dict;
for(const auto& item : x)
{
py_dict[py::str(item.get_key())] = to_py_object(item.without_key());
}
result = py_dict;
}
else
{
py::list py_list;
for(const auto& item : x)
{
py_list.append(to_py_object(item));
}
result = py_list;
}
}
else
{
result = py::cast(x);
}
});

return result;
}

migraphx::shape to_shape(const py::buffer_info& info)
{
migraphx::shape::type_t t;
Expand Down Expand Up @@ -380,7 +416,16 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)

py::class_<migraphx::instruction_ref>(m, "instruction_ref")
.def("shape", [](migraphx::instruction_ref i) { return i->get_shape(); })
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); });
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); })
.def("inputs", [](migraphx::instruction_ref i) { return i->inputs(); })
.def("name", [](migraphx::instruction_ref i) { return i->name(); })
.def("__hash__",
[](const migraphx::instruction_ref& i) {
return std::hash<migraphx::instruction_ref>()(i);
})
.def("__eq__", [](const migraphx::instruction_ref& i, const migraphx::instruction_ref& j) {
return std::equal_to<migraphx::instruction_ref>()(i, j);
});

py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
Expand Down Expand Up @@ -422,7 +467,14 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
return mm.add_return(args);
},
py::arg("args"))
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); });
.def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); })
.def(
"__iter__",
[](const migraphx::module& mm) {
auto r = migraphx::iterator_for(mm);
return py::make_iterator(r.begin(), r.end());
},
py::keep_alive<0, 1>());

py::class_<migraphx::program>(m, "program")
.def(py::init([]() { return migraphx::program(); }))
Expand Down Expand Up @@ -502,7 +554,10 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
}
return migraphx::make_op(name, v);
}))
.def("name", &migraphx::operation::name);
.def("name", &migraphx::operation::name)
.def("values", [](const migraphx::operation& operation) -> py::object {
return to_py_object(operation.to_value());
});

py::enum_<migraphx::op::pooling_mode>(op, "pooling_mode")
.value("average", migraphx::op::pooling_mode::average)
Expand Down
173 changes: 173 additions & 0 deletions tools/converters/mxr_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import migraphx
import onnx
from onnx import helper, TensorProto, checker
import numpy as np
import os
import argparse


# Utility function to map MIGraphX types to ONNX data types
def get_dtype(instruction):
type_mapping = {
'float_type': TensorProto.FLOAT,
'bf16_type': TensorProto.BFLOAT16
}
return type_mapping[instruction.shape().type_string()]


# Utility function to get the shape of an instruction
def get_shape(instruction):
if isinstance(instruction, list):
raise ValueError("Expected instruction, got a list.")
return instruction.shape().lens()


# Utility function to map MIGraphX operations to ONNX operations
def map_operation(operation):
mxr_to_onnx_op = {
"dot": "MatMul",
"mul": "MatMul",
"add": "Add",
"multibroadcast": "Expand",
"erf": "Erf",
"tanh": "Tanh",
"exp": "Exp",
"div": "Div",
"relu": "Relu"
}

if operation not in mxr_to_onnx_op:
raise NotImplementedError(f"Operation '{operation}' is not supported.")
return mxr_to_onnx_op[operation]


# Helper function to create ONNX nodes for specific operations
def create_node(instruction, parameters, node_name, n, initializers):
if node_name == "multibroadcast" or node_name == "reshape":
shape_key = "out_lens" if node_name == "multibroadcast" else "dims"
shape_array = np.array(parameters[shape_key], dtype=np.int64)
initializer_name = f"{node_name}_shape_{n}"

initializers.append(
helper.make_tensor(name=initializer_name,
data_type=TensorProto.INT64,
dims=shape_array.shape,
vals=shape_array.flatten().tolist()))
return helper.make_node(
map_operation(node_name),
inputs=[str(hash(i))
for i in instruction.inputs()] + [initializer_name],
outputs=[str(hash(instruction))])

elif node_name == "transpose":
return helper.make_node(
"Transpose",
inputs=[str(hash(i)) for i in instruction.inputs()],
outputs=[str(hash(instruction))],
perm=parameters["permutation"])

elif node_name == "convolution":
return helper.make_node(
"Conv",
inputs=[str(hash(i)) for i in instruction.inputs()],
outputs=[str(hash(instruction))
], #[str(hash(i)) for i in instruction.outputs()],
dilations=parameters["dilation"],
group=parameters["group"],
pads=parameters["padding"],
strides=parameters["stride"])

return helper.make_node(
map_operation(node_name),
inputs=[str(hash(i)) for i in instruction.inputs()],
outputs=[str(hash(instruction))])


# Main function to convert MIGraphX module to ONNX model
def generate_onnx(module):
inputs = {}
operations = []
initializers = []
n = 0 # Node counter
output = None

for instruction in module:
op_name = instruction.op().name()

# Handle input nodes
if op_name in ["@literal", "@param"]:

inputs[str(hash(instruction))] = helper.make_tensor_value_info(
str(hash(instruction)), get_dtype(instruction),
get_shape(instruction))

# Handle computational nodes
elif "@" not in op_name:
n += 1
parameters = instruction.op().values()

operations.append(
create_node(instruction, parameters, op_name, n, initializers))

# Handle return node
elif op_name == "@return":

output = [
helper.make_tensor_value_info(str(hash(i)), get_dtype(i),
get_shape(i))
for i in instruction.inputs()
]

# Create the ONNX graph
graph = helper.make_graph(nodes=operations,
name="Graph",
inputs=list(inputs.values()),
initializer=initializers,
outputs=output if output else [])

return helper.make_model(graph, producer_name="onnx-dot-add-example")


# Main function to process MIGraphX files and generate ONNX models
def main(mxr_directory_path, onnx_directory_path):
for file_name in os.listdir(mxr_directory_path):
file_path = os.path.join(mxr_directory_path, file_name)
if ".mxr" in file_path:
try:
program = migraphx.load(file_path)
module = program.get_main_module()
model = generate_onnx(module)

# Validate the generated ONNX model
try:
checker.check_model(model)
print(f"ONNX model for {file_path} is valid.")
except onnx.checker.ValidationError as e:
print(f"Validation failed for {file_path}: {e}")
except Exception as e:
print(
f"Unexpected error during validation for {file_path}: {e}"
)

os.makedirs(onnx_directory_path, exist_ok=True)
onnx_file_path = os.path.join(onnx_directory_path,
file_name.replace("mxr", "onnx"))
onnx.save(model, onnx_file_path)

except Exception as e:
print(f"Error processing {file_path}: {e}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Process MXR files and generate ONNX models.")
parser.add_argument("mxr_directory_path",
type=str,
help="Path to the directory containing MXR files.")
parser.add_argument(
"onnx_directory_path",
type=str,
help="Path to the directory where ONNX models will be saved.")

args = parser.parse_args()
main(args.mxr_directory_path, args.onnx_directory_path)

0 comments on commit c87d983

Please sign in to comment.