forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Main Commit] Label Tracking API made public
- Loading branch information
1 parent
30cc6bb
commit 023b149
Showing
6 changed files
with
153 additions
and
166 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
|
||
#include "openvino/core/dimension.hpp" | ||
#include "openvino/core/partial_shape.hpp" | ||
#include "openvino/core/type/element_type.hpp" | ||
|
||
namespace ov { | ||
using EqualitySoup = std::shared_ptr<std::set<label_t>>; | ||
using EqTable = std::unordered_map<label_t, EqualitySoup>; | ||
using ValTable = std::unordered_map<label_t, ov::Dimension>; | ||
|
||
class OPENVINO_API LabelTable : public std::enable_shared_from_this<LabelTable> { | ||
public: | ||
explicit LabelTable(label_t label = 1) : current_label(label){}; | ||
void set_as_equal(const ov::Dimension& lhs, const ov::Dimension& rhs); | ||
bool are_equal(const ov::Dimension& lhs, const ov::Dimension& rhs); | ||
|
||
const EqTable& get_equivalence_table() const; | ||
const ValTable& get_value_equivalence_table() const; | ||
label_t get_next_label(); | ||
|
||
void set_up_for_tracking(ov::PartialShape& shape); | ||
void set_up_for_tracking(ov::Dimension& d); | ||
void set_up_for_tracking(ov::Dimension& d, label_t label); | ||
static void reset_tracking_info(ov::Dimension& d); | ||
|
||
private: | ||
label_t current_label; | ||
EqTable dimension_table_of_equivalence; | ||
ValTable value_table_of_equivalence; | ||
}; | ||
|
||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// Copyright (C) 2018-2021 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/core/label_table.hpp" | ||
|
||
using namespace ov; | ||
|
||
void LabelTable::set_as_equal(const Dimension& lhs, const Dimension& rhs) { | ||
const auto &l_label = lhs.get_label(), r_label = rhs.get_label(); | ||
if (l_label == ov::no_label || r_label == ov::no_label) | ||
// TODO after value restriction enabling: non labeled dim propagates restriction (if any) to labeled dim | ||
return; | ||
|
||
auto get_soup = [](const label_t& label, EqTable& table) -> EqualitySoup { | ||
if (!table.count(label) || !table.at(label)) | ||
table[label] = std::make_shared<std::set<label_t>>(std::set<label_t>{label}); | ||
return table.at(label); | ||
}; | ||
|
||
auto l_soup = get_soup(l_label, dimension_table_of_equivalence); | ||
auto r_soup = get_soup(r_label, dimension_table_of_equivalence); | ||
if (r_soup->size() > l_soup->size()) // we would like to minimize number of iterations in the following for-loop | ||
std::swap(l_soup, r_soup); | ||
l_soup->insert(r_soup->begin(), r_soup->end()); | ||
for (const auto& label : *r_soup) | ||
dimension_table_of_equivalence[label] = l_soup; | ||
} | ||
|
||
const ValTable& LabelTable::get_value_equivalence_table() const { | ||
return value_table_of_equivalence; | ||
} | ||
|
||
const EqTable& LabelTable::get_equivalence_table() const { | ||
return dimension_table_of_equivalence; | ||
} | ||
|
||
label_t LabelTable::get_next_label() { | ||
return current_label++; | ||
} | ||
|
||
bool LabelTable::are_equal(const Dimension& lhs, const Dimension& rhs) { | ||
if (!lhs.has_label() || !rhs.has_label()) | ||
return false; | ||
const auto &l_label = lhs.get_label(), &r_label = rhs.get_label(); | ||
if (l_label == r_label) | ||
return true; | ||
if (dimension_table_of_equivalence.count(l_label) && dimension_table_of_equivalence[l_label]) | ||
return dimension_table_of_equivalence[l_label]->count(r_label); | ||
return false; | ||
} | ||
|
||
void LabelTable::reset_tracking_info(Dimension& d) { | ||
d.set_label(no_label); | ||
d.set_label_table(nullptr); | ||
} | ||
|
||
void LabelTable::set_up_for_tracking(Dimension& d) { | ||
set_up_for_tracking(d, get_next_label()); | ||
} | ||
|
||
void LabelTable::set_up_for_tracking(Dimension& d, label_t label) { | ||
d.set_label(label); // TODO: should we update current label if user uses larger label? | ||
d.set_label_table(this->shared_from_this()); | ||
} | ||
|
||
void LabelTable::set_up_for_tracking(ov::PartialShape& shape) { | ||
for (auto& d : shape) | ||
set_up_for_tracking(d); | ||
} |