Skip to content

Commit

Permalink
[Main Commit] Label Tracking API made public
Browse files Browse the repository at this point in the history
  • Loading branch information
jane-intel committed Feb 28, 2024
1 parent 30cc6bb commit 023b149
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 166 deletions.
64 changes: 0 additions & 64 deletions src/core/dev_api/openvino/core/dimension_tracker.hpp

This file was deleted.

25 changes: 19 additions & 6 deletions src/core/include/openvino/core/dimension.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
#include "openvino/core/interval.hpp"

namespace ov {
class TableOfEquivalence;
class LabelTable;
/// \brief Alias for dimension label type.
using label_t = uint32_t;
/// \brief Special label value indicate no label set.
constexpr label_t no_label = 0;

/// \brief Class representing a dimension, which may be dynamic (undetermined until runtime),
/// in a shape or shape-like object.
Expand Down Expand Up @@ -180,22 +182,33 @@ class OPENVINO_API Dimension {
using std::swap;
swap(a.m_dimension, b.m_dimension);
swap(a.m_label, b.m_label);
swap(a.m_table_of_equivalence, b.m_table_of_equivalence);
swap(a.m_label_table, b.m_label_table);
}

/// \brief String representation of Dimension
std::string to_string() const;

/// Label-related methods of ov::Dimension class

/// \brief Indicates if meaningful label was set to the Dimension
bool has_label() const;
/// \brief Returns label of the Dimension
ov::label_t get_label() const;
/// \brief Sets label value to the Dimension
void set_label(const ov::label_t& label);
/// \brief Sets Label Table to the Dimension
void set_label_table(const std::shared_ptr<LabelTable>& table);
/// \brief Returns Label Table
std::shared_ptr<LabelTable> get_label_table() const;

private:
Dimension(const Interval& interval) : m_dimension(interval) {}

// The actual numerical value of the dimension.
Interval m_dimension{};

// private fields for dimension tracking
friend class DimensionTracker;
label_t m_label{0};
std::shared_ptr<TableOfEquivalence> m_table_of_equivalence = nullptr;
label_t m_label{ov::no_label};
std::shared_ptr<LabelTable> m_label_table = nullptr;
};

/// \brief Insert a human-readable representation of a dimension into an output stream.
Expand Down
41 changes: 41 additions & 0 deletions src/core/include/openvino/core/label_table.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <unordered_map>
#include <unordered_set>

#include "openvino/core/dimension.hpp"
#include "openvino/core/partial_shape.hpp"
#include "openvino/core/type/element_type.hpp"

namespace ov {
using EqualitySoup = std::shared_ptr<std::set<label_t>>;
using EqTable = std::unordered_map<label_t, EqualitySoup>;
using ValTable = std::unordered_map<label_t, ov::Dimension>;

class OPENVINO_API LabelTable : public std::enable_shared_from_this<LabelTable> {
public:
explicit LabelTable(label_t label = 1) : current_label(label){};
void set_as_equal(const ov::Dimension& lhs, const ov::Dimension& rhs);
bool are_equal(const ov::Dimension& lhs, const ov::Dimension& rhs);

const EqTable& get_equivalence_table() const;
const ValTable& get_value_equivalence_table() const;
label_t get_next_label();

void set_up_for_tracking(ov::PartialShape& shape);
void set_up_for_tracking(ov::Dimension& d);
void set_up_for_tracking(ov::Dimension& d, label_t label);
static void reset_tracking_info(ov::Dimension& d);

private:
label_t current_label;
EqTable dimension_table_of_equivalence;
ValTable value_table_of_equivalence;
};

} // namespace ov
26 changes: 23 additions & 3 deletions src/core/src/dimension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <limits>
#include <sstream>

#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/label_table.hpp"
#include "openvino/util/common_util.hpp"

using namespace ov;
Expand Down Expand Up @@ -197,9 +197,9 @@ bool Dimension::merge(Dimension& dst, const Dimension& d1, const Dimension& d2)
dst = Dimension(result_interval);
}

if (auto& t = d1.m_table_of_equivalence)
if (auto& t = d1.m_label_table)
t->set_as_equal(d1, d2);
else if (auto& t = d2.m_table_of_equivalence)
else if (auto& t = d2.m_label_table)
t->set_as_equal(d1, d2);

dst.m_label = merge_labels(d1.m_label, d2.m_label);
Expand Down Expand Up @@ -242,3 +242,23 @@ Dimension::value_type Dimension::get_max_length() const {
Dimension::value_type Dimension::get_min_length() const {
return dimension_length(m_dimension.get_min_val());
}

bool Dimension::has_label() const {
return m_label != ov::no_label;
}

ov::label_t Dimension::get_label() const {
return m_label;
}

void Dimension::set_label(const label_t& label) {
m_label = label;
}

void Dimension::set_label_table(const std::shared_ptr<LabelTable>& table) {
m_label_table = table;
}

std::shared_ptr<LabelTable> Dimension::get_label_table() const {
return m_label_table;
}
93 changes: 0 additions & 93 deletions src/core/src/dimension_tracker.cpp

This file was deleted.

70 changes: 70 additions & 0 deletions src/core/src/label_table.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/core/label_table.hpp"

using namespace ov;

void LabelTable::set_as_equal(const Dimension& lhs, const Dimension& rhs) {
const auto &l_label = lhs.get_label(), r_label = rhs.get_label();
if (l_label == ov::no_label || r_label == ov::no_label)
// TODO after value restriction enabling: non labeled dim propagates restriction (if any) to labeled dim
return;

auto get_soup = [](const label_t& label, EqTable& table) -> EqualitySoup {
if (!table.count(label) || !table.at(label))
table[label] = std::make_shared<std::set<label_t>>(std::set<label_t>{label});
return table.at(label);
};

auto l_soup = get_soup(l_label, dimension_table_of_equivalence);
auto r_soup = get_soup(r_label, dimension_table_of_equivalence);
if (r_soup->size() > l_soup->size()) // we would like to minimize number of iterations in the following for-loop
std::swap(l_soup, r_soup);
l_soup->insert(r_soup->begin(), r_soup->end());
for (const auto& label : *r_soup)
dimension_table_of_equivalence[label] = l_soup;
}

const ValTable& LabelTable::get_value_equivalence_table() const {
return value_table_of_equivalence;
}

const EqTable& LabelTable::get_equivalence_table() const {
return dimension_table_of_equivalence;
}

label_t LabelTable::get_next_label() {
return current_label++;
}

bool LabelTable::are_equal(const Dimension& lhs, const Dimension& rhs) {
if (!lhs.has_label() || !rhs.has_label())
return false;
const auto &l_label = lhs.get_label(), &r_label = rhs.get_label();
if (l_label == r_label)
return true;
if (dimension_table_of_equivalence.count(l_label) && dimension_table_of_equivalence[l_label])
return dimension_table_of_equivalence[l_label]->count(r_label);
return false;
}

void LabelTable::reset_tracking_info(Dimension& d) {
d.set_label(no_label);
d.set_label_table(nullptr);
}

void LabelTable::set_up_for_tracking(Dimension& d) {
set_up_for_tracking(d, get_next_label());
}

void LabelTable::set_up_for_tracking(Dimension& d, label_t label) {
d.set_label(label); // TODO: should we update current label if user uses larger label?
d.set_label_table(this->shared_from_this());
}

void LabelTable::set_up_for_tracking(ov::PartialShape& shape) {
for (auto& d : shape)
set_up_for_tracking(d);
}

0 comments on commit 023b149

Please sign in to comment.