Skip to content

Commit

Permalink
reducing running time for EMICORR step (#8152)
Browse files Browse the repository at this point in the history
  • Loading branch information
penaguerrero committed Dec 21, 2023
1 parent c357f75 commit 8ddb61c
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 59 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ emicorr
-------

- Fix another bug with subarray=Full. [#8151]
- Speeding up the code and fixing case of subarray not in ref file. [#8152]

1.13.1 (2023-12-19)
===================
Expand Down
121 changes: 62 additions & 59 deletions jwst/emicorr/emicorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@
"MIRIFULONG" : ["Hz390", "Hz10_slow_MIRIFULONG"],
"MIRIFUSHORT" : ["Hz390", "Hz10_slow_MIRIFUSHORT"]}}},

"MASK1065": {
"rowclocks": 82,
"frameclocks": 23968,
"freqs": {"FAST": ["Hz390", "Hz10"],
"SLOW": {"MIRIMAGE" : ["Hz390", "Hz10_slow_MIRIMAGE"],
"MIRIFULONG" : ["Hz390", "Hz10_slow_MIRIFULONG"],
"MIRIFUSHORT" : ["Hz390", "Hz10_slow_MIRIFUSHORT"]}}},

# 390Hz already in-phase for these, but may need corr for other
# frequencies (e.g. 10Hz heater noise)

Expand Down Expand Up @@ -224,16 +232,18 @@ def apply_emicorr(input_model, emicorr_model, save_onthefly_reffile,
log.info('Using reference file to get subarray case.')
subname, rowclocks, frameclocks, freqs2correct = get_subarcase(emicorr_model, subarray, readpatt, detector)
reference_wave_list = []
for fnme in freqs2correct:
freq, ref_wave = get_frequency_info(emicorr_model, fnme)
freqs_numbers.append(freq)
reference_wave_list.append(ref_wave)
if freqs2correct is not None:
for fnme in freqs2correct:
freq, ref_wave = get_frequency_info(emicorr_model, fnme)
freqs_numbers.append(freq)
reference_wave_list.append(ref_wave)
else:
log.info('Using default subarray case corrections.')
subname, rowclocks, frameclocks, freqs2correct = get_subarcase(default_subarray_cases, subarray, readpatt, detector)
for fnme in freqs2correct:
freq = get_frequency_info(default_emi_freqs, fnme)
freqs_numbers.append(freq)
if freqs2correct is not None:
for fnme in freqs2correct:
freq = get_frequency_info(default_emi_freqs, fnme)
freqs_numbers.append(freq)

log.info('With configuration: Subarray={}, Read_pattern={}, Detector={}'.format(subarray, readpatt, detector))
if rowclocks is None or len(freqs_numbers) == 0:
Expand Down Expand Up @@ -275,7 +285,35 @@ def apply_emicorr(input_model, emicorr_model, save_onthefly_reffile,
nx4 = int(nx/4)

dd_all = np.ones((nints, ngroups, ny, nx4))
log.info('Subtracting self-superbias from each group of each integration')
log.info('Subtracting self-superbias from each group of each integration and')

# Calculate times of all pixels in the input integration, then use that to calculate
# phase of all pixels. Times here is in integer numbers of 10us pixels starting from
# the first data pixel in the input image. Times can be a very large integer, so use
# a big datatype. Phaseall (below) is just 0-1.0.

# A safer option is to calculate times_per_integration and calculate the phase at each
# int separately. That way times array will have a smaller number of elements at each
# int, with less risk of datatype overflow. Still, use the largest datatype available
# for the time_this_int array.

times_this_int = np.zeros((ngroups, ny, nx4), dtype='ulonglong')
phaseall = np.zeros((nints, ngroups, ny, nx4))

# non-roi rowclocks between subarray frames (this will be 0 for fullframe)
extra_rowclocks = (1024. - ny) * (4 + 3.)
# e.g. ((1./390.625) / 10e-6) = 256.0 pix and ((1./218.52055) / 10e-6) = 457.62287 pix
period_in_pixels = (1./frequency) / 10.0e-6

start_time, ref_pix_sample = 0, 3

# Need colstop for phase calculation in case of last refpixel in a row. Technically,
# this number comes from the subarray definition (see subarray_cases dict above), but
# calculate it from the input image header here just in case the subarray definitions
# are not available to this routine.
colstop = int( xsize/4 + xstart - 1 )
log.info('doing phase calculation per integration')

for ninti in range(nints_to_phase):
log.debug(' Working on integration: {}'.format(ninti+1))

Expand Down Expand Up @@ -310,34 +348,6 @@ def apply_emicorr(input_model, emicorr_model, save_onthefly_reffile,
# This is the quad-averaged, cleaned, input image data for the exposure
dd_all[ninti, ngroupi, ...] = dd - np.median(dd)

# Calculate times of all pixels in the input integration, then use that to calculate
# phase of all pixels. Times here is in integer numbers of 10us pixels starting from
# the first data pixel in the input image. Times can be a very large integer, so use
# a big datatype. Phaseall (below) is just 0-1.0.

# A safer option is to calculate times_per_integration and calculate the phase at each
# int separately. That way times array will have a smaller number of elements at each
# int, with less risk of datatype overflow. Still, use the largest datatype available
# for the time_this_int array.

times_this_int = np.zeros((ngroups, ny, nx4), dtype='ulonglong')
phaseall = np.zeros((nints, ngroups, ny, nx4))

# non-roi rowclocks between subarray frames (this will be 0 for fullframe)
extra_rowclocks = (1024. - ny) * (4 + 3.)
# e.g. ((1./390.625) / 10e-6) = 256.0 pix and ((1./218.52055) / 10e-6) = 457.62287 pix
period_in_pixels = (1./frequency) / 10.0e-6

start_time, ref_pix_sample = 0, 3

# Need colstop for phase calculation in case of last refpixel in a row. Technically,
# this number comes from the subarray definition (see subarray_cases dict above), but
# calculate it from the input image header here just in case the subarray definitions
# are not available to this routine.
colstop = int( xsize/4 + xstart - 1 )
log.info('Phase calculation per integration')
for l in range(nints_to_phase):
log.debug(' Working on integration: {}'.format(l+1))
for k in range(ngroups): # frames
for j in range(ny): # rows
# nsamples= 1 for fast, 9 for slow (from metadata)
Expand Down Expand Up @@ -366,7 +376,7 @@ def apply_emicorr(input_model, emicorr_model, save_onthefly_reffile,
# number of 10us from the first data pixel in this integration, so to
# convert to phase, divide by the waveform *period* in float pixels
phase_this_int = times_this_int / period_in_pixels
phaseall[l, ...] = phase_this_int - phase_this_int.astype('ulonglong')
phaseall[ninti, ...] = phase_this_int - phase_this_int.astype('ulonglong')

# add a frame time to account for the extra frame reset between MIRI integrations
start_time += frameclocks
Expand Down Expand Up @@ -436,8 +446,8 @@ def apply_emicorr(input_model, emicorr_model, save_onthefly_reffile,
# and optionally amplitude scaled)
# shift and resample reference_wave at pa's phase
# u[0] is the phase shift of reference_wave *to* pa
u = np.where(cc >= max(cc))
lut_reference = rebin(np.roll(reference_wave, u[0]), [period_in_pixels])
u = np.argmax(cc)
lut_reference = rebin(np.roll(reference_wave, u), [period_in_pixels])

# Scale reference wave amplitude to match the pa amplitude from this dataset by
# fitting a line rather than taking a mean ratio so that any DC offset drops out
Expand Down Expand Up @@ -554,23 +564,17 @@ def minmed(data, minval=False, avgval=False, maxval=False):
ngroups, ny, nx = np.shape(data)
medimg = np.zeros((ny, nx))
# use a mask to ignore nans for calculations
masked_data = np.ma.array(data, mask=np.isnan(data))

for i in range(nx):
for j in range(ny):
vec = masked_data[:, j, i]
u = np.where(vec != 0)
n = vec[u].size
if n > 0:
if n <= 2 or minval:
medimg[j, i] = np.ma.min(vec[u])
if maxval:
medimg[j, i] = np.ma.max(vec[u])
if not minval and not maxval and not avgval:
medimg[j, i] = np.ma.median(vec[u])
if avgval:
dmean , _, _, _ = iter_stat_sig_clip(vec[u])
medimg[j, i] = dmean
vec = np.ma.array(data, mask=np.isnan(data))
n = vec.size
if n > 0:
if n <= 2 or minval:
medimg = np.ma.min(vec, axis=0)
if maxval:
medimg = np.ma.max(vec, axis=0)
if not minval and not maxval and not avgval:
medimg = np.ma.median(vec, axis=0)
if avgval:
medimg = np.ma.mean(vec, axis=0)
return medimg


Expand Down Expand Up @@ -734,9 +738,8 @@ def iter_stat_sig_clip(data, sigrej=3.0, maxiter=10):
# Compute the mean + standard deviation of the entire data array,
# these values will be returned if there are fewer than 2 good points.
dmask = np.ones(ngood, dtype='b') + 1
dmean = sum(data * dmask) / ngood
dmean = np.sum(data * dmask) / ngood
dsigma = np.sqrt(sum((data - dmean)**2) / (ngood - 1))
dsigma = dsigma
iiter = 1

# Iteratively compute the mean + stdev, updating the sigma-rejection thresholds
Expand All @@ -752,7 +755,7 @@ def iter_stat_sig_clip(data, sigrej=3.0, maxiter=10):
ngood = sum(dmask)

if ngood >= 2:
dmean = sum(data*dmask) / ngood
dmean = np.sum(data*dmask) / ngood
dsigma = np.sqrt( sum((data - dmean)**2 * dmask) / (ngood - 1) )
dsigma = dsigma

Expand Down

0 comments on commit 8ddb61c

Please sign in to comment.