Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a LinkNavigator utility #646

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions include/podio/LinkNavigator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#ifndef PODIO_LINKNAVIGATOR_H
#define PODIO_LINKNAVIGATOR_H

#include "podio/detail/LinkFwd.h"

#include <map>
#include <tuple>
#include <utility>
#include <vector>

namespace podio {

namespace detail::associations {
/// A small struct that simply bundles an object and its weight for a more
/// convenient return value for the LinkNavigator
///
/// @note In most uses the names of the members should not really matter as it
/// is possible to us this via structured bindings
template <typename T>
struct WeightedObject {
WeightedObject(T obj, float w) : o(obj), weight(w) {
}
T o; ///< The object
float weight; ///< The weight in the association
};
} // namespace detail::associations

/// A helper class to more easily handle one-to-many associations.
///
/// Internally simply populates two maps in its constructor and then queries
/// them to retrieve objects that are associated with another.
///
/// @note There are no guarantees on the order of the objects in these maps.
/// Hence, there are also no guarantees on the order of the returned objects,
/// even if there inherintly is an order to them in the underlying associations
/// collection.
template <typename LinkCollT>
class LinkNavigator {
using FromT = LinkCollT::from_type;
using ToT = LinkCollT::to_type;

template <typename T>
using WeightedObject = detail::associations::WeightedObject<T>;

public:
/// Construct a navigator from an association collection
LinkNavigator(const LinkCollT& associations);

/// We do only construct from a collection
LinkNavigator() = delete;
LinkNavigator(const LinkNavigator&) = default;
LinkNavigator& operator=(const LinkNavigator&) = default;
LinkNavigator(LinkNavigator&&) = default;
LinkNavigator& operator=(LinkNavigator&&) = default;
~LinkNavigator() = default;

/// Get all the objects and weights that are associated to the passed object
///
/// @param object The object that is labeled *to* in the association
///
/// @returns A vector of all objects and their weights that are associated to
/// the passed object
std::vector<WeightedObject<FromT>> getAssociated(const ToT& object) const {
const auto& [begin, end] = m_to2from.equal_range(object);
std::vector<WeightedObject<FromT>> result;
result.reserve(std::distance(begin, end));

for (auto it = begin; it != end; ++it) {
result.emplace_back(it->second);
}
return result;
}

/// Get all the objects and weights that are associated to the passed object
///
/// @param object The object that is labeled *from* in the association
///
/// @returns A vector of all objects and their weights that are associated to
/// the passed object
std::vector<WeightedObject<ToT>> getAssociated(const FromT& object) const {
const auto& [begin, end] = m_from2to.equal_range(object);
std::vector<WeightedObject<ToT>> result;
result.reserve(std::distance(begin, end));

for (auto it = begin; it != end; ++it) {
result.emplace_back(it->second);
}
return result;
}

private:
std::multimap<FromT, WeightedObject<ToT>> m_from2to{}; ///< Map the from to the to objects
std::multimap<ToT, WeightedObject<FromT>> m_to2from{}; ///< Map the to to the from objects
};

template <typename LinkCollT>
LinkNavigator<LinkCollT>::LinkNavigator(const LinkCollT& associations) {
for (const auto& [from, to, weight] : associations) {
m_from2to.emplace(std::piecewise_construct, std::forward_as_tuple(from), std::forward_as_tuple(to, weight));
m_to2from.emplace(std::piecewise_construct, std::forward_as_tuple(to), std::forward_as_tuple(from, weight));
}
}

} // namespace podio

#endif // PODIO_LINKNAVIGATOR_H
2 changes: 2 additions & 0 deletions include/podio/detail/LinkCollectionImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class LinkCollection : public podio::CollectionBase {
using CollectionDataT = podio::LinkCollectionData<FromT, ToT>;

public:
using from_type = FromT;
using to_type = ToT;
using value_type = Link<FromT, ToT>;
using mutable_type = MutableLink<FromT, ToT>;
using const_iterator = LinkCollectionIterator<FromT, ToT>;
Expand Down
49 changes: 49 additions & 0 deletions tests/unittests/links.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "catch2/catch_test_macros.hpp"

#include "podio/LinkCollection.h"
#include "podio/LinkNavigator.h"

#include "datamodel/ExampleClusterCollection.h"
#include "datamodel/ExampleHitCollection.h"
Expand Down Expand Up @@ -473,3 +474,51 @@ TEST_CASE("Link JSON conversion", "[links][json]") {
}

#endif

TEST_CASE("LinkNavigator basics", "[links]") {
TestLColl coll{};
std::vector<ExampleHit> hits(11);
std::vector<ExampleCluster> clusters(3);

for (size_t i = 0; i < 10; ++i) {
auto a = coll.create();
a.set(hits[i]);
a.set(clusters[i % 3]);
a.setWeight(i * 0.1f);
}

auto a = coll.create();
a.set(hits[10]);

podio::LinkNavigator nav{coll};

for (size_t i = 0; i < 10; ++i) {
const auto& hit = hits[i];
const auto assocClusters = nav.getAssociated(hit);
REQUIRE(assocClusters.size() == 1);
const auto& [cluster, weight] = assocClusters[0];
REQUIRE(cluster == clusters[i % 3]);
REQUIRE(weight == i * 0.1f);
}

const auto& cluster1 = clusters[0];
auto assocHits = nav.getAssociated(cluster1);
REQUIRE(assocHits.size() == 4);
for (size_t i = 0; i < 4; ++i) {
const auto& [hit, weight] = assocHits[i];
REQUIRE(hit == hits[i * 3]);
REQUIRE(weight == i * 3 * 0.1f);
}

const auto& cluster2 = clusters[1];
assocHits = nav.getAssociated(cluster2);
REQUIRE(assocHits.size() == 3);
for (size_t i = 0; i < 3; ++i) {
const auto& [hit, weight] = assocHits[i];
REQUIRE(hit == hits[i * 3 + 1]);
REQUIRE(weight == (i * 3 + 1) * 0.1f);
}

const auto [noCluster, noWeight] = nav.getAssociated(hits[10])[0];
REQUIRE_FALSE(noCluster.isAvailable());
}
Loading