Skip to content

Commit

Permalink
add lower pass test to compare graph and loop
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jun 5, 2024
1 parent 00222ac commit 6edf613
Show file tree
Hide file tree
Showing 6 changed files with 385 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/common/snippets/src/lowered/loop_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ void order(const std::vector<size_t>& new_order, std::vector<T>& values) {
"Failed to sort values: `new_order` must contain new indexes for ALL values");
std::vector<T> ordered_values(values.size());
for (size_t i = 0; i < values.size(); ++i) {
ordered_values[new_order[i]] = values[i];
ordered_values[i] = values[new_order[i]];
}
values = std::move(ordered_values);
}
Expand Down
3 changes: 3 additions & 0 deletions src/common/snippets/tests/include/lir_comparator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class LIRComparator {
static Result compare_loop_managers(const ov::snippets::lowered::LoopManagerPtr& loop_manager,
const ov::snippets::lowered::LoopManagerPtr& loop_manager_ref);

static Result compare_loop_managers_without_index(const ov::snippets::lowered::LoopManagerPtr& loop_manager,
const ov::snippets::lowered::LoopManagerPtr& loop_manager_ref);

static Result compare_loop_info(const ov::snippets::lowered::LoopInfoPtr& loop_info,
const ov::snippets::lowered::LoopInfoPtr& loop_info_ref);

Expand Down
18 changes: 18 additions & 0 deletions src/common/snippets/tests/include/lir_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,24 @@ void create_and_add_unified_loop_info(const std::shared_ptr<ov::snippets::lowere
const std::vector<ov::snippets::lowered::LoopPort>& entries,
const std::vector<ov::snippets::lowered::LoopPort>& exits,
bool add_default_handlers = true);
/**
* @brief Creates unified loop info based on provided entry and exit points, and adds it to the linear_ir's loops map.
* Meanwhile set loop id to expr range [loop_begin_pos, loop_end_pos).
* @attention This helper wraps LoopManager::mark_loop method, which also marks expressions with the corresponding loop info
* @param linear_ir linear_ir in which loop info should be added
* @param loop_begin_pos begin expr postion in this loop
* @param loop_end_pos end expr postion in this loop
* @param entries entry points of loop
* @param exits exit points of loop
*/
void create_and_add_unified_loop_info(const std::shared_ptr<ov::snippets::lowered::LinearIR>& linear_ir,
ov::snippets::lowered::LinearIR::constExprIt loop_begin_pos,
ov::snippets::lowered::LinearIR::constExprIt loop_end_pos,
size_t work_amount,
size_t increment,
const std::vector<ov::snippets::lowered::LoopPort>& entries,
const std::vector<ov::snippets::lowered::LoopPort>& exits,
bool add_default_handlers = true);
} // namespace snippets
} // namespace test
} // namespace ov
24 changes: 22 additions & 2 deletions src/common/snippets/tests/src/lir_comparator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,13 @@ LIRComparator::Result LIRComparator::compare(const LinearIRPtr& linear_ir,
for (auto result_it = results.begin(), result_it_ref = results_ref.begin(); result_it != results.end(); ++result_it, ++result_it_ref)
PROPAGATE_ERROR("", run_comparison(result_it, result_it_ref));

if (should_compare(LIRCmpValues::LOOP_MANAGER))
PROPAGATE_ERROR("Loop managers", compare_loop_managers(linear_ir->get_loop_manager(), linear_ir_ref->get_loop_manager()));
if (should_compare(LIRCmpValues::LOOP_MANAGER)) {
if (should_compare(LIRCmpValues::LOOP_INDICES)) {
PROPAGATE_ERROR("Loop managers", compare_loop_managers(linear_ir->get_loop_manager(), linear_ir_ref->get_loop_manager()));
} else {
PROPAGATE_ERROR("Loop managers", compare_loop_managers_without_index(linear_ir->get_loop_manager(), linear_ir_ref->get_loop_manager()));
}
}
return Result::ok();
}

Expand Down Expand Up @@ -127,6 +132,21 @@ LIRComparator::Result LIRComparator::compare_loop_managers(const LoopManagerPtr&
return Result::ok();
}

LIRComparator::Result LIRComparator::compare_loop_managers_without_index(const LoopManagerPtr& loop_manager,
const LoopManagerPtr& loop_manager_ref) {
const auto& map = loop_manager->get_map();
const auto& map_ref = loop_manager_ref->get_map();
COMPARE("Loops map size", map.size(), map_ref.size());

for (auto map_iter = map.cbegin(), map_ref_iter = map_ref.cbegin(); map_iter != map.cend(); map_iter++, map_ref_iter++) {
const auto& loop_info = map_iter->second;
const auto& loop_info_ref = map_ref_iter->second;
PROPAGATE_ERROR("Loop with id " + std::to_string(map_iter->first) + "and Ref Loop with id" + std::to_string(map_ref_iter->first),
compare_loop_info(loop_info, loop_info_ref));
}
return Result::ok();
}

LIRComparator::Result LIRComparator::compare_loop_info(const LoopInfoPtr& loop_info, const LoopInfoPtr& loop_info_ref) {
COMPARE("Loop info type", loop_info->get_type_info(), loop_info_ref->get_type_info());
COMPARE("Work amounts", loop_info->get_work_amount(), loop_info_ref->get_work_amount());
Expand Down
12 changes: 12 additions & 0 deletions src/common/snippets/tests/src/lir_test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ void create_and_add_unified_loop_info(const LinearIRPtr& linear_ir,
loop_manager->mark_loop(linear_ir->begin(), linear_ir->begin(), work_amount, increment, entries, exits, set_default_handlers);
}

void create_and_add_unified_loop_info(const LinearIRPtr& linear_ir,
ov::snippets::lowered::LinearIR::constExprIt loop_begin_pos,
ov::snippets::lowered::LinearIR::constExprIt loop_end_pos,
size_t work_amount,
size_t increment,
const std::vector<LoopPort>& entries,
const std::vector<LoopPort>& exits,
bool set_default_handlers) {
const auto& loop_manager = linear_ir->get_loop_manager();
loop_manager->mark_loop(loop_begin_pos, loop_end_pos, work_amount, increment, entries, exits, set_default_handlers);
}

} // namespace snippets
} // namespace test
} // namespace ov
Loading

0 comments on commit 6edf613

Please sign in to comment.