Skip to content

Commit

Permalink
prepare return dict in a separate function
Browse files Browse the repository at this point in the history
  • Loading branch information
md-arif-shaikh committed Oct 8, 2023
1 parent 665aa8e commit bf3d9c5
Showing 1 changed file with 70 additions and 74 deletions.
144 changes: 70 additions & 74 deletions gw_eccentricity/eccDefinition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,11 +1261,12 @@ def measure_ecc(self, tref_in=None, fref_in=None):
raise KeyError("Exactly one of tref_in and fref_in"
" should be specified.")
elif tref_in is not None:
self.tref_in_ndim = np.ndim(tref_in)
self.domain = "time"
self.ref_ndim = np.ndim(tref_in)
self.tref_in = np.atleast_1d(tref_in)
else:
self.fref_in_ndim = np.ndim(fref_in)
self.tref_in_ndim = self.fref_in_ndim
self.domain = "frequency"
self.ref_ndim = np.ndim(fref_in)
self.fref_in = np.atleast_1d(fref_in)
# Get the pericenters and apocenters
pericenters = self.find_extrema("pericenters")
Expand Down Expand Up @@ -1299,8 +1300,7 @@ def measure_ecc(self, tref_in=None, fref_in=None):
if any([self.probably_quasicircular_pericenter,
self.probably_quasicircular_apocenter]) \
and self.set_failures_to_zero:
return self.set_eccentricity_and_mean_anomaly_to_zero(
tref_in)
return self.set_eccentricity_and_mean_anomaly_to_zero()

# Choose good extrema
self.pericenters_location, self.apocenters_location \
Expand Down Expand Up @@ -1357,7 +1357,7 @@ def measure_ecc(self, tref_in=None, fref_in=None):

# Check if tref_out is reasonable
if len(self.tref_out) == 0:
self.check_input_limits(self.tref_in, self.tmin, self.tmax, "time")
self.check_input_limits(self.tref_in, self.tmin, self.tmax)
raise Exception(
"tref_out is empty. This can happen if the "
"waveform has insufficient identifiable "
Expand Down Expand Up @@ -1394,75 +1394,83 @@ def measure_ecc(self, tref_in=None, fref_in=None):
# check if eccentricity is monotonic and convex
self.check_monotonicity_and_convexity()

# If tref_in is a scalar, return a scalar
if self.tref_in_ndim == 0:
self.mean_anomaly = self.mean_anomaly[0]
self.eccentricity = self.eccentricity[0]
self.tref_out = self.tref_out[0]

if fref_in is not None and self.fref_in_ndim == 0:
self.fref_out = self.fref_out[0]

if self.debug_plots:
# make a plot for diagnostics
fig, axes = self.make_diagnostic_plots()
self.save_debug_fig(fig, f"gwecc_{self.method}_diagnostics.pdf")
plt.close(fig)
return_dict = {"eccentricity": self.eccentricity,
"mean_anomaly": self.mean_anomaly}
if fref_in is not None:
return_dict.update({"fref_out": self.fref_out})
else:
return_dict.update({"tref_out": self.tref_out})
return return_dict
# return measured eccentricity, mean anomaly and reference time or
# frequency where these are measured.
return self.make_return_dict_for_eccentricity_and_mean_anomaly()

def set_eccentricity_and_mean_anomaly_to_zero(self, tref_in):
def set_eccentricity_and_mean_anomaly_to_zero(self):
"""Set eccentricity and mean_anomaly to zero."""
if tref_in is not None:
if self.domain == "time":
# This function sets eccentricity and mean anomaly to zero
# when a method fails to detect any apoceneters or pericenrers,
# and therefore in such cases, we can set the tref_out to be
# the times that falls within the range of self.t.
ref_arr = self.tref_in[
self.tref_out = self.tref_in[
np.logical_and(self.tref_in >= min(self.t),
self.tref_in <= max(self.t))]
if len(ref_arr) == 0:
out_len = len(self.tref_out)
if out_len == 0:
# check that tref_in is in the allowed range
self.check_input_limits(
self.tref_in, min(self.t), max(self.t), "time")
# To match the type of tref_in
ref_arr = ref_arr[0] if self.tref_in_ndim == 0 else ref_arr
# Finally make tref_out available to self
self.tref_out = ref_arr
self.tref_in, min(self.t), max(self.t))
else:
# Since we don't have the maximum and minimum allowed reference
# frequencies computed from the frequencies at the pericenetrs and
# apoceneters, we simply set the maximum and minimum value to be
# the maximum and minimum of instantaneous f22, respectively.
f22_min = min(self.omega22) / (2 * np.pi)
f22_max = max(self.omega22) / (2 * np.pi)
ref_arr = self.fref_in[
self.fref_out = self.fref_in[
np.logical_and(self.fref_in >= f22_min,
self.fref_in <= f22_max)]
# check that fref_in is in the allowed range
if len(ref_arr) == 0:
out_len = len(self.fref_out)
if out_len == 0:
# check that fref_in is in the allowed range
self.check_input_limits(
self.fref_in, f22_min, f22_max, "frequency")
# To match the type of the fref_in.
ref_arr = ref_arr[0] if self.fref_in_ndim == 0 else ref_arr
self.fref_out = ref_arr
# At the top of measure_ecc we set tref_in_ndim and fref_in_ndim, and
# even in case of fref_in, we set tref_in_ndim to the same as
# fref_in_ndim, therefore, below we can just use tref_in_ndim.
self.eccentricity \
= 0 if self.tref_in_ndim == 0 else np.zeros(len(ref_arr))
self.mean_anomaly \
= 0 if self.tref_in_ndim == 0 else np.zeros(len(ref_arr))
return {
"tref_out" if tref_in is not None else "fref_out": ref_arr,
self.fref_in, f22_min, f22_max)
self.eccentricity = np.zeros(out_len)
self.mean_anomaly = np.zeros(out_len)
return self.make_return_dict_for_eccentricity_and_mean_anomaly()

def make_return_dict_for_eccentricity_and_mean_anomaly(self):
"""Prepare a dictionary with reference time/freq, ecc and mean ano.
In this function, we prepare a dictionary containing the measured
eccentricity, mean anomaly and the reference time or frequency where
these are measured at.
We also make sure that if the input reference time/frequency is scalar
then the returned eccentricity and mean anomaly is also a scalar. To do
this, we use the information about the tref_in/fref_in that is provided
by the user. At the top of measure_ecc we set ref_ndim to identify
whether the original input was scalar or array-like and use that here.
"""
if self.ref_ndim == 0:
self.eccentricity = self.eccentricity[0]
self.mean_anomaly = self.mean_anomaly[0]
if self.domain == "time":
self.tref_out = self.tref_out[0]
else:
self.fref_out = self.fref_out[0]

return_dict = {
"eccentricity": self.eccentricity,
"mean_anomaly": self.mean_anomaly
}
if self.domain == "time":
return_dict.update({
"tref_out": self.tref_out
})
else:
return_dict.update({
"fref_out": self.fref_out
})
return return_dict

def et_from_ew22_0pn(self, ew22):
"""Get temporal eccentricity at Newtonian order.
Expand Down Expand Up @@ -1500,7 +1508,7 @@ def compute_eccentricity(self, t):
Eccentricity at t.
"""
# Check that t is within tmin and tmax to avoid extrapolation
self.check_input_limits(t, self.tmin, self.tmax, "time")
self.check_input_limits(t, self.tmin, self.tmax)

omega22_pericenter_at_t = self.omega22_pericenters_interp(t)
omega22_apocenter_at_t = self.omega22_apocenters_interp(t)
Expand All @@ -1527,7 +1535,7 @@ def derivative_of_eccentricity(self, t, n=1):
nth order time derivative of eccentricity.
"""
# Check that t is within tmin and tmax to avoid extrapolation
self.check_input_limits(t, self.tmin, self.tmax, "time")
self.check_input_limits(t, self.tmin, self.tmax)

if self.ecc_for_checks is None:
self.ecc_for_checks = self.compute_eccentricity(
Expand Down Expand Up @@ -1566,7 +1574,7 @@ def compute_mean_anomaly(self, t):
Mean anomaly at t.
"""
# Check that t is within tmin and tmax to avoid extrapolation
self.check_input_limits(t, self.tmin, self.tmax, "time")
self.check_input_limits(t, self.tmin, self.tmax)

# Get the mean anomaly at the pericenters
mean_ano_pericenters = np.arange(len(self.t_pericenters)) * 2 * np.pi
Expand All @@ -1576,8 +1584,7 @@ def compute_mean_anomaly(self, t):
# Modulo 2pi to make the mean anomaly vary between 0 and 2pi
return mean_ano % (2 * np.pi)

def check_input_limits(self, input_vals, min_allowed_val, max_allowed_val,
input_type):
def check_input_limits(self, input_vals, min_allowed_val, max_allowed_val):
"""Check that the input time/frequency is within allowed range.
To avoid any extrapolation, check that the times or frequencies are
Expand All @@ -1597,33 +1604,28 @@ def check_input_limits(self, input_vals, min_allowed_val, max_allowed_val,
max_allowed_val: float
Maximum allowed time or frequency where eccentricity/mean anomaly
can be measured.
input_type: str
Description of the input. Can be tref_in or fref_in
"""
if input_type not in ["time", "frequency"]:
raise ValueError("Input type must be `time` or `frequency`.")
input_vals = np.atleast_1d(input_vals)
add_extra_info = (input_type == "time" and
add_extra_info = (self.domain == "time" and
not any([self.probably_quasicircular_apocenter,
self.probably_quasicircular_pericenter]))
if any(input_vals > max_allowed_val):
message = (f"Found reference {input_type} later than maximum "
f"allowed {input_type}={max_allowed_val}")
message = (f"Found reference {self.domain} later than maximum "
f"allowed {self.domain}={max_allowed_val}")
if add_extra_info:
message += (" which corresponds to min(last pericenter "
"time, last apocenter time).")
raise NotInAllowedInputRange(
"Reference " + input_type, min_allowed_val, max_allowed_val,
f"Reference {self.domain}", min_allowed_val, max_allowed_val,
message)
if any(input_vals < min_allowed_val):
message = (f"Found reference {input_type} earlier than minimum "
f"allowed {input_type}={min_allowed_val}")
message = (f"Found reference {self.domain} earlier than minimum "
f"allowed {self.domain}={min_allowed_val}")
if add_extra_info:
message += (" which corresponds to max(first pericenter "
"time, first apocenter time).")
raise NotInAllowedInputRange(
"Reference " + input_type, min_allowed_val, max_allowed_val,
f"Reference {self.domain}", min_allowed_val, max_allowed_val,
message)

def check_extrema_separation(self, extrema_location,
Expand Down Expand Up @@ -2318,16 +2320,10 @@ def get_fref_out(self, fref_in, method):
np.logical_and(fref_in >= fref_min,
fref_in < fref_max)]
if len(fref_out) == 0:
if fref_in[0] < fref_min:
raise Exception("fref_in is earlier than minimum available "
f"frequency {fref_min}")
if fref_in[-1] > fref_max:
raise Exception("fref_in is later than maximum available "
f"frequency {fref_max}")
else:
raise Exception("fref_out is empty. This can happen if the "
"waveform has insufficient identifiable "
"pericenters/apocenters.")
self.check_input_limits(fref_in, fref_min, fref_max)
raise Exception("fref_out is empty. This can happen if the "
"waveform has insufficient identifiable "
"pericenters/apocenters.")
return fref_out

def make_diagnostic_plots(
Expand Down

0 comments on commit bf3d9c5

Please sign in to comment.