Skip to content

Commit

Permalink
Add two filter method
Browse files Browse the repository at this point in the history
  • Loading branch information
beomki-yeo committed Dec 9, 2024
1 parent 08ed780 commit 133fe57
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 74 deletions.
2 changes: 1 addition & 1 deletion core/include/traccc/finding/details/find_tracks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ track_candidate_container_types::host find_tracks(
// Run the Kalman update on a copy of the track parameters
const bool res =
sf.template visit_mask<gain_matrix_updater<algebra_type>>(
trk_state, in_param, false);
trk_state, in_param);

// The chi2 from Kalman update should be less than chi2_max
if (res && trk_state.filtered_chi2() < config.chi2_max) {
Expand Down
41 changes: 41 additions & 0 deletions core/include/traccc/fitting/actors/surface_id_aborter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2024 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

// Project include(s)
#include "detray/definitions/detail/qualifiers.hpp"
#include "detray/propagator/base_actor.hpp"

// System include(s)
#include <limits>

namespace traccc {

/// Aborter triggered when the next surface is reached
struct surface_id_aborter : detray::actor {
struct state {
detray::geometry::barcode m_abort_id;
};

template <typename propagator_state_t>
DETRAY_HOST_DEVICE void operator()(state &abrt_state,
propagator_state_t &prop_state) const {

auto &navigation = prop_state._navigation;

// Abort if the propagator is on the surface with abort ID
if ((navigation.is_on_sensitive() ||
navigation.encountered_sf_material())) {
if (navigation.barcode() == abrt_state.m_abort_id) {
prop_state._heartbeat &= navigation.abort();
}
}
}
};

} // namespace traccc
75 changes: 11 additions & 64 deletions core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,16 @@ struct gain_matrix_updater {
TRACCC_HOST_DEVICE inline bool operator()(
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
track_state<algebra_t>& trk_state,
const bound_track_parameters& bound_params,
const bool backward_mode) const {
const bound_track_parameters& bound_params) const {

using shape_type = typename mask_group_t::value_type::shape;

const auto D = trk_state.get_measurement().meas_dim;
assert(D == 1u || D == 2u);
if (D == 1u) {
return update<1u, shape_type>(trk_state, bound_params,
backward_mode);
return update<1u, shape_type>(trk_state, bound_params);
} else if (D == 2u) {
return update<2u, shape_type>(trk_state, bound_params,
backward_mode);
return update<2u, shape_type>(trk_state, bound_params);
}

return false;
Expand All @@ -61,8 +58,7 @@ struct gain_matrix_updater {
template <size_type D, typename shape_t>
TRACCC_HOST_DEVICE inline bool update(
track_state<algebra_t>& trk_state,
const bound_track_parameters& bound_params,
const bool backward_mode) const {
const bound_track_parameters& bound_params) const {

static_assert(((D == 1u) || (D == 2u)),
"The measurement dimension should be 1 or 2");
Expand Down Expand Up @@ -92,10 +88,8 @@ struct gain_matrix_updater {
bound_params.covariance();

// Set track state parameters
if (!backward_mode) {
trk_state.predicted().set_vector(predicted_vec);
trk_state.predicted().set_covariance(predicted_cov);
}
trk_state.predicted().set_vector(predicted_vec);
trk_state.predicted().set_covariance(predicted_cov);

if constexpr (std::is_same_v<shape_t, detray::line<true>> ||
std::is_same_v<shape_t, detray::line<false>>) {
Expand Down Expand Up @@ -138,59 +132,12 @@ struct gain_matrix_updater {
}

// Set the track state parameters
if (!backward_mode) {
trk_state.filtered().set_vector(filtered_vec);
trk_state.filtered().set_covariance(filtered_cov);
trk_state.filtered_chi2() = matrix_operator().element(chi2, 0, 0);

// Wrap the phi in the range of [-pi, pi]
wrap_phi(trk_state.filtered());
}
trk_state.filtered().set_vector(filtered_vec);
trk_state.filtered().set_covariance(filtered_cov);
trk_state.filtered_chi2() = matrix_operator().element(chi2, 0, 0);

if (backward_mode) {
assert(trk_state.filtered().surface_link() ==
bound_params.surface_link());

const matrix_type<e_bound_size, e_bound_size> predicted_cov_inv =
matrix_operator().inverse(predicted_cov);
const matrix_type<e_bound_size, e_bound_size> filtered_cov_inv =
matrix_operator().inverse(trk_state.filtered().covariance());

// Eq (3.38) of "Pattern Recognition, Tracking and Vertex
// Reconstruction in Particle Detectors"
const matrix_type<e_bound_size, e_bound_size> smoothed_cov_inv =
predicted_cov_inv + filtered_cov_inv;

const matrix_type<e_bound_size, e_bound_size> smoothed_cov =
matrix_operator().inverse(smoothed_cov_inv);

// Eq (3.38) of "Pattern Recognition, Tracking and Vertex
// Reconstruction in Particle Detectors"
const matrix_type<e_bound_size, 1u> smoothed_vec =
smoothed_cov *
(filtered_cov_inv * trk_state.filtered().vector() +
predicted_cov_inv * predicted_vec);

trk_state.smoothed().set_vector(smoothed_vec);
trk_state.smoothed().set_covariance(smoothed_cov);

const matrix_type<D, 1> residual_smt =
meas_local - H * smoothed_vec;

// Eq (3.39) of "Pattern Recognition, Tracking and Vertex
// Reconstruction in Particle Detectors"
const matrix_type<D, D> R_smt =
V - H * smoothed_cov * matrix_operator().transpose(H);

// Eq (3.40) of "Pattern Recognition, Tracking and Vertex
// Reconstruction in Particle Detectors"
const matrix_type<1, 1> chi2_smt =
matrix_operator().transpose(residual_smt) *
matrix_operator().inverse(R_smt) * residual_smt;

trk_state.smoothed_chi2() =
matrix_operator().element(chi2_smt, 0, 0);
}
// Wrap the phi in the range of [-pi, pi]
wrap_phi(trk_state.filtered());

return true;
}
Expand Down
3 changes: 1 addition & 2 deletions core/include/traccc/fitting/kalman_filter/kalman_actor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ struct kalman_actor : detray::actor {

const bool res =
sf.template visit_mask<gain_matrix_updater<algebra_t>>(
trk_state, propagation._stepping.bound_params(),
actor_state.backward_mode);
trk_state, propagation._stepping.bound_params());

// Abort if the Kalman update fails
if (!res) {
Expand Down
82 changes: 75 additions & 7 deletions core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
#include "traccc/edm/track_candidate.hpp"
#include "traccc/edm/track_parameters.hpp"
#include "traccc/edm/track_state.hpp"
#include "traccc/finding/actors/interaction_register.hpp"
#include "traccc/fitting/actors/surface_id_aborter.hpp"
#include "traccc/fitting/fitting_config.hpp"
#include "traccc/fitting/kalman_filter/gain_matrix_smoother.hpp"
#include "traccc/fitting/kalman_filter/kalman_actor.hpp"
#include "traccc/fitting/kalman_filter/kalman_step_aborter.hpp"
#include "traccc/fitting/kalman_filter/statistics_updater.hpp"
#include "traccc/fitting/kalman_filter/two_filters_smoother.hpp"
#include "traccc/utils/particle.hpp"

// detray include(s).
Expand Down Expand Up @@ -63,11 +66,11 @@ class kalman_fitter {
using fit_actor = traccc::kalman_actor<algebra_type, vector_type>;
using resetter = detray::parameter_resetter<algebra_type>;

// Forward Propagator type
using actor_chain_type =
detray::actor_chain<detray::dtuple, aborter, transporter, interactor,
fit_actor, resetter, kalman_step_aborter>;

// Propagator type
using propagator_type =
detray::propagator<stepper_t, navigator_t, actor_chain_type>;

Expand Down Expand Up @@ -207,23 +210,85 @@ class kalman_fitter {
/// track and vertex fitting", R.Frühwirth, NIM A.
///
/// @param fitter_state the state of kalman fitter
TRACCC_HOST_DEVICE
void smooth(state& fitter_state, const bound_covariance& /*cov*/) {
TRACCC_HOST_DEVICE void smooth(state& fitter_state,
const bound_covariance& /*cov*/) {

auto& track_states = fitter_state.m_fit_actor_state.m_track_states;

// Since the smoothed track parameter of the last surface can be
// considered to be the filtered one, we can reversly iterate the
// algorithm to obtain the smoothed parameter of other surfaces
auto& last = track_states.back();
last.smoothed().set_parameter_vector(last.filtered());
last.smoothed().set_covariance(last.filtered().covariance());
last.smoothed_chi2() = last.filtered_chi2();

if (m_cfg.use_backward_filter) {
/*
// Create propagator
propagator_type propagator(m_cfg.propagation);
backward_propagator_type propagator(m_cfg.propagation);
for (typename vector_type<
track_state<algebra_type>>::reverse_iterator it =
track_states.rbegin();
it != track_states.rend() - 1; ++it) {
// Get surface corresponding to bound params
const detray::tracking_surface sf{m_detector,
it->surface_link()};
const typename detector_type::geometry_context ctx{};
auto bound_params = it->smoothed();
// inflate_covariance(bound_params, m_cfg.)
// Apply material interaction
typename interactor::state interactor_state;
interactor{}.update(
ctx,
detail::correct_particle_hypothesis(m_cfg.ptc_hypothesis,
bound_params),
bound_params, interactor_state,
static_cast<int>(detray::navigation::direction::e_backward),
sf);
// Two filters (forward & backward) method
typename backward_propagator_type::state propagation(
bound_params, m_field, m_detector);
propagation._navigation.set_volume(it->surface_link().volume());
propagation._navigation.set_direction(
detray::navigation::direction::e_backward);
typename aborter::state path_abrt_state;
typename transporter::state trp_state;
surface_id_aborter::state id_abrt_state{
(it + 1)->surface_link()};
typename interactor::state int_state;
typename interaction_register<interactor>::state reg_state{
int_state};
typename resetter::state rst_state;
propagator.propagate(
propagation,
detray::tie(path_abrt_state, trp_state, id_abrt_state,
reg_state, int_state, rst_state));
const detray::tracking_surface sf_next{
m_detector, (it + 1)->surface_link()};
sf_next.template visit_mask<two_filters_smoother<algebra_type>>(
*(it + 1), propagation._stepping.bound_params());
}
*/

// Set path limit
fitter_state.m_aborter_state.set_path_limit(
m_cfg.propagation.stepping.path_limit);

// Seed param for backward seed = last state of forward filter
auto bw_seed_params =
fitter_state.m_fit_actor_state.m_track_states.back().filtered();
inflate_covariance(bw_seed_params,
m_cfg.covariance_inflation_factor);
//bw_seed_params.set_covariance(cov);

// Two filters (forward & backward) method
typename propagator_type::state propagation(bw_seed_params, m_field,
Expand All @@ -239,7 +304,9 @@ class kalman_fitter {

// Reset the backward mode to false
fitter_state.m_fit_actor_state.backward_mode = false;

} else {
/*
auto& track_states = fitter_state.m_fit_actor_state.m_track_states;
// The Rauch-Tung-Striebel(RTS) smoother requires the following:
Expand All @@ -253,6 +320,7 @@ class kalman_fitter {
last.smoothed().set_parameter_vector(last.filtered());
last.smoothed().set_covariance(last.filtered().covariance());
last.smoothed_chi2() = last.filtered_chi2();
*/

for (typename vector_type<
track_state<algebra_type>>::reverse_iterator it =
Expand Down
Loading

0 comments on commit 133fe57

Please sign in to comment.