|
| 1 | +// This file is part of the ACTS project. |
| 2 | +// |
| 3 | +// Copyright (C) 2016 CERN for the benefit of the ACTS project |
| 4 | +// |
| 5 | +// This Source Code Form is subject to the terms of the Mozilla Public |
| 6 | +// License, v. 2.0. If a copy of the MPL was not distributed with this |
| 7 | +// file, You can obtain one at https://mozilla.org/MPL/2.0/. |
| 8 | + |
| 9 | +#include "Acts/Definitions/Direction.hpp" |
| 10 | +#include "Acts/Definitions/TrackParametrization.hpp" |
| 11 | +#include "Acts/EventData/MultiTrajectory.hpp" |
| 12 | +#include "Acts/EventData/TrackContainer.hpp" |
| 13 | +#include "Acts/EventData/TrackStatePropMask.hpp" |
| 14 | +#include "Acts/EventData/VectorMultiTrajectory.hpp" |
| 15 | +#include "Acts/EventData/VectorTrackContainer.hpp" |
| 16 | +#include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp" |
| 17 | +#include "Acts/Geometry/GeometryIdentifier.hpp" |
| 18 | +#include "Acts/Propagator/DirectNavigator.hpp" |
| 19 | +#include "Acts/Propagator/Navigator.hpp" |
| 20 | +#include "Acts/Propagator/Propagator.hpp" |
| 21 | +#include "Acts/Propagator/SympyStepper.hpp" |
| 22 | +#include "Acts/TrackFitting/GainMatrixSmoother.hpp" |
| 23 | +#include "Acts/TrackFitting/GainMatrixUpdater.hpp" |
| 24 | +#include "Acts/TrackFitting/KalmanFitter.hpp" |
| 25 | +#include "Acts/Utilities/Delegate.hpp" |
| 26 | +#include "Acts/Utilities/Logger.hpp" |
| 27 | +#include "ActsExamples/EventData/IndexSourceLink.hpp" |
| 28 | +#include "ActsExamples/EventData/MeasurementCalibration.hpp" |
| 29 | +#include "ActsExamples/EventData/Track.hpp" |
| 30 | +#include "ActsExamples/TrackFitting/RefittingCalibrator.hpp" |
| 31 | +#include "TrackFitterFunction.hpp" |
| 32 | + |
| 33 | +#include <algorithm> |
| 34 | +#include <cmath> |
| 35 | +#include <functional> |
| 36 | +#include <memory> |
| 37 | +#include <utility> |
| 38 | +#include <vector> |
| 39 | + |
| 40 | +namespace Acts { |
| 41 | +class MagneticFieldProvider; |
| 42 | +class SourceLink; |
| 43 | +class Surface; |
| 44 | +class TrackingGeometry; |
| 45 | +} // namespace Acts |
| 46 | + |
| 47 | +namespace { |
| 48 | + |
| 49 | +using Stepper = Acts::SympyStepper; |
| 50 | +using Propagator = Acts::Propagator<Stepper, Acts::Navigator>; |
| 51 | +using Fitter = Acts::KalmanFitter<Propagator, Acts::VectorMultiTrajectory>; |
| 52 | +using DirectPropagator = Acts::Propagator<Stepper, Acts::DirectNavigator>; |
| 53 | +using DirectFitter = |
| 54 | + Acts::KalmanFitter<DirectPropagator, Acts::VectorMultiTrajectory>; |
| 55 | + |
| 56 | +using TrackContainer = |
| 57 | + Acts::TrackContainer<Acts::VectorTrackContainer, |
| 58 | + Acts::VectorMultiTrajectory, std::shared_ptr>; |
| 59 | + |
| 60 | +struct SimpleReverseFilteringLogic { |
| 61 | + double momentumThreshold = 0; |
| 62 | + |
| 63 | + bool doBackwardFiltering( |
| 64 | + Acts::VectorMultiTrajectory::ConstTrackStateProxy trackState) const { |
| 65 | + auto momentum = std::abs(1 / trackState.filtered()[Acts::eBoundQOverP]); |
| 66 | + return (momentum <= momentumThreshold); |
| 67 | + } |
| 68 | +}; |
| 69 | + |
| 70 | +using namespace ActsExamples; |
| 71 | + |
| 72 | +struct KalmanFitterFunctionImpl final : public TrackFitterFunction { |
| 73 | + Fitter fitter; |
| 74 | + DirectFitter directFitter; |
| 75 | + |
| 76 | + Acts::GainMatrixUpdater kfUpdater; |
| 77 | + Acts::GainMatrixSmoother kfSmoother; |
| 78 | + SimpleReverseFilteringLogic reverseFilteringLogic; |
| 79 | + |
| 80 | + bool multipleScattering = false; |
| 81 | + bool energyLoss = false; |
| 82 | + Acts::FreeToBoundCorrection freeToBoundCorrection; |
| 83 | + |
| 84 | + IndexSourceLink::SurfaceAccessor slSurfaceAccessor; |
| 85 | + |
| 86 | + KalmanFitterFunctionImpl(Fitter&& f, DirectFitter&& df, |
| 87 | + const Acts::TrackingGeometry& trkGeo) |
| 88 | + : fitter(std::move(f)), |
| 89 | + directFitter(std::move(df)), |
| 90 | + slSurfaceAccessor{trkGeo} {} |
| 91 | + |
| 92 | + template <typename calibrator_t> |
| 93 | + auto makeKfOptions(const GeneralFitterOptions& options, |
| 94 | + const calibrator_t& calibrator) const { |
| 95 | + Acts::KalmanFitterExtensions<Acts::VectorMultiTrajectory> extensions; |
| 96 | + extensions.updater.connect< |
| 97 | + &Acts::GainMatrixUpdater::operator()<Acts::VectorMultiTrajectory>>( |
| 98 | + &kfUpdater); |
| 99 | + extensions.smoother.connect< |
| 100 | + &Acts::GainMatrixSmoother::operator()<Acts::VectorMultiTrajectory>>( |
| 101 | + &kfSmoother); |
| 102 | + extensions.reverseFilteringLogic |
| 103 | + .connect<&SimpleReverseFilteringLogic::doBackwardFiltering>( |
| 104 | + &reverseFilteringLogic); |
| 105 | + |
| 106 | + Acts::KalmanFitterOptions<Acts::VectorMultiTrajectory> kfOptions( |
| 107 | + options.geoContext, options.magFieldContext, options.calibrationContext, |
| 108 | + extensions, options.propOptions, &(*options.referenceSurface)); |
| 109 | + |
| 110 | + kfOptions.referenceSurfaceStrategy = |
| 111 | + Acts::KalmanFitterTargetSurfaceStrategy::first; |
| 112 | + kfOptions.multipleScattering = multipleScattering; |
| 113 | + kfOptions.energyLoss = energyLoss; |
| 114 | + kfOptions.freeToBoundCorrection = freeToBoundCorrection; |
| 115 | + kfOptions.extensions.calibrator.connect<&calibrator_t::calibrate>( |
| 116 | + &calibrator); |
| 117 | + |
| 118 | + if (options.doRefit) { |
| 119 | + kfOptions.extensions.surfaceAccessor.connect<&RefittingCalibrator::accessSurface>(); |
| 120 | + } else { |
| 121 | + kfOptions.extensions.surfaceAccessor.connect<&IndexSourceLink::SurfaceAccessor::operator()>( |
| 122 | + &slSurfaceAccessor); |
| 123 | + } |
| 124 | + |
| 125 | + return kfOptions; |
| 126 | + } |
| 127 | + |
| 128 | + TrackFitterResult operator()(const std::vector<Acts::SourceLink>& sourceLinks, |
| 129 | + const TrackParameters& initialParameters, |
| 130 | + const GeneralFitterOptions& options, |
| 131 | + const MeasurementCalibratorAdapter& calibrator, |
| 132 | + TrackContainer& tracks) const override { |
| 133 | + const auto kfOptions = makeKfOptions(options, calibrator); |
| 134 | + return fitter.fit(sourceLinks.begin(), sourceLinks.end(), initialParameters, |
| 135 | + kfOptions, tracks); |
| 136 | + } |
| 137 | + |
| 138 | + TrackFitterResult operator()( |
| 139 | + const std::vector<Acts::SourceLink>& sourceLinks, |
| 140 | + const TrackParameters& initialParameters, |
| 141 | + const GeneralFitterOptions& options, |
| 142 | + const RefittingCalibrator& calibrator, |
| 143 | + const std::vector<const Acts::Surface*>& surfaceSequence, |
| 144 | + TrackContainer& tracks) const override { |
| 145 | + const auto kfOptions = makeKfOptions(options, calibrator); |
| 146 | + return directFitter.fit(sourceLinks.begin(), sourceLinks.end(), |
| 147 | + initialParameters, kfOptions, surfaceSequence, |
| 148 | + tracks); |
| 149 | + } |
| 150 | +}; |
| 151 | + |
| 152 | +} // namespace |
| 153 | + |
| 154 | +std::shared_ptr<ActsExamples::TrackFitterFunction> |
| 155 | +ActsExamples::makeKalmanFitterFunction( |
| 156 | + std::shared_ptr<const Acts::TrackingGeometry> trackingGeometry, |
| 157 | + std::shared_ptr<const Acts::MagneticFieldProvider> magneticField, |
| 158 | + bool multipleScattering, bool energyLoss, |
| 159 | + double reverseFilteringMomThreshold, |
| 160 | + Acts::FreeToBoundCorrection freeToBoundCorrection, |
| 161 | + const Acts::Logger& logger) { |
| 162 | + // Stepper should be copied into the fitters |
| 163 | + const Stepper stepper(std::move(magneticField)); |
| 164 | + |
| 165 | + // Standard fitter |
| 166 | + const auto& geo = *trackingGeometry; |
| 167 | + Acts::Navigator::Config cfg{std::move(trackingGeometry)}; |
| 168 | + cfg.resolvePassive = false; |
| 169 | + cfg.resolveMaterial = true; |
| 170 | + cfg.resolveSensitive = true; |
| 171 | + Acts::Navigator navigator(cfg, logger.cloneWithSuffix("Navigator")); |
| 172 | + Propagator propagator(stepper, std::move(navigator), |
| 173 | + logger.cloneWithSuffix("Propagator")); |
| 174 | + Fitter trackFitter(std::move(propagator), logger.cloneWithSuffix("Fitter")); |
| 175 | + |
| 176 | + // Direct fitter |
| 177 | + Acts::DirectNavigator directNavigator{ |
| 178 | + logger.cloneWithSuffix("DirectNavigator")}; |
| 179 | + DirectPropagator directPropagator(stepper, std::move(directNavigator), |
| 180 | + logger.cloneWithSuffix("DirectPropagator")); |
| 181 | + DirectFitter directTrackFitter(std::move(directPropagator), |
| 182 | + logger.cloneWithSuffix("DirectFitter")); |
| 183 | + |
| 184 | + // build the fitter function. owns the fitter object. |
| 185 | + auto fitterFunction = std::make_shared<KalmanFitterFunctionImpl>( |
| 186 | + std::move(trackFitter), std::move(directTrackFitter), geo); |
| 187 | + fitterFunction->multipleScattering = multipleScattering; |
| 188 | + fitterFunction->energyLoss = energyLoss; |
| 189 | + fitterFunction->reverseFilteringLogic.momentumThreshold = |
| 190 | + reverseFilteringMomThreshold; |
| 191 | + fitterFunction->freeToBoundCorrection = freeToBoundCorrection; |
| 192 | + |
| 193 | + return fitterFunction; |
| 194 | +} |
0 commit comments