Skip to content

Commit

Permalink
Added ensemble support for RingCounting tool. (#278)
Browse files Browse the repository at this point in the history
* add RingCounting tool

* RingCounting.py: added model loading and documentation

* RingCounting.py: fix bug

* RingCounting.py: rm unimplemented method call

* RingCounting.py: fixed pmt masking, fixed file loading, added files_to_load path line commenting support

* RingCounting.py:
+ docStrings
+ saving controlled by [[save_to]]

* added RingCountingConfig, fixed loading error in RingCountingTool

* RingCounting.py:
+ logging of successful file loading
+ pep8 formatting

* Filled RingCountingTool README.md

* Created RingCounting Tool (#1)

Created RingCounting Tool

---------

Co-authored-by: Daniel T. Schmid <[email protected]>

* Update files_to_load.txt

noop to trigger CI test

* Fixed mrdtotalentries being smaller than tanktotalentries throwing a fatal error in LoadRawData.cpp when build type is set to "MRD", since in this case only MRD data is expected to be loaded.

* Improved clarity of the load and save operation of the RingCounting tool:
- Renamed load_from_file -> load_from_csv
- Introduced variable save_to_csv
- Updated documentation at beginning of RingCounting.py to reflect changes
- Updated README.md to reflect changes
- Updated UserTools/RingCounting/RingCountingConfig to reflect changes

Users using the RingCounting tool need to update their RingCountingConfig files following this change.

* Added ensemble support for RingCounting tool.
- Moved logic for processing predictions to method process_predictions
- Added new configuration variables model_is_ensemble, ensemble_model_count,ensemble_prediction_combination_mode
- Added support for loading ensemble models and processing predictions
- Added support for average and majority-voting ensemble
- Added storing of majority-voting ensemble class prediction in new variables RingCountingVoting{SR,MR}Prediction in RecoEvent BoostStore
- Updated documentation within RingCounting.py
- Updated README.md
- Updated configfiles/RingCounting/RingCountingConfig

Users using the RingCounting tool need to update their RingCountingConfig file(s) following this change.

---------

Co-authored-by: Daniel T. Schmid <[email protected]>
Co-authored-by: marc1uk <[email protected]>
  • Loading branch information
3 people authored Nov 26, 2024
1 parent db4358b commit 0d9aca0
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 18 deletions.
11 changes: 11 additions & 0 deletions UserTools/RingCounting/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ reco_event_bs.Set("RingCountingSRPrediction", predicted_sr)
reco_event_bs.Set("RingCountingMRPrediction", predicted_mr)
```

and in case of using an ensemble with majority-voting the following variables are also set:
```
reco_event_bs.Set("RingCountingVotingSRPrediction", predicted_sr)
reco_event_bs.Set("RingCountingVotingMRPrediction", predicted_mr)
```

---
## Configuration

Expand All @@ -55,4 +61,9 @@ files_to_load configfiles/RingCounting/files_to_load.txt # txt file c
version 1_0_0 # Model version
model_path /exp/annie/app/users/dschmid/RingCountingStore/models/ # Model path
pmt_mask november_22 # Masked PMTs (name of hard-coded set of PMTs to ignore)
model_is_ensemble 1 # If set to 1, treat as ensemble
ensemble_model_count 13 # Number of models in ensemble
ensemble_prediction_combination_mode average # average/voting
```
126 changes: 108 additions & 18 deletions UserTools/RingCounting/RingCounting.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,19 @@
# 6. Which PMT mask to use (some PMTs have been turned off in the training); check documentation for which model
# requires what mask.
# -> defined by setting [[pmt_mask]]
# 7. Where to save the predictions, when save_to_csv == true
# 7. Where to save the predictions, when save_to_csv == 1
# -> defined by setting [[save_to]]
# An example config file can also be found in in the RingCountingStore/documentation/ folders mentioned below.
# 8. Whether a single model or ensemble should be used
# -> defined by setting [[model_is_ensemble]]
# 9. How many models make up the ensemble. If model count is N, "sub"-models are labeled 0, 1, ..., N-1.
# -> defined by setting [[ensemble_model_count]]
# 10. How the model predictions should be combined when using an ensemble. Supported:
# - "None" (the type), only the first model's predictions are used. (blank line in config file)
# - "average", average predictions of all models
# - "voting", average predictions and in addition a majority-voting prediction is also produced.
# -> defined by setting [[ensemble_prediction_combination_mode]]
#
# An example config file can be found in the configfiles/RingCounting/ ToolChain.
#
#
# When using on the grid, make sure to only use onsite computing resources. TensorFlow is not supported at all offsite
Expand Down Expand Up @@ -85,16 +95,22 @@ class RingCounting(Tool, RingCountingGlobals):
load_from_csv = std.string() # if 1, load 1 or more CNNImage formatted csv file instead of using toolchain
save_to_csv = std.string() # if 1, save as a csv file in format MR prediction, SR prediction
files_to_load = std.string() # List of files to be loaded (must be in CNNImage format,
# load_from_file has to be true)
# load_from_csv has to be true)
version = std.string() # Model version
model_path = std.string() # Path to model directory
pmt_mask = std.string() # See RingCountingGlobals
save_to = std.string() # Where to save the predictions to
model_is_ensemble = std.string() # Whether the model consists of multiple models acting as a mixture of experts
# (MOE)/ensemble
ensemble_model_count = std.string() # Count of models used in the ensemble
ensemble_prediction_combination_mode = std.string() # How predictions of models are combined: average, voting, ..

# ----------------------------------------------------------------------------------------------------
# Model variables
model = None
predicted = None
model = None # Union[TF.model/Keras.model, None]
ensemble_models = None # Union[List[TF.model/Keras.model], None]
predicted = None # np.array()
predicted_ensemble = None # List[np.array()]

def Initialise(self):
""" Initialise RingCounting tool object in following these steps:
Expand Down Expand Up @@ -123,6 +139,21 @@ def Initialise(self):
self.m_variables.Get("save_to", self.save_to)
self.save_to = str(self.save_to) # cast to str since std.string =/= str
self.pmt_mask = self.PMT_MASKS[self.pmt_mask]
self.m_variables.Get("model_is_ensemble", self.model_is_ensemble)
self.model_is_ensemble = "1" == self.model_is_ensemble
self.m_variables.Get("ensemble_model_count", self.ensemble_model_count)
self.ensemble_model_count = int(self.ensemble_model_count)
if self.ensemble_model_count % 2 == 0:
self.m_log.Log(__file__ + f" WARNING: Number of models in ensemble is even"
f" ({self.ensemble_model_count}). Can lead to unexpected classification when"
f" using voting to determine ensemble predictions.",
self.v_warning, self.m_verbosity)
self.m_variables.Get("ensemble_prediction_combination_mode", self.ensemble_prediction_combination_mode)
if self.ensemble_prediction_combination_mode not in [None, "average", "voting"]:
self.m_log.Log(__file__ + f" WARNING: Unsupported prediction combination mode selected"
f" ({self.ensemble_prediction_combination_mode}). Defaulting to 'average'.",
self.v_warning, self.m_verbosity)
self.ensemble_prediction_combination_mode = "average"

# ----------------------------------------------------------------------------------------------------
# Loading data
Expand All @@ -146,14 +177,7 @@ def Execute(self):
self.mask_pmts()
self.predict()

if not self.load_from_csv:
predicted_sr = float(self.predicted[0][1])
predicted_mr = float(self.predicted[0][0])

reco_event_bs = self.m_data.Stores.at("RecoEvent")

reco_event_bs.Set("RingCountingSRPrediction", predicted_sr)
reco_event_bs.Set("RingCountingMRPrediction", predicted_mr)
self.process_predictions()

return 1

Expand Down Expand Up @@ -210,8 +234,13 @@ def load_data(self):
self.v_debug, self.m_verbosity)

def save_data(self):
""" Save the data to the specified [[save_to]]-file. """
np.savetxt(self.save_to, self.predicted, delimiter=",")
""" Save the data to the specified [[save_to]]-file. When using an ensemble, each line contains all of the
individual model's predictions for that event (ordered as MR1,SR1,MR2,SR2,...).
"""
if self.model_is_ensemble:
np.savetxt(self.save_to, np.array(self.predicted_ensemble).flatten(), delimiter=",")
else:
np.savetxt(self.save_to, self.predicted, delimiter=",")

def mask_pmts(self):
""" Mask PMTs to 0. The PMTs to be masked is given as a list of indices, defined by setting [[pmt_mask]].
Expand All @@ -224,8 +253,18 @@ def mask_pmts(self):
np.put(self.cnn_image_pmt, self.pmt_mask, 0, mode='raise')

def load_model(self):
""" Load the specified model [[version]]."""
self.model = tf.keras.models.load_model(self.model_path + f"RC_model_v{self.version}.model")
""" Load the specified model [[version]]. If [[model_is_ensemble]], load all models in ensemble.
Models files are expected to be named as 'model_path + RC_model_v[[version]].model' for single models, and
'model_path + RC_model_ENS_v[[version]].i.model', where i in {0, 1, ..., [[ensemble_model_count]] - 1} for
ensemble models.
"""
if self.model_is_ensemble:
self.ensemble_models = [
tf.keras.models.load_model(self.model_path + f"RC_model_ENS_v{self.version}.{i}.model")
for i in range(0, self.ensemble_model_count)
]
else:
self.model = tf.keras.models.load_model(self.model_path + f"RC_model_v{self.version}.model")

def get_next_event(self):
""" Get the next event from the BoostStore. """
Expand Down Expand Up @@ -256,8 +295,59 @@ def predict(self):
"""

self.m_log.Log(__file__ + " PREDICTING", self.v_message, self.m_verbosity)
self.predicted = self.model.predict(np.reshape(self.cnn_image_pmt, newshape=(-1, 10, 16, 1)))
if self.model_is_ensemble:
self.predicted_ensemble = [
m.predict(np.reshape(self.cnn_image_pmt, newshape=(-1, 10, 16, 1))) for m in self.ensemble_models
]
else:
self.predicted = self.model.predict(np.reshape(self.cnn_image_pmt, newshape=(-1, 10, 16, 1)))

def process_predictions(self):
""" Process the model predictions. If an ensemble is used, calculate final predictions based on the selected
ensemble mode. Finally, store predictions in the RecoEvent BoostStore.
Store the output of the averaging ensemble and single model within the RecoEvent BoostStore under
RingCountingSRPrediction,
RingCountingMRPrediction.
Store the voted-for class prediction of the voting ensemble within the RecoEvent BoostStore under
RingCountingVotingSRPrediction,
RingCountingVotingMRPrediction.
"""
predicted_sr = -1
predicted_mr = -1
reco_event_bs = self.m_data.Stores.at("RecoEvent")

if self.model_is_ensemble:
if self.ensemble_prediction_combination_mode is None:
predicted_sr = float(self.predicted_ensemble[0][0][1])
predicted_mr = float(self.predicted_ensemble[0][0][0])

elif self.ensemble_prediction_combination_mode in ["average", "voting"]:
# Voting will also get a predicted_sr and mr calculated by averaging, since it can be useful to also
# use the averaged predictions in that case. Sometimes 4 models could yield outputs of a class as
# 0.51, while a single model could classify the class as 0.1. The average will then be < 0.5, leading
# to a different predicted class based on averaging, compared to voting.
predicted_sr = np.average([float(i[0][1]) for i in self.predicted_ensemble])
predicted_mr = np.average([float(i[0][0]) for i in self.predicted_ensemble])

if self.ensemble_prediction_combination_mode == "voting":
# Index will be 1 for argmax in case of SR prediction, hence the sum of the argmaxes gives votes in
# favour of SR.
# In case of having an even number of models and an equal number of votes for both classes, the class
# will be set as neither SR *nor* MR.

votes = np.argmax([float(i[0]) for i in self.predicted_ensemble])
pred_category_sr = 1 if np.sum(votes) > self.ensemble_model_count // 2 else 0
pred_category_mr = 1 if np.sum(votes) < self.ensemble_model_count // 2 else 0

reco_event_bs.Set("RingCountingVotingSRPrediction", pred_category_sr)
reco_event_bs.Set("RingCountingVotingMRPrediction", pred_category_mr)
else:
predicted_sr = float(self.predicted[0][1])
predicted_mr = float(self.predicted[0][0])

reco_event_bs.Set("RingCountingSRPrediction", predicted_sr)
reco_event_bs.Set("RingCountingMRPrediction", predicted_mr)

###################
# ↓ Boilerplate ↓ #
Expand Down
4 changes: 4 additions & 0 deletions configfiles/RingCounting/RingCountingConfig
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ model_path /exp/annie/app/users/dschmid/RingCountingStore/models/
pmt_mask november_22
# Output file
save_to RC_output.csv

model_is_ensemble 1 # If set to 1, treat as ensemble
ensemble_model_count 13 # Number of models in ensemble
ensemble_prediction_combination_mode average # average/voting

0 comments on commit 0d9aca0

Please sign in to comment.