Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stuff #115

Open
wants to merge 35 commits into
base: master
Choose a base branch
from
Open

stuff #115

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7988ce1
some new features
jkiesele May 10, 2023
08ccd81
added activation
jkiesele May 11, 2023
a73dd02
fix for some machines
jkiesele May 15, 2023
13a2241
convenience functions
jkiesele May 17, 2023
6852bdc
changes
jkiesele Jun 19, 2023
86650f0
added pc pool model
jkiesele Jun 19, 2023
4068b93
just for the diffs
jkiesele Jun 19, 2023
03b4e3b
minimal cleanup
jkiesele Jun 19, 2023
04d3b3f
fixed major issue with batch norm layers. be careful, this will creat…
jkiesele Jun 19, 2023
6c71f2b
revert to make compatible
jkiesele Jun 20, 2023
65512c4
update
jkiesele Jun 20, 2023
d9d37ae
just some docu and perf improvement; no bugs
jkiesele Jun 21, 2023
9eb3459
testing script for push layer
jkiesele Jun 21, 2023
138bda0
fixed bug
jkiesele Jun 21, 2023
c2be21a
compat revert
jkiesele Jun 21, 2023
40b3a91
just a handy wrapper
jkiesele Jun 21, 2023
0f263df
some more debugging vis
jkiesele Jun 21, 2023
17f4f20
more bugfixes, now working well
jkiesele Jun 21, 2023
c04db64
pre model
jkiesele Jun 23, 2023
ae3d2e7
small tool to check batch norm and identify possible issues
jkiesele Jun 23, 2023
e8db0f9
resolved issues with weight loading while maintaining compatibility
jkiesele Jun 23, 2023
d480e5a
possible stability fix
jkiesele Jun 23, 2023
68f0b5a
important bugfix
jkiesele Jun 23, 2023
1728c5a
no nan callback
jkiesele Jun 23, 2023
ed2bfe4
snapshot of testing
jkiesele Jun 23, 2023
0d776a3
mostly config options (compat)
jkiesele Jun 24, 2023
1df9cb7
added baseline training
jkiesele Jun 25, 2023
027acaf
snapshot. needs revert on pz test
jkiesele Jun 25, 2023
caf4ae1
updated readme
jkiesele Jun 25, 2023
5ce3aa1
for fcc studies
jkiesele Jun 25, 2023
df7efc5
snapshot
jkiesele Jun 27, 2023
eecf5dd
removed whitspace
jkiesele Jun 29, 2023
560deb6
added predict part
jkiesele Jul 3, 2023
61d5f61
faster
jkiesele Jul 3, 2023
f8e7457
snapshot
jkiesele Jul 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ The naming scheme should be obvious and must be followed. Compile with make.
Converting the data from ntuples
===========

``convertFromSource.py -i <text file listing all training input files> -o <output dir> -c TrainData_NanoML``
``convertFromSource.py -i <text file listing all training input files> -o <output dir> -c TrainData_NanoML``
The conversion rule itself is located here:
``modules/datastructures/TrainData_NanoML.py``

Other conversion rules / data structures can also be defined. For maximum compatibility, it is advised to follow the NanoML example w.r.t. the final outputs.

The training files (see next section) usually also contain a comment in the beginning pointing to the latest data set at CERN and flatiron.

Standard training and inference
Expand All @@ -68,7 +70,7 @@ cd Train
Look at the first lines of the file `std_training.py` containing a short description and where to find the dataset compatible with that training file. Then execute the following command to run a training.

```
python3 std_training.py <path_to_dataset>/training_data.djcdc <training_output_path>
python3 baseline_training.py <path_to_dataset>/training_data.djcdc <training_output_path>
```
Please notice that the standard configuration might or might not include writing the printout to a file in the training output directory.

Expand Down
338 changes: 338 additions & 0 deletions Train/baseline_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
'''

Compatible with the dataset here:
/eos/home-j/jkiesele/ML4Reco/Gun20Part_NewMerge/train

On flatiron:
/mnt/ceph/users/jkieseler/HGCalML_data/Gun20Part_NewMerge/train

not compatible with datasets before end of Jan 2022

'''

import tensorflow as tf

from tensorflow.keras.layers import Dense, Concatenate

from DeepJetCore.DJCLayers import StopGradient

from Layers import RaggedGlobalExchange, DistanceWeightedMessagePassing, DictModel
from Layers import RaggedGravNet, ScaledGooeyBatchNorm2
from Regularizers import AverageDistanceRegularizer
from LossLayers import LLFullObjectCondensation
from DebugLayers import PlotCoordinates

from model_blocks import condition_input, extent_coords_if_needed, create_outputs, re_integrate_to_full_hits

from callbacks import plotClusterSummary

from DeepJetCore.training.DeepJet_callbacks import simpleMetricsCallback


#loss options:
loss_options={
# here and in the following energy = momentum
'energy_loss_weight': 0.,
'q_min': 1.,
# addition to original OC, adds average position for clusterin
# usually 0.5 is a reasonable value to break degeneracies
# and keep training smooth enough
'use_average_cc_pos': 0.5,
'classification_loss_weight':0.0,
'position_loss_weight':0.,
'timing_loss_weight':0.,
'beta_loss_scale':1.,
# these weights will downweight low energies, for a
# training sample with a good energy distribution,
# this won't be needed.
'use_energy_weights': False,
# this is the standard repulsive hinge loss from the paper
'implementation': 'hinge'
}


# elu behaves well, likely fine
dense_activation='elu'

# record internal metrics every N batches
record_frequency=10
# plot every M times, metrics were recorded. In other words,
# plotting will happen every M*N batches
plotfrequency=50

learningrate = 1e-4

# this is the maximum number of hits (points) per batch,
# not the number of events (samples). This is safer w.r.t.
# memory
nbatch = 10000

#iterations of gravnet blocks
n_neighbours=[64,64]

# 3 is a bit low but nice in the beginning since it can be plotted
n_cluster_space_coordinates = 3
n_gravnet_dims = 3


def gravnet_model(Inputs,
td,
debug_outdir=None,
plot_debug_every=record_frequency*plotfrequency,
):
####################################################################################
##################### Input processing, no need to change much here ################
####################################################################################

input_list = td.interpretAllModelInputs(Inputs,returndict=True)
input_list = condition_input(input_list, no_scaling=True)

#just for info what's available, prints once
print('available inputs',[k for k in input_list.keys()])

rs = input_list['row_splits']
t_idx = input_list['t_idx']
energy = input_list['rechit_energy']
c_coords = input_list['coords']

## build inputs

x_in = Concatenate()([input_list['coords'],
input_list['features']])

x_in = ScaledGooeyBatchNorm2(
fluidity_decay=0.1 #freeze out quickly, just to get good input preprocessing
)(x_in)

x = x_in

c_coords = ScaledGooeyBatchNorm2(
fluidity_decay=0.1 #same here
)(c_coords)


####################################################################################
##################### now the actual model goes below ##############################
####################################################################################

# output of each iteration will be concatenated
allfeat = []

# extend coordinates already here if needed, just as a good starting point
c_coords = extent_coords_if_needed(c_coords, x, n_gravnet_dims)

for i in range(len(n_neighbours)):

# derive new coordinates for clustering
x = RaggedGlobalExchange()([x, rs])

x = Dense(64,activation=dense_activation)(x)
x = Dense(64,activation=dense_activation)(x)
x = Dense(64,activation=dense_activation)(x)
x = Concatenate()([c_coords,x]) #give a good starting point
x = ScaledGooeyBatchNorm2()(x)

xgn, gncoords, gnnidx, gndist = RaggedGravNet(n_neighbours=n_neighbours[i],
n_dimensions=n_gravnet_dims,
n_propagate=64, #this is the number of features that are exchanged
n_filters=64, #output dense
feature_activation = 'elu',
)([x, rs])

x = Concatenate()([x,xgn])

# mostly to record average distances etc. can be used to force coordinates
# to be within reasonable range (but usually not needed)
gndist = AverageDistanceRegularizer(strength=1e-6,
record_metrics=True
)(gndist)

#for information / debugging, can also be safely removed
gncoords = PlotCoordinates(plot_every = plot_debug_every, outdir = debug_outdir,
name='gn_coords_'+str(i))([gncoords,
energy,
t_idx,
rs])
# we have to pass them downwards, otherwise the layer above gets optimised away
# but we don't want the gradient to be disturbed, so it gets stopped
gncoords = StopGradient()(gncoords)
x = Concatenate()([gncoords,x])

# this repeats the distance weighted message passing step from gravnet
# on the same graph topology
x = DistanceWeightedMessagePassing([64,64],
activation=dense_activation
)([x,gnnidx,gndist])

x = ScaledGooeyBatchNorm2()(x)

x = Dense(64,activation=dense_activation)(x)
x = Dense(64,activation=dense_activation)(x)
x = Dense(64,activation=dense_activation)(x)

x = ScaledGooeyBatchNorm2()(x)

allfeat.append(x)



x = Concatenate()([c_coords]+allfeat)#gives a prior to the clustering coords
#create one global feature vector
xg = Dense(512,activation=dense_activation,name='glob_dense_'+str(i))(x)
x = RaggedGlobalExchange()([xg, rs])
x = Concatenate()([x,xg])
# last part of network
x = Dense(64,activation=dense_activation)(x)
x = ScaledGooeyBatchNorm2()(x)
x = Dense(64,activation=dense_activation)(x)
x = ScaledGooeyBatchNorm2()(x)
x = Dense(64,activation=dense_activation)(x)
x = ScaledGooeyBatchNorm2()(x)


#######################################################################
########### the part below should remain almost unchanged #############
########### of course with the exception of the OC loss #############
########### weights #############
#######################################################################

#use a standard batch norm at the last stage


pred_beta, pred_ccoords, pred_dist,\
pred_energy_corr, pred_energy_low_quantile, pred_energy_high_quantile,\
pred_pos, pred_time, pred_time_unc, pred_id = create_outputs(x, n_ccoords=n_cluster_space_coordinates)

# loss
pred_beta = LLFullObjectCondensation(scale=1.,
record_metrics=True,
print_loss=True,
name="FullOCLoss",
**loss_options
)( # oc output and payload
[pred_beta, pred_ccoords, pred_dist,
pred_energy_corr,pred_energy_low_quantile,pred_energy_high_quantile,
pred_pos, pred_time, pred_time_unc,
pred_id] +
[energy]+
# truth information
[input_list['t_idx'] ,
input_list['t_energy'] ,
input_list['t_pos'] ,
input_list['t_time'] ,
input_list['t_pid'] ,
input_list['t_spectator_weight'],
input_list['t_fully_contained'],
input_list['t_rec_energy'],
input_list['t_is_unique'],
input_list['row_splits']])

# fast feedback
pred_ccoords = PlotCoordinates(plot_every=plot_debug_every, outdir = debug_outdir,
name='condensation_coords')([pred_ccoords, pred_beta,input_list['t_idx'],
rs])

# just to have a defined output, only adds names
model_outputs = re_integrate_to_full_hits(
input_list,
pred_ccoords,
pred_beta,
pred_energy_corr,
pred_energy_low_quantile,
pred_energy_high_quantile,
pred_pos,
pred_time,
pred_id,
pred_dist
)

return DictModel(inputs=Inputs, outputs=model_outputs)



import training_base_hgcal
train = training_base_hgcal.HGCalTraining()

if not train.modelSet():
train.setModel(gravnet_model,
td=train.train_data.dataclass(),
debug_outdir=train.outputDir+'/intplots')

train.setCustomOptimizer(tf.keras.optimizers.Nadam(clipnorm=1.,epsilon=1e-2))
#
train.compileModel(learningrate=1e-4)

train.keras_model.summary()


verbosity = 2
import os

publishpath = None #this can be an ssh reachable path (be careful: needs tokens / keypairs)

# establish callbacks


cb = [
simpleMetricsCallback(
output_file=train.outputDir+'/metrics.html',
record_frequency= record_frequency,
plot_frequency = plotfrequency,
select_metrics='FullOCLoss_*loss',
publish=publishpath #no additional directory here (scp cannot create one)
),

simpleMetricsCallback(
output_file=train.outputDir+'/latent_space_metrics.html',
record_frequency= record_frequency,
plot_frequency = plotfrequency,
select_metrics='average_distance_*',
publish=publishpath
),


simpleMetricsCallback(
output_file=train.outputDir+'/val_metrics.html',
call_on_epoch=True,
select_metrics='val_*',
publish=publishpath #no additional directory here (scp cannot create one)
),




]


cb += [
plotClusterSummary(
outputfile=train.outputDir + "/clustering/",
samplefile=train.val_data.getSamplePath(train.val_data.samples[0]),
after_n_batches=200
)
]

#cb=[]

train.change_learning_rate(learningrate)

model, history = train.trainModel(nepochs=3,
batchsize=nbatch,
additional_callbacks=cb)

print("freeze BN")
# Note the submodel here its not just train.keras_model
#for l in train.keras_model.layers:
# if 'FullOCLoss' in l.name:
# l.q_min/=2.

train.change_learning_rate(learningrate/2.)


model, history = train.trainModel(nepochs=121,
batchsize=nbatch,
additional_callbacks=cb)




Loading