Skip to content

Commit

Permalink
Merge pull request #32 from rahil-makadia/dev
Browse files Browse the repository at this point in the history
minor bugfix for analytic simultaneous nongrav fits
  • Loading branch information
rahil-makadia authored Dec 5, 2023
2 parents 0d25a6c + 8d08094 commit f8f8c82
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ on:
paths:
- "grss/version.txt"

permissions:
id-token: write
contents: read

jobs:
build-and-upload:
runs-on: ubuntu-latest
Expand Down
15 changes: 7 additions & 8 deletions grss/fit/fit_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ def __init__(self, x_init, cov_init=None, obs_array_optical=None, observer_codes
self.covariance = None
self.obs_array = None
self.observer_codes = None
self.observer_info = None
self.check_initial_solution(x_init, cov_init)
self.check_input_observation_arrays(obs_array_optical, observer_codes_optical,
obs_array_radar, observer_codes_radar)
Expand Down Expand Up @@ -781,6 +782,7 @@ def assemble_observation_arrays(self):
self.obs_array_radar[:, 0])
self.observer_codes = tuple(np.array(self.observer_codes_radar,
dtype=tuple)[sort_idx])
self.observer_info = get_observer_info(self.observer_codes)
# number of observations is the number of non-nan values
# in the second and third columns of the observation array
self.n_obs = np.count_nonzero(~np.isnan(self.obs_array[:, 1:3]))
Expand Down Expand Up @@ -953,7 +955,7 @@ def get_prop_sims(self, name):
t_eval_utc = True
eval_apparent_state = True
converged_light_time = True
observer_info = np.array(get_observer_info(self.observer_codes), dtype=tuple)
observer_info = np.array(self.observer_info, dtype=tuple)
observer_info_past = tuple(observer_info[self.past_obs_idx])
observer_info_future = tuple(observer_info[self.future_obs_idx])
prop_sim_past = None
Expand Down Expand Up @@ -1300,9 +1302,8 @@ def get_computed_obs(self, prop_sim_past, prop_sim_future, integ_body_idx):
radar_obs = radar_obs[:,integ_body_idx]
measured_obs = self.obs_array[:, 1:3]
computed_obs = np.nan*np.ones_like(measured_obs)
observer_info = get_observer_info(self.observer_codes)
for i in range(len(self.obs_array)):
obs_info_len = len(observer_info[i])
obs_info_len = len(self.observer_info[i])
if obs_info_len in {4, 7}:
computed_obs[i, :] = optical_obs[i]
elif obs_info_len == 9: # delay measurement
Expand All @@ -1319,7 +1320,7 @@ def _get_analytic_stm(self, t_eval, prop_sim):
if self.fit_cometary:
stm[:, 3:6] /= 180.0/np.pi # covert partial w.r.t. rad -> partial w.r.t. deg
if len(stm_state_full) > 42:
param_block = stm_state_full[42:].reshape((6, -1))
param_block = stm_state_full[42:].reshape((6, -1), order='F')
stm = np.hstack((stm, param_block))
num_params = param_block.shape[1]
if num_params > 0:
Expand All @@ -1345,12 +1346,11 @@ def get_analytic_partials(self, prop_sim_past, prop_sim_future):
NotImplementedError
Because analytic partials are not yet implemented.
"""
observer_info = get_observer_info(self.observer_codes)
partials = np.zeros((self.n_obs, self.n_fit))
len_past_idx = len(self.past_obs_idx) if self.past_obs_exist else 0
partials_idx = 0
for i in range(self.obs_array.shape[0]):
obs_info_len = len(observer_info[i])
obs_info_len = len(self.observer_info[i])
if obs_info_len in {4, 7}:
is_optical = True
size = 2
Expand Down Expand Up @@ -1498,12 +1498,11 @@ def apply_outlier_rejection(self, partials, weights, residuals):
chi_reject = 3.0
chi_recover = 2.8
full_cov = np.linalg.inv(partials.T @ self.obs_weight @ partials)
observer_info = get_observer_info(self.observer_codes)
j = 0
residual_chi_squared = np.zeros(len(self.obs_array))
rejected_indices = []
for i in range(len(self.obs_array)):
obs_info_len = len(observer_info[i])
obs_info_len = len(self.observer_info[i])
if obs_info_len in {4, 7}:
size = 2
elif obs_info_len in {9, 10}:
Expand Down
2 changes: 1 addition & 1 deletion grss/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.0
2.0.2

0 comments on commit f8f8c82

Please sign in to comment.