Skip to content

Commit

Permalink
Adding additional check, if the graph is empty (#2630)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarusic-amd authored Jan 18, 2024
1 parent 3fd8abf commit e474a04
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
41 changes: 20 additions & 21 deletions src/tf/tf_parser.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -298,28 +298,27 @@ void tf_parser::parse_graph(const tensorflow::GraphDef& graph)
{
this->parse_node(p.first);
}
auto last_ins = std::prev(mm->end());
if(last_ins != mm->end())
{
// Needs to add a ret instruction at the end of
// the program
if(output_node_names.empty())
{
output_node_names = find_outputs();
}
if(mm->size() == 0)
return;

std::vector<instruction_ref> output_ins;
std::transform(output_node_names.begin(),
output_node_names.end(),
std::back_inserter(output_ins),
[&](auto output_name) {
if(not contains(instructions, output_name))
MIGRAPHX_THROW("PARSE_TF: output name " + output_name +
" not found in graph!");
return this->to_nchw(instructions[output_name]);
});
mm->add_return(output_ins);
// Needs to add a ret instruction at the end of
// the program
if(output_node_names.empty())
{
output_node_names = find_outputs();
}

std::vector<instruction_ref> output_ins;
std::transform(output_node_names.begin(),
output_node_names.end(),
std::back_inserter(output_ins),
[&](auto output_name) {
if(not contains(instructions, output_name))
MIGRAPHX_THROW("PARSE_TF: output name " + output_name +
" not found in graph!");
return this->to_nchw(instructions[output_name]);
});
mm->add_return(output_ins);
}

void tf_parser::parse_node(const std::string& name)
Expand Down
7 changes: 4 additions & 3 deletions test/tf/tf_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -63,9 +63,10 @@ migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
migraphx::eliminate_identity{}});

// remove the last return instruction
auto last_ins = std::prev(mm->end());
if(last_ins != mm->end())

if(mm->size() > 0)
{
auto last_ins = std::prev(mm->end());
if(last_ins->name() == "@return")
{
mm->remove_instruction(last_ins);
Expand Down

0 comments on commit e474a04

Please sign in to comment.