Skip to content

Commit 5a75e37

Browse files
[PWGEM] MlResponse for electron PID (AliceO2Group#8278)
Co-authored-by: ALICE Action Bot <[email protected]>
1 parent 21292e9 commit 5a75e37

12 files changed

+567
-816
lines changed

PWGEM/Dilepton/Core/DielectronCut.h

+9-17
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
#include "TNamed.h"
2525
#include "Math/Vector4D.h"
2626

27-
#include "Tools/ML/MlResponse.h"
28-
#include "Tools/ML/model.h"
27+
#include "PWGEM/Dilepton/Utils/MlResponseDielectronSingleTrack.h"
2928

3029
#include "Framework/Logger.h"
3130
#include "Framework/DataTypes.h"
@@ -241,19 +240,12 @@ class DielectronCut : public TNamed
241240
template <typename TTrack, typename TCollision>
242241
bool PassPIDML(TTrack const& track, TCollision const& collision) const
243242
{
244-
std::vector<float> inputFeatures{static_cast<float>(collision.numContrib()), track.p(), track.tgl(),
245-
track.tpcNSigmaEl(), /*track.tpcNSigmaMu(),*/ track.tpcNSigmaPi(), track.tpcNSigmaKa(), track.tpcNSigmaPr(),
246-
track.tofNSigmaEl(), /*track.tofNSigmaMu(),*/ track.tofNSigmaPi(), track.tofNSigmaKa(), track.tofNSigmaPr(),
247-
track.meanClusterSizeITS() * std::cos(std::atan(track.tgl()))};
248-
249-
// calculate classifier
250-
float prob_ele = mPIDModel->evalModel(inputFeatures)[0];
251-
// LOGF(info, "prob_ele = %f", prob_ele);
252-
if (prob_ele < 0.95) {
243+
/*if (!PassTOFif(track)) { // Allows for pre-selection. But potentially dangerous if analyzers are not aware of it
253244
return false;
254-
} else {
255-
return true;
256-
}
245+
}*/
246+
std::vector<float> inputFeatures = mPIDMlResponse->getInputFeatures(track, collision);
247+
float binningFeature = mPIDMlResponse->getBinningFeature(track, collision);
248+
return mPIDMlResponse->isSelectedMl(inputFeatures, binningFeature);
257249
}
258250

259251
template <typename T>
@@ -426,9 +418,9 @@ class DielectronCut : public TNamed
426418
void ApplyPrefilter(bool flag);
427419
void ApplyPhiV(bool flag);
428420

429-
void SetPIDModel(o2::ml::OnnxModel* model)
421+
void SetPIDMlResponse(o2::analysis::MlResponseDielectronSingleTrack<float>* mlResponse)
430422
{
431-
mPIDModel = model;
423+
mPIDMlResponse = mlResponse;
432424
}
433425

434426
// Getters
@@ -494,7 +486,7 @@ class DielectronCut : public TNamed
494486
float mMinTOFNsigmaPi{-1e+10}, mMaxTOFNsigmaPi{+1e+10};
495487
float mMinTOFNsigmaKa{-1e+10}, mMaxTOFNsigmaKa{+1e+10};
496488
float mMinTOFNsigmaPr{-1e+10}, mMaxTOFNsigmaPr{+1e+10};
497-
o2::ml::OnnxModel* mPIDModel{nullptr};
489+
o2::analysis::MlResponseDielectronSingleTrack<float>* mPIDMlResponse{nullptr};
498490

499491
ClassDef(DielectronCut, 1);
500492
};

PWGEM/Dilepton/Core/Dilepton.h

+29-19
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
#include "DataFormatsParameters/GRPMagField.h"
4646
#include "CCDB/BasicCCDBManager.h"
4747
#include "Tools/ML/MlResponse.h"
48-
#include "Tools/ML/model.h"
4948

5049
#include "PWGEM/Dilepton/DataModel/dileptonTables.h"
5150
#include "PWGEM/Dilepton/Core/DielectronCut.h"
@@ -57,6 +56,7 @@
5756
#include "PWGEM/Dilepton/Utils/EventHistograms.h"
5857
#include "PWGEM/Dilepton/Utils/EMTrackUtilities.h"
5958
#include "PWGEM/Dilepton/Utils/PairUtilities.h"
59+
#include "PWGEM/Dilepton/Utils/MlResponseDielectronSingleTrack.h"
6060

6161
using namespace o2;
6262
using namespace o2::aod;
@@ -186,7 +186,7 @@ struct Dilepton {
186186
Configurable<float> cfg_min_p_its_cluster_size{"cfg_min_p_its_cluster_size", 0.0, "min p to apply ITS cluster size cut"};
187187
Configurable<float> cfg_max_p_its_cluster_size{"cfg_max_p_its_cluster_size", 0.0, "max p to apply ITS cluster size cut"};
188188

189-
Configurable<int> cfg_pid_scheme{"cfg_pid_scheme", static_cast<int>(DielectronCut::PIDSchemes::kTPChadrejORTOFreq), "pid scheme [kTOFreq : 0, kTPChadrej : 1, kTPChadrejORTOFreq : 2, kTPConly : 3]"};
189+
Configurable<int> cfg_pid_scheme{"cfg_pid_scheme", static_cast<int>(DielectronCut::PIDSchemes::kTPChadrejORTOFreq), "pid scheme [kTOFreq : 0, kTPChadrej : 1, kTPChadrejORTOFreq : 2, kTPConly : 3, kTOFif = 4, kPIDML = 5]"};
190190
Configurable<float> cfg_min_TPCNsigmaEl{"cfg_min_TPCNsigmaEl", -2.0, "min. TPC n sigma for electron inclusion"};
191191
Configurable<float> cfg_max_TPCNsigmaEl{"cfg_max_TPCNsigmaEl", +3.0, "max. TPC n sigma for electron inclusion"};
192192
Configurable<float> cfg_min_TPCNsigmaMu{"cfg_min_TPCNsigmaMu", -0.0, "min. TPC n sigma for muon exclusion"};
@@ -201,9 +201,13 @@ struct Dilepton {
201201
Configurable<float> cfg_max_TOFNsigmaEl{"cfg_max_TOFNsigmaEl", +3.0, "max. TOF n sigma for electron inclusion"};
202202
Configurable<bool> enableTTCA{"enableTTCA", true, "Flag to enable or disable TTCA"};
203203

204-
// CCDB configuration for PID ML
205-
Configurable<std::string> BDTLocalPathGamma{"BDTLocalPathGamma", "pid_ml_xgboost.onnx", "Path to the local .onnx file"};
206-
Configurable<std::string> BDTPathCCDB{"BDTPathCCDB", "Users/d/dsekihat/pwgem/pidml/", "Path on CCDB"};
204+
// configuration for PID ML
205+
Configurable<std::vector<std::string>> onnxFileNames{"onnxFileNames", std::vector<std::string>{"filename"}, "ONNX file names for each bin (if not from CCDB full path)"};
206+
Configurable<std::vector<std::string>> onnxPathsCCDB{"onnxPathsCCDB", std::vector<std::string>{"path"}, "Paths of models on CCDB"};
207+
Configurable<std::vector<double>> binsMl{"binsMl", std::vector<double>{-999999., 999999.}, "Bin limits for ML application"};
208+
Configurable<std::vector<double>> cutsMl{"cutsMl", std::vector<double>{0.95}, "ML cuts per bin"};
209+
Configurable<std::vector<std::string>> namesInputFeatures{"namesInputFeatures", std::vector<std::string>{"feature"}, "Names of ML model input features"};
210+
Configurable<std::string> nameBinningFeature{"nameBinningFeature", "pt", "Names of ML model binning feature"};
207211
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB. Exceptions: > 0 for the specific timestamp, 0 gets the run dependent timestamp"};
208212
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
209213
Configurable<bool> enableOptimizations{"enableOptimizations", false, "Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)"};
@@ -465,9 +469,6 @@ struct Dilepton {
465469
used_trackIds.clear();
466470
used_trackIds.shrink_to_fit();
467471

468-
if (eid_bdt) {
469-
delete eid_bdt;
470-
}
471472
delete h2sp_resolution;
472473
}
473474

@@ -603,7 +604,7 @@ struct Dilepton {
603604
fEMEventCut.SetRequireNoCollInTimeRangeStandard(eventcuts.cfgRequireNoCollInTimeRangeStandard);
604605
}
605606

606-
o2::ml::OnnxModel* eid_bdt = nullptr;
607+
o2::analysis::MlResponseDielectronSingleTrack<float> mlResponseSingleTrack;
607608
void DefineDielectronCut()
608609
{
609610
fDielectronCut = DielectronCut("fDielectronCut", "fDielectronCut");
@@ -646,21 +647,30 @@ struct Dilepton {
646647
fDielectronCut.SetTOFNsigmaElRange(dielectroncuts.cfg_min_TOFNsigmaEl, dielectroncuts.cfg_max_TOFNsigmaEl);
647648

648649
if (dielectroncuts.cfg_pid_scheme == static_cast<int>(DielectronCut::PIDSchemes::kPIDML)) { // please call this at the end of DefineDileptonCut
649-
eid_bdt = new o2::ml::OnnxModel();
650+
static constexpr int nClassesMl = 2;
651+
const std::vector<int> cutDirMl = {o2::cuts_ml::CutSmaller, o2::cuts_ml::CutNot};
652+
const std::vector<std::string> labelsClasses = {"Signal", "Background"};
653+
const uint32_t nBinsMl = dielectroncuts.binsMl.value.size() - 1;
654+
const std::vector<std::string> labelsBins(nBinsMl, "bin");
655+
double cutsMlArr[nBinsMl][nClassesMl];
656+
for (uint32_t i = 0; i < nBinsMl; i++) {
657+
cutsMlArr[i][0] = dielectroncuts.cutsMl.value[i];
658+
cutsMlArr[i][1] = 0.;
659+
}
660+
o2::framework::LabeledArray<double> cutsMl = {cutsMlArr[0], nBinsMl, nClassesMl, labelsBins, labelsClasses};
661+
662+
mlResponseSingleTrack.configure(dielectroncuts.binsMl.value, cutsMl, cutDirMl, nClassesMl);
650663
if (dielectroncuts.loadModelsFromCCDB) {
651664
ccdbApi.init(ccdburl);
652-
std::map<std::string, std::string> metadata;
653-
bool retrieveSuccessGamma = ccdbApi.retrieveBlob(dielectroncuts.BDTPathCCDB.value, ".", metadata, dielectroncuts.timestampCCDB.value, false, dielectroncuts.BDTLocalPathGamma.value);
654-
if (retrieveSuccessGamma) {
655-
eid_bdt->initModel(dielectroncuts.BDTLocalPathGamma.value, dielectroncuts.enableOptimizations.value);
656-
} else {
657-
LOG(fatal) << "Error encountered while fetching/loading the Gamma model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?";
658-
}
665+
mlResponseSingleTrack.setModelPathsCCDB(dielectroncuts.onnxFileNames.value, ccdbApi, dielectroncuts.onnxPathsCCDB.value, dielectroncuts.timestampCCDB.value);
659666
} else {
660-
eid_bdt->initModel(dielectroncuts.BDTLocalPathGamma.value, dielectroncuts.enableOptimizations.value);
667+
mlResponseSingleTrack.setModelPathsLocal(dielectroncuts.onnxFileNames.value);
661668
}
669+
mlResponseSingleTrack.cacheInputFeaturesIndices(dielectroncuts.namesInputFeatures);
670+
mlResponseSingleTrack.cacheBinningIndex(dielectroncuts.nameBinningFeature);
671+
mlResponseSingleTrack.init(dielectroncuts.enableOptimizations.value);
662672

663-
fDielectronCut.SetPIDModel(eid_bdt);
673+
fDielectronCut.SetPIDMlResponse(&mlResponseSingleTrack);
664674
} // end of PID ML
665675
}
666676

PWGEM/Dilepton/Core/DileptonMC.h

+31-24
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
#include "DataFormatsParameters/GRPMagField.h"
3737
#include "CCDB/BasicCCDBManager.h"
3838
#include "Tools/ML/MlResponse.h"
39-
#include "Tools/ML/model.h"
4039

4140
#include "Common/Core/RecoDecay.h"
4241
#include "Common/Core/trackUtilities.h"
@@ -51,6 +50,7 @@
5150
#include "PWGEM/Dilepton/Utils/EventHistograms.h"
5251
#include "PWGEM/Dilepton/Utils/EMTrackUtilities.h"
5352
#include "PWGEM/Dilepton/Utils/PairUtilities.h"
53+
#include "PWGEM/Dilepton/Utils/MlResponseDielectronSingleTrack.h"
5454

5555
using namespace o2;
5656
using namespace o2::aod;
@@ -167,7 +167,7 @@ struct DileptonMC {
167167
Configurable<float> cfg_min_p_its_cluster_size{"cfg_min_p_its_cluster_size", 0.0, "min p to apply ITS cluster size cut"};
168168
Configurable<float> cfg_max_p_its_cluster_size{"cfg_max_p_its_cluster_size", 0.0, "max p to apply ITS cluster size cut"};
169169

170-
Configurable<int> cfg_pid_scheme{"cfg_pid_scheme", static_cast<int>(DielectronCut::PIDSchemes::kTPChadrejORTOFreq), "pid scheme [kTOFreq : 0, kTPChadrej : 1, kTPChadrejORTOFreq : 2, kTPConly : 3]"};
170+
Configurable<int> cfg_pid_scheme{"cfg_pid_scheme", static_cast<int>(DielectronCut::PIDSchemes::kTPChadrejORTOFreq), "pid scheme [kTOFreq : 0, kTPChadrej : 1, kTPChadrejORTOFreq : 2, kTPConly : 3, kTOFif = 4, kPIDML = 5]"};
171171
Configurable<float> cfg_min_TPCNsigmaEl{"cfg_min_TPCNsigmaEl", -2.0, "min. TPC n sigma for electron inclusion"};
172172
Configurable<float> cfg_max_TPCNsigmaEl{"cfg_max_TPCNsigmaEl", +3.0, "max. TPC n sigma for electron inclusion"};
173173
Configurable<float> cfg_min_TPCNsigmaMu{"cfg_min_TPCNsigmaMu", -0.0, "min. TPC n sigma for muon exclusion"};
@@ -182,9 +182,13 @@ struct DileptonMC {
182182
Configurable<float> cfg_max_TOFNsigmaEl{"cfg_max_TOFNsigmaEl", +3.0, "max. TOF n sigma for electron inclusion"};
183183
Configurable<bool> enableTTCA{"enableTTCA", true, "Flag to enable or disable TTCA"};
184184

185-
// CCDB configuration for PID ML
186-
Configurable<std::string> BDTLocalPathGamma{"BDTLocalPathGamma", "pid_ml_xgboost.onnx", "Path to the local .onnx file"};
187-
Configurable<std::string> BDTPathCCDB{"BDTPathCCDB", "Users/d/dsekihat/pwgem/pidml/", "Path on CCDB"};
185+
// configuration for PID ML
186+
Configurable<std::vector<std::string>> onnxFileNames{"onnxFileNames", std::vector<std::string>{"filename"}, "ONNX file names for each bin (if not from CCDB full path)"};
187+
Configurable<std::vector<std::string>> onnxPathsCCDB{"onnxPathsCCDB", std::vector<std::string>{"path"}, "Paths of models on CCDB"};
188+
Configurable<std::vector<double>> binsMl{"binsMl", std::vector<double>{-999999., 999999.}, "Bin limits for ML application"};
189+
Configurable<std::vector<double>> cutsMl{"cutsMl", std::vector<double>{0.95}, "ML cuts per bin"};
190+
Configurable<std::vector<std::string>> namesInputFeatures{"namesInputFeatures", std::vector<std::string>{"feature"}, "Names of ML model input features"};
191+
Configurable<std::string> nameBinningFeature{"nameBinningFeature", "pt", "Names of ML model binning feature"};
188192
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB. Exceptions: > 0 for the specific timestamp, 0 gets the run dependent timestamp"};
189193
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
190194
Configurable<bool> enableOptimizations{"enableOptimizations", false, "Enables the ONNX extended model-optimization: sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED)"};
@@ -241,12 +245,7 @@ struct DileptonMC {
241245
HistogramRegistry fRegistry{"output", {}, OutputObjHandlingPolicy::AnalysisObject, false, false};
242246
static constexpr std::string_view event_cut_types[2] = {"before/", "after/"};
243247

244-
~DileptonMC()
245-
{
246-
if (eid_bdt) {
247-
delete eid_bdt;
248-
}
249-
}
248+
~DileptonMC() {}
250249

251250
void addhistograms()
252251
{
@@ -494,7 +493,7 @@ struct DileptonMC {
494493
fEMEventCut.SetRequireNoCollInTimeRangeStandard(eventcuts.cfgRequireNoCollInTimeRangeStandard);
495494
}
496495

497-
o2::ml::OnnxModel* eid_bdt = nullptr;
496+
o2::analysis::MlResponseDielectronSingleTrack<float> mlResponseSingleTrack;
498497
void DefineDielectronCut()
499498
{
500499
fDielectronCut = DielectronCut("fDielectronCut", "fDielectronCut");
@@ -536,23 +535,31 @@ struct DileptonMC {
536535
fDielectronCut.SetTPCNsigmaPrRange(dielectroncuts.cfg_min_TPCNsigmaPr, dielectroncuts.cfg_max_TPCNsigmaPr);
537536
fDielectronCut.SetTOFNsigmaElRange(dielectroncuts.cfg_min_TOFNsigmaEl, dielectroncuts.cfg_max_TOFNsigmaEl);
538537

539-
if (dielectroncuts.cfg_pid_scheme == static_cast<int>(DielectronCut::PIDSchemes::kPIDML)) { // please call this at the end of DefineDielectronCut
540-
// o2::ml::OnnxModel* eid_bdt = new o2::ml::OnnxModel();
541-
eid_bdt = new o2::ml::OnnxModel();
538+
if (dielectroncuts.cfg_pid_scheme == static_cast<int>(DielectronCut::PIDSchemes::kPIDML)) { // please call this at the end of DefineDileptonCut
539+
static constexpr int nClassesMl = 2;
540+
const std::vector<int> cutDirMl = {o2::cuts_ml::CutSmaller, o2::cuts_ml::CutNot};
541+
const std::vector<std::string> labelsClasses = {"Signal", "Background"};
542+
const uint32_t nBinsMl = dielectroncuts.binsMl.value.size() - 1;
543+
const std::vector<std::string> labelsBins(nBinsMl, "bin");
544+
double cutsMlArr[nBinsMl][nClassesMl];
545+
for (uint32_t i = 0; i < nBinsMl; i++) {
546+
cutsMlArr[i][0] = dielectroncuts.cutsMl.value[i];
547+
cutsMlArr[i][1] = 0.;
548+
}
549+
o2::framework::LabeledArray<double> cutsMl = {cutsMlArr[0], nBinsMl, nClassesMl, labelsBins, labelsClasses};
550+
551+
mlResponseSingleTrack.configure(dielectroncuts.binsMl.value, cutsMl, cutDirMl, nClassesMl);
542552
if (dielectroncuts.loadModelsFromCCDB) {
543553
ccdbApi.init(ccdburl);
544-
std::map<std::string, std::string> metadata;
545-
bool retrieveSuccessGamma = ccdbApi.retrieveBlob(dielectroncuts.BDTPathCCDB.value, ".", metadata, dielectroncuts.timestampCCDB.value, false, dielectroncuts.BDTLocalPathGamma.value);
546-
if (retrieveSuccessGamma) {
547-
eid_bdt->initModel(dielectroncuts.BDTLocalPathGamma.value, dielectroncuts.enableOptimizations.value);
548-
} else {
549-
LOG(fatal) << "Error encountered while fetching/loading the Gamma model from CCDB! Maybe the model doesn't exist yet for this runnumber/timestamp?";
550-
}
554+
mlResponseSingleTrack.setModelPathsCCDB(dielectroncuts.onnxFileNames.value, ccdbApi, dielectroncuts.onnxPathsCCDB.value, dielectroncuts.timestampCCDB.value);
551555
} else {
552-
eid_bdt->initModel(dielectroncuts.BDTLocalPathGamma.value, dielectroncuts.enableOptimizations.value);
556+
mlResponseSingleTrack.setModelPathsLocal(dielectroncuts.onnxFileNames.value);
553557
}
558+
mlResponseSingleTrack.cacheInputFeaturesIndices(dielectroncuts.namesInputFeatures);
559+
mlResponseSingleTrack.cacheBinningIndex(dielectroncuts.nameBinningFeature);
560+
mlResponseSingleTrack.init(dielectroncuts.enableOptimizations.value);
554561

555-
fDielectronCut.SetPIDModel(eid_bdt);
562+
fDielectronCut.SetPIDMlResponse(&mlResponseSingleTrack);
556563
} // end of PID ML
557564
}
558565

0 commit comments

Comments
 (0)