-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #17 from esa/ml-dsgp4
ml-dsgp4: python file, tutorials, and docs
- Loading branch information
Showing
16 changed files
with
436 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,51 +1,64 @@ | ||
.. _api: | ||
|
||
API | ||
==== | ||
####### | ||
|
||
$\partial$SGP4 API | ||
|
||
.. currentmodule:: dsgp4 | ||
|
||
.. autosummary:: | ||
:toctree: _autosummary | ||
:toctree: _autosummary/ | ||
:recursive: | ||
|
||
dsgp4 | ||
dsgp4.plot.plot_orbit | ||
dsgp4.plot.plot_tles | ||
dsgp4.tle.compute_checksum | ||
dsgp4.tle.read_satellite_catalog_number | ||
dsgp4.tle.load_from_lines | ||
dsgp4.tle.load_from_data | ||
dsgp4.tle.load | ||
dsgp4.tle.TLE | ||
dsgp4.tle.TLE.copy | ||
dsgp4.tle.TLE.perigee_alt | ||
dsgp4.tle.TLE.apogee_alt | ||
dsgp4.tle.TLE.set_time | ||
dsgp4.tle.TLE.update | ||
dsgp4.util.get_gravity_constants | ||
dsgp4.util.propagate_batch | ||
dsgp4.util.propagate | ||
dsgp4.util.initialize_tle | ||
dsgp4.util.from_year_day_to_date | ||
dsgp4.util.gstime | ||
dsgp4.util.clone_w_grad | ||
dsgp4.util.jday | ||
dsgp4.util.invjday | ||
dsgp4.util.days2mdhms | ||
dsgp4.util.from_string_to_datetime | ||
dsgp4.util.from_mjd_to_epoch_days_after_1_jan | ||
dsgp4.util.from_mjd_to_datetime | ||
dsgp4.util.from_jd_to_datetime | ||
dsgp4.util.get_non_empty_lines | ||
dsgp4.util.from_datetime_to_fractional_day | ||
dsgp4.util.from_datetime_to_mjd | ||
dsgp4.util.from_datetime_to_jd | ||
dsgp4.util.from_cartesian_to_tle_elements | ||
dsgp4.util.from_cartesian_to_keplerian | ||
dsgp4.util.from_cartesian_to_keplerian_torch | ||
dsgp4.sgp4 | ||
dsgp4.sgp4_batched | ||
dsgp4.sgp4init.sgp4init | ||
dsgp4.sgp4init_batch.sgp4init_batch | ||
dsgp4.sgp4init_batch.initl_batch | ||
plot.plot_orbit | ||
plot.plot_tles | ||
tle.compute_checksum | ||
tle.read_satellite_catalog_number | ||
tle.load_from_lines | ||
tle.load_from_data | ||
tle.load | ||
tle.TLE | ||
tle.TLE.copy | ||
tle.TLE.perigee_alt | ||
tle.TLE.apogee_alt | ||
tle.TLE.set_time | ||
tle.TLE.update | ||
util.get_gravity_constants | ||
util.propagate_batch | ||
util.propagate | ||
util.initialize_tle | ||
util.from_year_day_to_date | ||
util.gstime | ||
util.clone_w_grad | ||
util.jday | ||
util.invjday | ||
util.days2mdhms | ||
util.from_string_to_datetime | ||
util.from_mjd_to_epoch_days_after_1_jan | ||
util.from_mjd_to_datetime | ||
util.from_jd_to_datetime | ||
util.get_non_empty_lines | ||
util.from_datetime_to_fractional_day | ||
util.from_datetime_to_mjd | ||
util.from_datetime_to_jd | ||
util.from_cartesian_to_tle_elements | ||
util.from_cartesian_to_keplerian | ||
util.from_cartesian_to_keplerian_torch | ||
sgp4 | ||
sgp4_batched | ||
sgp4init.sgp4init | ||
sgp4init_batch.sgp4init_batch | ||
sgp4init_batch.initl_batch | ||
initl | ||
newton_method | ||
sgp4init | ||
sgp4init_batch | ||
|
||
.. currentmodule:: dsgp4 | ||
|
||
.. toctree:: | ||
:maxdepth: 2 | ||
:caption: dsgp4 ML-dSGP4 Module | ||
|
||
dsgp4.mldsgp4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,9 +7,9 @@ | |
"source": [ | ||
"# Credits\n", | ||
"\n", | ||
"$\\partial\\textrm{SGP4}$ was developed during a project sponsored by the University of Oxford, while Giacomo Acciarini was at the [OX4AILab](https://oxai4science.github.io/) collaborating with Dr. Atılım Güneş Baydin.\n", | ||
"$\\partial\\textrm{SGP4}$ was developed during a project sponsored by the University of Oxford, while Giacomo Acciarini was at the [Oxford AI4Science Lab](https://oxai4science.github.io/) collaborating with Dr. Atılım Güneş Baydin.\n", | ||
"\n", | ||
"The main developers are: Giacomo Acciarini ( [email protected] ), Atılım Güneş Baydin ( [email protected] )." | ||
"The main developers is: Giacomo Acciarini ( [email protected] )." | ||
] | ||
} | ||
], | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
.. _mldsgp4: | ||
|
||
mldsgp4 model | ||
############## | ||
|
||
This module defines the ``mldsgp4`` class within the ``dsgp4`` library. | ||
|
||
.. currentmodule:: dsgp4 | ||
|
||
.. autoclass:: dsgp4.mldsgp4.mldsgp4 | ||
:members: __init__, forward, load_model | ||
:undoc-members: | ||
:exclude-members: __del__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
$\partial\textrm{SGP4}$ Documentation | ||
================================ | ||
|
||
**dsgp4** is a differentiable SGP4 program written leveraging the [PyTorch](https://pytorch.org/) machine learning framework: this enables features like automatic differentiation and batch propagation (across different TLEs) that were not previously available in the original implementation. | ||
**dsgp4** is a differentiable SGP4 program written leveraging the [PyTorch](https://pytorch.org/) machine learning framework: this enables features like automatic differentiation and batch propagation (across different TLEs) that were not previously available in the original implementation. Furthermore, it also offers a hybrid propagation scheme called ML-dSGP4 where dSGP4 and ML models can be combined to enhance SGP4 accuracy when higher-precision simulated (e.g. from a numerical integrator) or observed (e.g. from ephemerides) data is available. | ||
|
||
For more details on the model and results, check out our publication: [Acciarini, Giacomo, Atılım Güneş Baydin, and Dario Izzo. "*Closing the Gap Between SGP4 and High-Precision Propagation via Differentiable Programming*" (2024) Vol. 226(1), pages: 694-701](https://doi.org/10.1016/j.actaastro.2024.10.063) | ||
|
||
|
||
The authors are [Giacomo Acciarini](https://www.esa.int/gsp/ACT/team/giacomo_acciarini/), [Atılım Güneş Baydin](https://gbaydin.github.io/), [Dario Izzo](https://www.esa.int/gsp/ACT/team/dario_izzo/). The main developer is Giacomo Acciarini ([email protected]). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from .util import initialize_tle, propagate, propagate_batch | ||
from torch.nn.parameter import Parameter | ||
|
||
class mldsgp4(nn.Module): | ||
def __init__(self, | ||
normalization_R=6958.137, | ||
normalization_V=7.947155867983262, | ||
hidden_size=100, | ||
input_correction=1e-2, | ||
output_correction=0.8): | ||
""" | ||
This class implements the ML-dSGP4 model, where dSGP4 inputs and outputs are corrected via neural networks, | ||
better match simulated or observed higher-precision data. | ||
Parameters: | ||
---------------- | ||
normalization_R (``float``): normalization constant for x,y,z coordinates. | ||
normalization_V (``float``): normalization constant for vx,vy,vz coordinates. | ||
hidden_size (``int``): number of neurons in the hidden layers. | ||
input_correction (``float``): correction factor for the input layer. | ||
output_correction (``float``): correction factor for the output layer. | ||
""" | ||
super().__init__() | ||
self.fc1=nn.Linear(6, hidden_size) | ||
self.fc2=nn.Linear(hidden_size,hidden_size) | ||
self.fc3=nn.Linear(hidden_size, 6) | ||
self.fc4=nn.Linear(6,hidden_size) | ||
self.fc5=nn.Linear(hidden_size, hidden_size) | ||
self.fc6=nn.Linear(hidden_size, 6) | ||
|
||
self.tanh = nn.Tanh() | ||
self.leaky_relu = nn.LeakyReLU(negative_slope=0.01) | ||
self.normalization_R=normalization_R | ||
self.normalization_V=normalization_V | ||
self.input_correction = Parameter(input_correction*torch.ones((6,))) | ||
self.output_correction = Parameter(output_correction*torch.ones((6,))) | ||
|
||
def forward(self, tles, tsinces): | ||
""" | ||
This method computes the forward pass of the ML-dSGP4 model. | ||
It can take either a single or a list of `dsgp4.tle.TLE` objects, | ||
and a torch.tensor of times since the TLE epoch in minutes. | ||
It then returns the propagated state in the TEME coordinate system. The output | ||
is normalized, to unnormalize and obtain km and km/s, you can use self.normalization_R constant for the position | ||
and self.normalization_V constant for the velocity. | ||
Parameters: | ||
---------------- | ||
tles (``dsgp4.tle.TLE`` or ``list``): a TLE object or a list of TLE objects. | ||
tsinces (``torch.tensor``): a torch.tensor of times since the TLE epoch in minutes. | ||
Returns: | ||
---------------- | ||
(``torch.tensor``): a tensor of len(tsince)x6 representing the corrected satellite position and velocity in normalized units (to unnormalize to km and km/s, use `self.normalization_R` for position, and `self.normalization_V` for velocity). | ||
""" | ||
is_batch=hasattr(tles, '__len__') | ||
if is_batch: | ||
#this is the batch case, so we proceed and initialize the batch: | ||
_,tles=initialize_tle(tles,with_grad=True) | ||
x0 = torch.stack((tles._ecco, tles._argpo, tles._inclo, tles._mo, tles._no_kozai, tles._nodeo), dim=1) | ||
else: | ||
#this handles the case in which a singlee TLE is passed | ||
initialize_tle(tles,with_grad=True) | ||
x0 = torch.stack((tles._ecco, tles._argpo, tles._inclo, tles._mo, tles._no_kozai, tles._nodeo), dim=0).reshape(-1,6) | ||
x=self.leaky_relu(self.fc1(x0)) | ||
x=self.leaky_relu(self.fc2(x)) | ||
x=x0*(1+self.input_correction*self.tanh(self.fc3(x))) | ||
#now we need to substitute them back into the tles: | ||
tles._ecco=x[:,0] | ||
tles._argpo=x[:,1] | ||
tles._inclo=x[:,2] | ||
tles._mo=x[:,3] | ||
tles._no_kozai=x[:,4] | ||
tles._nodeo=x[:,5] | ||
if is_batch: | ||
#we propagate the batch: | ||
states_teme=propagate_batch(tles,tsinces) | ||
else: | ||
states_teme=propagate(tles,tsinces) | ||
states_teme=states_teme.reshape(-1,6) | ||
#we now extract the output parameters to correct: | ||
x_out=torch.cat((states_teme[:,:3]/self.normalization_R, states_teme[:,3:]/self.normalization_V),dim=1) | ||
|
||
x=self.leaky_relu(self.fc4(x_out)) | ||
x=self.leaky_relu(self.fc5(x)) | ||
x=x_out*(1+self.output_correction*self.tanh(self.fc6(x))) | ||
return x | ||
|
||
def load_model(self, path, device='cpu'): | ||
""" | ||
This method loads a model from a file. | ||
Parameters: | ||
---------------- | ||
path (``str``): path to the file where the model is stored. | ||
device (``str``): device where the model will be loaded. Default is 'cpu'. | ||
""" | ||
self.load_state_dict(torch.load(path,map_location=torch.device(device))) | ||
self.eval() |
Oops, something went wrong.