Skip to content

Commit

Permalink
BT_piecewise
Browse files Browse the repository at this point in the history
  • Loading branch information
abhisrkckl committed Jan 22, 2025
1 parent 6909d5b commit dd3c152
Showing 1 changed file with 72 additions and 89 deletions.
161 changes: 72 additions & 89 deletions src/pint/models/stand_alone_psr_binaries/BT_piecewise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
import astropy.units as u
import numpy as np

import pint.toa
from pint import GMsun, Tsun, ls
from pint import ls
from pint.models.stand_alone_psr_binaries.BT_model import BTmodel

from .binary_generic import PSR_BINARY


class BTpiecewise(BTmodel):
"""
Expand Down Expand Up @@ -82,9 +79,8 @@ def __init__(self, axis_store_initial=None, t=None, input_params=None):
self.d_binarydelay_d_par_funcs = [self.d_BTdelay_d_par]
if t is not None:
self._t = t
if input_params is not None:
if self.T0X is None:
self.update_input(input_params)
if input_params is not None and self.T0X is None:
self.update_input(input_params)
self.binary_params = list(self.param_default_value.keys())

def set_param_values(self, valDict=None):
Expand All @@ -98,8 +94,6 @@ def setup_internal_structures(self, valDict=None):
# initialise arrays to store piecewise group boundaries
self.lower_group_edge = []
self.upper_group_edge = []
# initialise array that will be 5 x n in shape. Where n is the number of pieces required by the model
piecewise_parameter_information = []
# If there are no updates passed by binary_instance, sets default value (usually overwritten when reading from parfile)

if valDict is None:
Expand All @@ -120,42 +114,40 @@ def setup_internal_structures(self, valDict=None):
# Searches through updates for keys prefixes matching T0X/A1X, can be allowed to be more flexible with param+"X_" provided param is defined earlier.
for key, value in valDict.items():
if (
key[0:4] == "T0X_"
or key[0:4] == "A1X_"
and not (key[4:8] in piece_index)
key[:4] == "T0X_"
or key[:4] == "A1X_"
and key[4:8] not in piece_index
):
# appends index to array
piece_index.append((key[4:8]))
# makes sure only one instance of each index is present returns order indeces
piece_index = np.unique(piece_index)
# initialise array that will be 5 x n in shape. Where n is the number of pieces required by the model
piecewise_parameter_information = []
# looping through each index in order they are given (0 -> n)
for index in piece_index:
# array to store specific piece i's information in the order [index,T0X,A1X,Group's lower edge, Group's upper edge,]
param_pieces = []
piece_number = f"{int(index):04d}"
param_pieces.append(piece_number)
param_pieces = [piece_number]
string = [
"T0X_" + index,
"A1X_" + index,
"XR1_" + index,
"XR2_" + index,
f"T0X_{index}",
f"A1X_{index}",
f"XR1_{index}",
f"XR2_{index}",
]

# if string[0] not in param_pieces:
for i in range(0, len(string)):
if string[i] in valDict:
param_pieces.append(valDict[string[i]])
elif string[i] not in valDict:
attr = string[i][0:2]
for item in string:
if item in valDict:
param_pieces.append(valDict[item])
else:
attr = item[:2]

if hasattr(self, attr):
param_pieces.append(getattr(self, attr))
else:
raise AttributeError(
"Malformed valDict being used, attempting to set an attribute that doesn't exist. Likely a corner case slipping through validate() in binary_piecewise."
)
# Raises error if range not defined as there is no Piece upper/lower bound in the model.

piecewise_parameter_information.append(param_pieces)

self.valDict = valDict
Expand All @@ -166,14 +158,13 @@ def setup_internal_structures(self, valDict=None):
)

# Uses the index for each toa array to create arrays where elements are the A1X/T0X to use with that toa
if len(self.piecewise_parameter_information) > 0:
if self._t is not None:
self.group_index_array = self.toa_belongs_in_group(self._t)
if len(self.piecewise_parameter_information) > 0 and self._t is not None:
self.group_index_array = self.toa_belongs_in_group(self._t)

(
self.T0X_per_toa,
self.A1X_per_toa,
) = self.piecewise_parameter_from_information_array(self._t)
(
self.T0X_per_toa,
self.A1X_per_toa,
) = self.piecewise_parameter_from_information_array(self._t)

def piecewise_parameter_from_information_array(self, t):
"""Creates a list of piecewise orbital parameters to use in calculations. It is the same dimensions as the TOAs loaded in. Each entry is the piecewise parameter value from the group it belongs to.
Expand All @@ -192,8 +183,8 @@ def piecewise_parameter_from_information_array(self, t):
if len(self.group_index_array) != len(t):
self.group_index_array = self.toa_belongs_in_group(t)
# searches the 5 x n array to find the index matching the toa_index
possible_groups = [item[0] for item in self.piecewise_parameter_information]
if len(self.group_index_array) > 1 and len(t) > 1:
possible_groups = [item[0] for item in self.piecewise_parameter_information]
for i in self.group_index_array:
if i != -1:
for k, j in enumerate(possible_groups):
Expand Down Expand Up @@ -247,10 +238,7 @@ def toa_belongs_in_group(self, toas):
for i in toas.value:
lower_bound = np.searchsorted(np.array(lower_edge), i) - 1
upper_bound = np.searchsorted(np.array(upper_edge), i)
if lower_bound == upper_bound:
index_no = lower_bound
else:
index_no = -1
index_no = lower_bound if lower_bound == upper_bound else -1
if index_no != -1:
group_no.append(self.piecewise_parameter_information[index_no][0])
else:
Expand All @@ -265,10 +253,10 @@ def get_group_boundaries(self):
list (length: number of pieces). Contains all pieces' lower edge.
list (length: number of pieces). Contains all pieces' upper edge.
"""
lower_group_edge = []
upper_group_edge = []
if hasattr(self, "piecewise_parameter_information"):
for i in range(0, len(self.piecewise_parameter_information)):
lower_group_edge = []
upper_group_edge = []
for i in range(len(self.piecewise_parameter_information)):
lower_group_edge.append(self.piecewise_parameter_information[i][3])
upper_group_edge.append(self.piecewise_parameter_information[i][4])
return [lower_group_edge, upper_group_edge]
Expand All @@ -281,11 +269,11 @@ def a1(self):
1
]

if hasattr(self, "A1X_per_toa"):
ret = self.A1X_per_toa + self.tt0 * self.A1DOT
else:
ret = self.A1 + self.tt0 * self.A1DOT
return ret
return (
self.A1X_per_toa + self.tt0 * self.A1DOT
if hasattr(self, "A1X_per_toa")
else self.A1 + self.tt0 * self.A1DOT
)

def get_tt0(self, barycentricTOA):
"""Finds (barycentricTOA - T0_x). Where T0_x is the piecewise T0 value, if it exists, correponding to the group the TOA belongs to. If T0_x does not exist, use the global T0 vlaue.
Expand All @@ -306,58 +294,55 @@ def get_tt0(self, barycentricTOA):
T0 = self.T0X_per_toa
else:
T0 = self.T0
if not hasattr(barycentricTOA, "unit") or barycentricTOA.unit == None:
if not hasattr(barycentricTOA, "unit") or barycentricTOA.unit is None:
barycentricTOA = barycentricTOA * u.day
tt0 = (barycentricTOA - T0).to("second")
return tt0
return (barycentricTOA - T0).to("second")

def d_delayL1_d_par(self, par):
if par not in self.binary_params:
raise ValueError(f"{par} is not in binary parameter list.")
par_obj = getattr(self, par)
index, par_temp = self.in_piece(par)
if par_temp is None:
if hasattr(self, "d_delayL1_d_" + par):
func = getattr(self, "d_delayL1_d_" + par)
return func() * index
else:
if par in self.orbits_cls.orbit_params:
return self.d_delayL1_d_E() * self.d_E_d_par(par)
else:
return np.zeros(len(self.t)) * u.second / par_obj.unit
if not hasattr(self, f"d_delayL1_d_{par}"):
return (
self.d_delayL1_d_E() * self.d_E_d_par(par)
if par in self.orbits_cls.orbit_params
else np.zeros(len(self.t)) * u.second / par_obj.unit
)
func = getattr(self, f"d_delayL1_d_{par}")
return func() * index
elif hasattr(self, f"d_delayL1_d_{par_temp}"):
func = getattr(self, f"d_delayL1_d_{par_temp}")
return func() * index
else:
if hasattr(self, "d_delayL1_d_" + par_temp):
func = getattr(self, "d_delayL1_d_" + par_temp)
return func() * index
if par in self.orbits_cls.orbit_params:
return self.d_delayL1_d_E() * self.d_E_d_par()
else:
if par in self.orbits_cls.orbit_params:
return self.d_delayL1_d_E() * self.d_E_d_par()
else:
return np.zeros(len(self.t)) * u.second / par_obj.unit
return np.zeros(len(self.t)) * u.second / par_obj.unit

def d_delayL2_d_par(self, par):
if par not in self.binary_params:
raise ValueError(f"{par} is not in binary parameter list.")
par_obj = getattr(self, par)
index, par_temp = self.in_piece(par)
if par_temp is None:
if hasattr(self, "d_delayL2_d_" + par):
func = getattr(self, "d_delayL2_d_" + par)
return func() * index
else:
if par in self.orbits_cls.orbit_params:
return self.d_delayL2_d_E() * self.d_E_d_par(par)
else:
return np.zeros(len(self.t)) * u.second / par_obj.unit
if not hasattr(self, f"d_delayL2_d_{par}"):
return (
self.d_delayL2_d_E() * self.d_E_d_par(par)
if par in self.orbits_cls.orbit_params
else np.zeros(len(self.t)) * u.second / par_obj.unit
)
func = getattr(self, f"d_delayL2_d_{par}")
return func() * index
elif hasattr(self, f"d_delayL2_d_{par_temp}"):
func = getattr(self, f"d_delayL2_d_{par_temp}")
return func() * index
else:
if hasattr(self, "d_delayL2_d_" + par_temp):
func = getattr(self, "d_delayL2_d_" + par_temp)
return func() * index
if par in self.orbits_cls.orbit_params:
return self.d_delayL2_d_E() * self.d_E_d_par()
else:
if par in self.orbits_cls.orbit_params:
return self.d_delayL2_d_E() * self.d_E_d_par()
else:
return np.zeros(len(self.t)) * u.second / par_obj.unit
return np.zeros(len(self.t)) * u.second / par_obj.unit

def in_piece(self, par):
"""Finds which TOAs reference which piecewise binary parameter group using the group_index_array property.
Expand Down Expand Up @@ -437,11 +422,11 @@ def prtl_der(self, y, x):
The derivatives pdy/pdx
"""
if y not in self.binary_params + self.inter_vars:
errorMesg = y + " is not in binary parameter and variables list."
errorMesg = f"{y} is not in binary parameter and variables list."
raise ValueError(errorMesg)

if x not in self.inter_vars + self.binary_params:
errorMesg = x + " is not in binary parameters and variables list."
errorMesg = f"{x} is not in binary parameters and variables list."
raise ValueError(errorMesg)
# derivative to itself
if x == y:
Expand All @@ -454,23 +439,21 @@ def prtl_der(self, y, x):
# If attr is a PINT Parameter class type
if hasattr(attr, "units"):
U[i] = attr.units
# If attr is a Quantity type
elif hasattr(attr, "unit"):
U[i] = attr.unit
# If attr is a method
elif hasattr(attr, "__call__"):
U[i] = attr().unit
else:
raise TypeError(type(attr) + "can not get unit")
raise TypeError(f"{type(attr)}can not get unit")
yU = U[0]
xU = U[1]
# Call derivtive functions
derU = yU / xU
if hasattr(self, "d_" + y + "_d_" + x):
dername = "d_" + y + "_d_" + x
if hasattr(self, f"d_{y}_d_{x}"):
dername = f"d_{y}_d_{x}"
result = getattr(self, dername)()
elif hasattr(self, "d_" + y + "_d_par"):
dername = "d_" + y + "_d_par"
elif hasattr(self, f"d_{y}_d_par"):
dername = f"d_{y}_d_par"
result = getattr(self, dername)(x)
else:
result = np.longdouble(np.zeros(len(self.tt0)))
Expand All @@ -490,7 +473,7 @@ def d_M_d_par(self, par):
Derivitve of M respect to par
"""
if par not in self.binary_params:
errorMesg = par + " is not in binary parameter list."
errorMesg = f"{par} is not in binary parameter list."
raise ValueError(errorMesg)
par_obj = getattr(self, par)
result = self.orbits_cls.d_orbits_d_par(par)
Expand Down

0 comments on commit dd3c152

Please sign in to comment.