diff --git a/projects/onnx_c_importer/import-onnx-main.cpp b/projects/onnx_c_importer/import-onnx-main.cpp index 58ebd98b6a70..4660cad286ac 100644 --- a/projects/onnx_c_importer/import-onnx-main.cpp +++ b/projects/onnx_c_importer/import-onnx-main.cpp @@ -17,15 +17,16 @@ #include "llvm/Support/raw_ostream.h" #include "OnnxImporter.h" - #include "onnx/onnx_pb.h" #include #include +#include using namespace llvm; using namespace torch_mlir_onnx; +// Encapsulates MLIR context and module management struct MlirState { MlirState() { context = mlirContextCreateWithThreading(false); @@ -42,62 +43,81 @@ struct MlirState { }; int main(int argc, char **argv) { + // Define command-line options static cl::opt inputFilename( cl::Positional, cl::desc(""), cl::init("-")); + static cl::opt outputFilename( + "o", cl::desc("Output filename"), cl::value_desc("filename"), cl::init("-")); - static cl::opt outputFilename("o", cl::desc("Output filename"), - cl::value_desc("filename"), - cl::init("-")); - + // Initialize LLVM and parse command-line options InitLLVM y(argc, argv); cl::ParseCommandLineOptions(argc, argv, "torch-mlir-onnx-import-c"); - // Open the input as an istream because that is what protobuf likes. - std::unique_ptr alloced_input_stream; - std::istream *input_stream = nullptr; + // Open the input file stream + std::unique_ptr allocedInputStream; + std::istream *inputStream = nullptr; if (inputFilename == "-") { - errs() << "(parsing from stdin)\n"; - input_stream = &std::cin; + errs() << "(Parsing from stdin)\n"; + inputStream = &std::cin; } else { - alloced_input_stream = std::make_unique( + allocedInputStream = std::make_unique( inputFilename, std::ios::in | std::ios::binary); - if (!*alloced_input_stream) { - errs() << "error: could not open input file " << inputFilename << "\n"; - return 1; + if (!allocedInputStream->is_open()) { + errs() << "Error: Could not open input file: " << inputFilename << "\n"; + return EXIT_FAILURE; } - input_stream = alloced_input_stream.get(); + inputStream = allocedInputStream.get(); } - // Parse the model proto. - ModelInfo model_info; - if (!model_info.model_proto().ParseFromIstream(input_stream)) { - errs() << "Failed to parse ONNX ModelProto from " << inputFilename << "\n"; - return 2; + // Parse the ONNX model proto + ModelInfo modelInfo; + if (!modelInfo.model_proto().ParseFromIstream(inputStream)) { + errs() << "Error: Failed to parse ONNX ModelProto from " << inputFilename << "\n"; + return EXIT_FAILURE; } - if (failed(model_info.Initialize())) { - errs() << "error: Import failure: " << model_info.error_message() << "\n"; - model_info.DebugDumpProto(); - return 3; + // Initialize model information + if (failed(modelInfo.Initialize())) { + errs() << "Error: Import failure: " << modelInfo.error_message() << "\n"; + modelInfo.DebugDumpProto(); + return EXIT_FAILURE; } - model_info.DebugDumpProto(); + modelInfo.DebugDumpProto(); + + // Create MLIR state and context cache + MlirState ownedState; + ContextCache contextCache(modelInfo, ownedState.context); - // Import. - MlirState owned_state; - ContextCache cc(model_info, owned_state.context); - NodeImporter importer(model_info.main_graph(), cc, - mlirModuleGetOperation(owned_state.module)); + // Import the ONNX graph into MLIR + NodeImporter importer( + modelInfo.main_graph(), contextCache, mlirModuleGetOperation(ownedState.module)); if (failed(importer.DefineFunction())) { - errs() << "error: Could not define MLIR function for graph: " - << model_info.error_message() << "\n"; - return 4; + errs() << "Error: Could not define MLIR function for graph: " + << modelInfo.error_message() << "\n"; + return EXIT_FAILURE; } if (failed(importer.ImportAll())) { - errs() << "error: Could not import one or more graph nodes: " - << model_info.error_message() << "\n"; - return 5; + errs() << "Error: Could not import one or more graph nodes: " + << modelInfo.error_message() << "\n"; + return EXIT_FAILURE; } + + // Dump the imported MLIR module importer.DebugDumpModule(); - return 0; + // Optional: Save the output MLIR module to a file + if (outputFilename != "-") { + std::ofstream outFile(outputFilename, std::ios::out); + if (!outFile.is_open()) { + errs() << "Error: Could not open output file: " << outputFilename << "\n"; + return EXIT_FAILURE; + } + mlirOperationPrint(mlirModuleGetOperation(ownedState.module), outFile); + outs() << "Successfully saved MLIR module to " << outputFilename << "\n"; + } else { + outs() << "MLIR module processing complete. Output not saved to a file.\n"; + } + + return EXIT_SUCCESS; } +