diff --git a/include/podio/LinkNavigator.h b/include/podio/LinkNavigator.h new file mode 100644 index 000000000..d525e6067 --- /dev/null +++ b/include/podio/LinkNavigator.h @@ -0,0 +1,106 @@ +#ifndef PODIO_LINKNAVIGATOR_H +#define PODIO_LINKNAVIGATOR_H + +#include "podio/detail/LinkFwd.h" + +#include +#include +#include +#include + +namespace podio { + +namespace detail::associations { + /// A small struct that simply bundles an object and its weight for a more + /// convenient return value for the LinkNavigator + /// + /// @note In most uses the names of the members should not really matter as it + /// is possible to us this via structured bindings + template + struct WeightedObject { + WeightedObject(T obj, float w) : o(obj), weight(w) { + } + T o; ///< The object + float weight; ///< The weight in the association + }; +} // namespace detail::associations + +/// A helper class to more easily handle one-to-many associations. +/// +/// Internally simply populates two maps in its constructor and then queries +/// them to retrieve objects that are associated with another. +/// +/// @note There are no guarantees on the order of the objects in these maps. +/// Hence, there are also no guarantees on the order of the returned objects, +/// even if there inherintly is an order to them in the underlying associations +/// collection. +template +class LinkNavigator { + using FromT = LinkCollT::from_type; + using ToT = LinkCollT::to_type; + + template + using WeightedObject = detail::associations::WeightedObject; + +public: + /// Construct a navigator from an association collection + LinkNavigator(const LinkCollT& associations); + + /// We do only construct from a collection + LinkNavigator() = delete; + LinkNavigator(const LinkNavigator&) = default; + LinkNavigator& operator=(const LinkNavigator&) = default; + LinkNavigator(LinkNavigator&&) = default; + LinkNavigator& operator=(LinkNavigator&&) = default; + ~LinkNavigator() = default; + + /// Get all the objects and weights that are associated to the passed object + /// + /// @param object The object that is labeled *to* in the association + /// + /// @returns A vector of all objects and their weights that are associated to + /// the passed object + std::vector> getAssociated(const ToT& object) const { + const auto& [begin, end] = m_to2from.equal_range(object); + std::vector> result; + result.reserve(std::distance(begin, end)); + + for (auto it = begin; it != end; ++it) { + result.emplace_back(it->second); + } + return result; + } + + /// Get all the objects and weights that are associated to the passed object + /// + /// @param object The object that is labeled *from* in the association + /// + /// @returns A vector of all objects and their weights that are associated to + /// the passed object + std::vector> getAssociated(const FromT& object) const { + const auto& [begin, end] = m_from2to.equal_range(object); + std::vector> result; + result.reserve(std::distance(begin, end)); + + for (auto it = begin; it != end; ++it) { + result.emplace_back(it->second); + } + return result; + } + +private: + std::multimap> m_from2to{}; ///< Map the from to the to objects + std::multimap> m_to2from{}; ///< Map the to to the from objects +}; + +template +LinkNavigator::LinkNavigator(const LinkCollT& associations) { + for (const auto& [from, to, weight] : associations) { + m_from2to.emplace(std::piecewise_construct, std::forward_as_tuple(from), std::forward_as_tuple(to, weight)); + m_to2from.emplace(std::piecewise_construct, std::forward_as_tuple(to), std::forward_as_tuple(from, weight)); + } +} + +} // namespace podio + +#endif // PODIO_LINKNAVIGATOR_H diff --git a/include/podio/detail/LinkCollectionImpl.h b/include/podio/detail/LinkCollectionImpl.h index fe1a61ace..701214c2f 100644 --- a/include/podio/detail/LinkCollectionImpl.h +++ b/include/podio/detail/LinkCollectionImpl.h @@ -40,6 +40,8 @@ class LinkCollection : public podio::CollectionBase { using CollectionDataT = podio::LinkCollectionData; public: + using from_type = FromT; + using to_type = ToT; using value_type = Link; using mutable_type = MutableLink; using const_iterator = LinkCollectionIterator; diff --git a/tests/unittests/links.cpp b/tests/unittests/links.cpp index a296d06f6..12dbd112b 100644 --- a/tests/unittests/links.cpp +++ b/tests/unittests/links.cpp @@ -1,6 +1,7 @@ #include "catch2/catch_test_macros.hpp" #include "podio/LinkCollection.h" +#include "podio/LinkNavigator.h" #include "datamodel/ExampleClusterCollection.h" #include "datamodel/ExampleHitCollection.h" @@ -473,3 +474,51 @@ TEST_CASE("Link JSON conversion", "[links][json]") { } #endif + +TEST_CASE("LinkNavigator basics", "[links]") { + TestLColl coll{}; + std::vector hits(11); + std::vector clusters(3); + + for (size_t i = 0; i < 10; ++i) { + auto a = coll.create(); + a.set(hits[i]); + a.set(clusters[i % 3]); + a.setWeight(i * 0.1f); + } + + auto a = coll.create(); + a.set(hits[10]); + + podio::LinkNavigator nav{coll}; + + for (size_t i = 0; i < 10; ++i) { + const auto& hit = hits[i]; + const auto assocClusters = nav.getAssociated(hit); + REQUIRE(assocClusters.size() == 1); + const auto& [cluster, weight] = assocClusters[0]; + REQUIRE(cluster == clusters[i % 3]); + REQUIRE(weight == i * 0.1f); + } + + const auto& cluster1 = clusters[0]; + auto assocHits = nav.getAssociated(cluster1); + REQUIRE(assocHits.size() == 4); + for (size_t i = 0; i < 4; ++i) { + const auto& [hit, weight] = assocHits[i]; + REQUIRE(hit == hits[i * 3]); + REQUIRE(weight == i * 3 * 0.1f); + } + + const auto& cluster2 = clusters[1]; + assocHits = nav.getAssociated(cluster2); + REQUIRE(assocHits.size() == 3); + for (size_t i = 0; i < 3; ++i) { + const auto& [hit, weight] = assocHits[i]; + REQUIRE(hit == hits[i * 3 + 1]); + REQUIRE(weight == (i * 3 + 1) * 0.1f); + } + + const auto [noCluster, noWeight] = nav.getAssociated(hits[10])[0]; + REQUIRE_FALSE(noCluster.isAvailable()); +}