Skip to content

Commit

Permalink
Fake_factor method update
Browse files Browse the repository at this point in the history
  • Loading branch information
hephysicist committed Feb 14, 2025
1 parent e2e2ef4 commit b3cc09f
Showing 1 changed file with 138 additions and 49 deletions.
187 changes: 138 additions & 49 deletions columnflow/tasks/data_driven_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,23 @@ def run(self):
enable=["configs", "skip_configs", "datasets", "skip_datasets", "shifts", "skip_shifts"],
)

class dict_creator():
def init_dict(self, ax_list):
if not ax_list:
return -1.
else:
ax = ax_list[0]
updated_ax = ax_list[1:]
get_ax_dict = lambda ax, ax_list, func : {ax.bin(i): func(ax_list) for i in range(ax.size)}
return get_ax_dict(ax,updated_ax, self.init_dict)


class ComputeFakeFactors(
DatasetsProcessesMixin,
CategoriesMixin,
WeightProducerMixin,
ProducersMixin,
dict_creator,
):
sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox"))

Expand Down Expand Up @@ -279,7 +291,7 @@ def requires(self):
d: self.reqs.PrepareFakeFactorHistograms.req(
self,
dataset=d,
branch=-1,
branch=-1
)
for d in self.datasets
}
Expand All @@ -297,9 +309,15 @@ def run(self):
import hist
import numpy as np
from scipy.optimize import curve_fit
from scipy.special import erf
import matplotlib.pyplot as plt
import correctionlib
import correctionlib.convert as cl_convert
import correctionlib.schemav2 as cs
plt.figure(dpi=200)
plt.rcParams.update({
"text.usetex": True,
"font.family": "monospace",
"font.monospace": 'Computer Modern Typewriter'
})
# preare inputs and outputs
inputs = self.input()
outputs = self.output()
Expand Down Expand Up @@ -346,6 +364,8 @@ def get_dr_hist(self, h, det_reg):
cat_name = self.categories[0]
cat = self.config_inst.get_category(cat_name.replace('sr',det_reg))
return h[{"category": hist.loc(cat.id)}]

get_id = lambda ax, key: [i in enumerate(ax.keys)]

data_num = get_dr_hist(self, h_data, num_reg)
data_den = get_dr_hist(self, h_data, den_reg)
Expand All @@ -360,58 +380,80 @@ def get_dr_hist(self, h, det_reg):
def rel_err(x):
return x.variances()/np.maximum(x.values()**2, 1)

ff_err2 = np.abs(1./den) * (data_num.variances()**0.5 + mc_num.variances()**0.5) + np.abs(num)/(den**2) * (data_den.variances()**0.5 + mc_den.variances()**0.5)

def fitf(x, a, b):
return a + b * x
#make interpolation of the ff values
ipt_range = ff_val.shape[0]
x = data_num.axes[0].centers
ff_err = ff_val * ((data_num.variances() + mc_num.variances())**0.5 / np.abs(num) + (data_den.variances() + mc_den.variances())**0.5 / np.abs(den))


ff_fit = np.zeros((*np.shape(ff_val),3))
for idm in range(ff_val.shape[1]):
mask = ff_val[:,idm] > 0
y = ff_val[mask,idm]
y_err = ff_err2[mask,idm]
x_masked = x[mask]
popt, pcov = curve_fit(fitf,
x_masked,
y,
sigma=y_err,
absolute_sigma=True)
ff_fit[:,idm,0] = fitf(x, *popt)
ff_fit[:,idm,1] = fitf(x, *popt + np.sqrt(np.diag(pcov)))
ff_fit[:,idm,2] = fitf(x, *popt - np.sqrt(np.diag(pcov)))
h = hist.Hist.new
for (var_name, var_axis) in self.config_inst.x.fake_factor_method.axes.items():
h = eval(f'h.{var_axis.ax_str}')
h = h.StrCategory(['nominal', 'up', 'down'], name='syst', label='Statistical uncertainty of the fake factor')
ff_fitted = h.Weight()

ff_fitted.view().value = ff_fit
ff_fitted.name = name
ff_fitted.label = label

ff_raw = ff_fitted.copy().reset()
ff_raw = h.Weight()
ff_raw.view().value[...,0] = ff_val
ff_raw.view().variance[...,0] = ff_err2
ff_raw.view().variance[...,0] = ff_err**2
ff_raw.name = name + '_raw'
ff_raw.label = label + '_raw'

#Make an approximation of tau pt dependance
formula_str = 'p0 + p1*x+p2*x*x'
def fitf(x, p0, p1, p2):
return eval(formula_str)
def jac(x):
from numpy import array
out = array([[ 1, x, x**2],[x, x**2, x**3],[x**2, x**3, x**4]])
return out

def eval_formula(formula_str, popt):
for i,p in enumerate(popt):
formula_str = formula_str.replace(f'p{i}',str(popt[i]))
return formula_str

ff_fitted = ff_raw.copy().reset()
ff_fitted.name = name
ff_fitted.label = label
fitres = {}


axes = list(ff_raw.axes[1:2])
fitres = {}
dc = dict_creator()
for the_field in ['chi2','ndf','popt', 'pcov', 'fitf_str']:
fitres[the_field]= dc.init_dict(axes)

return ff_raw, ff_fitted
dm_axis = ff_raw.axes['tau_dm_pnet']
for dm in dm_axis:
h1d = ff_raw[{'tau_dm_pnet': hist.loc(dm),
'syst': hist.loc('nominal')}]
mask = h1d.values() > 0
y = h1d.values()[mask]
y_err = (h1d.variances()[mask])**0.5
x = h1d.axes[0].centers[mask]
popt, pcov = curve_fit(fitf,x,y,
sigma=y_err,
absolute_sigma=True,
)
fitres['chi2'][dm] = sum(((y - fitf(x, *popt))/y_err)**2)
fitres['ndf'][dm] = len(y) - len(popt)
fitres['popt'][dm] = popt
fitres['pcov'][dm] = pcov

fitres['fitf_str'][dm] = eval_formula(formula_str,popt)
for c, shift_name in enumerate(['down', 'nominal', 'up']): # if down then c=-1, if up c=+1, nominal => c=0
ff_fitted.view().value[:,
ff_fitted.axes[1].index(dm),
ff_fitted.axes[2].index(shift_name)] = fitf(x, *popt + (c-1) * np.sqrt(np.diag(pcov)))
fitres['name'] = name
fitres['jac'] = jac
fitres['fitf'] = fitf
return ff_raw, ff_fitted, fitres

wj_raw, wj_fitted = get_ff_corr(self,
wj_raw, wj_fitted, wj_fitres = get_ff_corr(self,
data_hists,
mc_hists,
num_reg = 'dr_num_wj',
den_reg = 'dr_den_wj',
name='ff_wjets',
label='Fake factor W+jets')

qcd_raw, qcd_fitted = get_ff_corr(self,
qcd_raw, qcd_fitted, qcd_fitres = get_ff_corr(self,
data_hists,
mc_hists,
num_reg = 'dr_num_qcd',
Expand All @@ -420,32 +462,59 @@ def fitf(x, a, b):
label='Fake factor QCD')

corr_list = []
for h in [wj_raw, wj_fitted, qcd_raw, qcd_fitted]:
corr = cl_convert.from_histogram(h)
corr.data.flow = "clamp"
corr.version = 2
corr_list.append(corr)
cset = correctionlib.schemav2.CorrectionSet(
for fitres in [wj_fitres, qcd_fitres]:
formula_str = fitres['fitf_str']
dm_bins = []
for (dm, the_formula) in formula_str.items():
x_max = 100
last_val = fitres['fitf'](x_max,* fitres['popt'][dm])

dm_bins.append(cs.CategoryItem(
key=dm,
value=cs.Formula(
nodetype="formula",
variables=["tau_pt"],
parser="TFormula",
expression=f'({the_formula})/(1. + exp(10.*(x-{x_max}))) + ({last_val})/(1. + exp(-10.*(x-{x_max})))',
)))
corr_list.append(cs.Correction(
name=fitres['name'],
description=f"fake factor correcton for {fitres['name'].split('_')[1]}",
version=2,
inputs=[
cs.Variable(name="tau_pt", type="real",description="pt of tau"),
cs.Variable(name="tau_dm_pnet", type="int", description="PNet decay mode of tau"),
],
output=cs.Variable(name="weight", type="real", description="Multiplicative event weight"),
data=cs.Category(
nodetype="category",
input="tau_dm_pnet",
content=dm_bins,)
))

cset = cs.CorrectionSet(
schema_version=2,
description="Fake factors",
corrections=corr_list
)
self.output()['ff_json'].dump(cset.json(exclude_unset=True), formatter="json")



#Plot fake factors:
for h_name in ['wj', 'qcd']:
h_raw = eval(f'{h_name}_raw')
h_fitted = eval(f'{h_name}_fitted')

fig, ax = plt.subplots(figsize=(12, 8))
h_raw[...,'nominal'].plot2d(ax=ax)
self.output()['plots']['_'.join((h_name,'nominal'))].dump(fig, formatter="mpl")

fitres = wj_fitres if h_name == 'wj' else qcd_fitres
dm_axis = h_raw.axes['tau_dm_pnet']
for dm in dm_axis:
h1d = h_raw[{'tau_dm_pnet': hist.loc(dm),
'syst': hist.loc('nominal')}]

hfit = h_fitted[{'tau_dm_pnet': hist.loc(dm)}]

fig, ax = plt.subplots(figsize=(8, 6))
mask = h1d.counts() > 0
x = h1d.axes[0].centers[mask]
Expand All @@ -457,13 +526,33 @@ def fitf(x, a, b):
marker='o',
fmt='o',
line=None, color='#2478B7', capsize=4)
ax.plot(hfit.axes[0].centers,
hfit[:,0].counts(),
x_fine = np.linspace(x[0],x[-1],num=100)
popt = fitres['popt'][dm]
pcov = fitres['pcov'][dm]
jac = fitres['jac']
def err(x,jac,pcov):
from numpy import sqrt,einsum
return sqrt(einsum('ij,ij',jac(x),pcov))

import functools
err_y = list(map(functools.partial(err, jac=jac,pcov=pcov), x_fine))

y_fitf = fitres['fitf'](x_fine,*popt)
y_fitf_up = fitres['fitf'](x_fine,*popt) + err_y
y_fitf_down = fitres['fitf'](x_fine,*(popt)) - err_y

ax.plot(x_fine,
y_fitf,
color='#FF867B')
ax.fill_between(hfit.axes[0].centers, hfit[:,2].counts(), hfit[:,1].counts(), color='#83d55f', alpha=0.5)
ax.fill_between(x_fine, y_fitf_up, y_fitf_down, color='#83d55f', alpha=0.5)
ax.set_ylabel('Fake Factor')
ax.set_xlabel('Tau pT [GeV]')
ax.set_title(f'Jet Fake Factors (Tau PNet Decay Mode {(dm)}')
ax.set_title(f'Jet Fake Factors :Tau PNet Decay Mode {(dm)}')
ax.annotate(rf"$\frac{{\chi^2}}{{ndf}} = \frac{{{np.round(fitres['chi2'][dm],2)}}}{{{fitres['ndf'][dm]}}}$",
(0.8, 0.9),
xycoords='axes fraction',
fontsize=20)

self.output()['plots1d']['_'.join((h_name,str(dm)))].dump(fig, formatter="mpl")


Expand Down

0 comments on commit b3cc09f

Please sign in to comment.