Skip to content

Commit

Permalink
Merge pull request #13 from B612-Asteroid-Institute/impactor-debugging
Browse files Browse the repository at this point in the history
Impactor debugging
  • Loading branch information
akoumjian authored Jan 17, 2025
2 parents ccaf6d4 + a09ded5 commit ebbbbf8
Show file tree
Hide file tree
Showing 13 changed files with 330 additions and 130 deletions.
6 changes: 3 additions & 3 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [
{name = "Joachim Moeyens", email = "[email protected]"},
]
dependencies = [
"matplotlib",
"matplotlib>=3.10.0",
"adam_assist>=0.2.0",
"ray[default]==2.39.0",
"sorcha @ git+https://github.com/B612-Asteroid-Institute/sorcha.git@cd5be9a06c6d24e1277cf2345bda9984f7097ede",
Expand Down
2 changes: 2 additions & 0 deletions src/adam_impact_study/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# ruff: noqa F401
from .main import summarize_impact_study_results
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import quivr as qv
from adam_core.time import Timestamp

from adam_impact_study.analysis.plots import make_analysis_plots
from adam_impact_study.types import (
ImpactorOrbits,
ImpactorResultSummary,
Expand Down Expand Up @@ -104,7 +105,6 @@ class DiscoveryDates(qv.Table):


def compute_discovery_dates(
impactor_orbits: ImpactorOrbits,
results: WindowResult,
) -> DiscoveryDates:
"""
Expand Down Expand Up @@ -484,7 +484,7 @@ def summarize_impact_study_object_results(
maximum_impact_probability=[0],
)

discovery_dates = compute_discovery_dates(impactor_orbits, impact_results)
discovery_dates = compute_discovery_dates(impact_results)
warning_times = compute_warning_time(impactor_orbits, impact_results)
realization_times = compute_realization_time(
impactor_orbits, impact_results, discovery_dates
Expand Down Expand Up @@ -522,10 +522,14 @@ def summarize_impact_study_object_results(
)


def summarize_impact_study_results(run_dir: str) -> ImpactorResultSummary:
def summarize_impact_study_results(
run_dir: str, out_dir: str, plot: bool = True
) -> ImpactorResultSummary:
"""
Summarize the impact study results.
"""
assert run_dir != out_dir, "run_dir and out_dir must be different"

orbit_ids = [os.path.basename(dir) for dir in glob.glob(f"{run_dir}/*")]
results = ImpactorResultSummary.empty()
for orbit_id in orbit_ids:
Expand All @@ -534,95 +538,9 @@ def summarize_impact_study_results(run_dir: str) -> ImpactorResultSummary:
[results, summarize_impact_study_object_results(run_dir, orbit_id)]
)

return results


def plot_discovery_by_h_mag(impactor_results: ImpactorResultSummary) -> None:
"""
Plot the discovery time by H-mag for each object.
"""

fig, ax = plt.subplots()

# grab max and min h_mag
h_mag_max = pc.max(impactor_results.orbit.H_r)
h_mag_min = pc.min(impactor_results.orbit.H_r)

discovered_mask = impactor_results.discovered()
discovered = impactor_results.apply_mask(discovered_mask)
not_discovered = impactor_results.apply_mask(pc.invert(discovered_mask))

ax.hist(
discovered.orbit.H_r,
bins=np.arange(pc.floor(h_mag_min).as_py(), pc.ceil(h_mag_max).as_py(), 0.5),
alpha=0.5,
label="Discovered",
color="blue",
)
ax.hist(
not_discovered.orbit.H_r,
bins=np.arange(pc.floor(h_mag_min).as_py(), pc.ceil(h_mag_max).as_py(), 0.5),
alpha=0.5,
label="Not Discovered",
color="red",
)
ax.set_xlabel("H-mag")
ax.set_ylabel("Number of Objects")
ax.legend()
plt.show()


def summarize_discovery_rates_by_diameter(
impactor_results: ImpactorResultSummary,
) -> pa.Table:
"""
Plot the discovery time by diameter for each object.
"""

discovered_mask = impactor_results.discovered()
table = impactor_results.flattened_table().append_column(
"discovered", discovered_mask
)
table = table.drop_columns(["orbit.coordinates.covariance.values"])
table_diameter_grouped = table.group_by(
["orbit.diameter", "orbit.ast_class"]
).aggregate([("discovered", "sum"), ("orbit.ast_class", "count")])
return table_diameter_grouped


def plot_q_vs_i(impactor_results: ImpactorResultSummary) -> None:
"""
Plot the discovery time by H-mag for each object.
"""

fig, ax = plt.subplots()

discovered_mask = impactor_results.discovered()
discovered = impactor_results.apply_mask(discovered_mask)
not_discovered = impactor_results.apply_mask(pc.invert(discovered_mask))

ax.scatter(
discovered.orbit.coordinates.to_keplerian().q,
discovered.orbit.coordinates.to_keplerian().i,
alpha=0.2,
label="Discovered",
color="blue",
)
ax.scatter(
not_discovered.orbit.coordinates.to_keplerian().q,
not_discovered.orbit.coordinates.to_keplerian().i,
alpha=0.2,
label="Not Discovered",
color="red",
)
ax.set_xlabel("q")
ax.set_ylabel("i")
ax.legend()
plt.show()


if __name__ == "__main__":
results = ImpactorResultSummary.from_parquet("demo/data/summarized_results.parquet")
os.makedirs(out_dir)
results.to_parquet(os.path.join(out_dir, "impactor_results_summary.parquet"))
logger.info(f"Saved impact study results to {out_dir}")

table_diameter_grouped = summarize_discovery_rates_by_diameter(results)
print(table_diameter_grouped)
if plot:
make_analysis_plots(results, out_dir)
174 changes: 174 additions & 0 deletions src/adam_impact_study/analysis/plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import logging
import os
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc

from adam_impact_study.types import ImpactorResultSummary

logger = logging.getLogger(__name__)


def plot_warning_time_histogram(
summary: ImpactorResultSummary,
) -> Tuple[plt.Figure, plt.Axes]:

fig, ax = plt.subplots(1, 1, dpi=200)

warning_time_max = pc.ceil(pc.max(summary.warning_time)).as_py() / 365.25
bins = np.arange(0, warning_time_max, 1)

unique_diameters = summary.orbit.diameter.unique().sort().to_pylist()
colors = plt.cm.coolwarm(np.linspace(0, 1, len(unique_diameters)))
for diameter, color in zip(unique_diameters, colors):

orbits_at_diameter = summary.select("orbit.diameter", diameter)

warning_time = orbits_at_diameter.warning_time.to_numpy(zero_copy_only=False)

ax.hist(
np.where(np.isnan(warning_time), 0, warning_time) / 365.25,
histtype="step",
label=f"{diameter:.3f} km",
color=color,
bins=bins,
density=True,
)

ax.set_xlim(0, warning_time_max)
ax.set_xticks(np.arange(0, warning_time_max + 20, 20))
ax.legend(frameon=False, bbox_to_anchor=(1.01, 0.75))
ax.set_xlabel("Warning Time for Discoveries [years]")
ax.set_ylabel("PDF")

return fig, ax


def plot_realization_time_histogram(
summary: ImpactorResultSummary,
) -> Tuple[plt.Figure, plt.Axes]:

fig, ax = plt.subplots(1, 1, dpi=200)

realization_time_max = pc.ceil(pc.max(summary.realization_time)).as_py()
if realization_time_max > 100:
realization_time_max = 100
bins = np.linspace(0, realization_time_max, 100)

unique_diameters = summary.orbit.diameter.unique().sort().to_pylist()
colors = plt.cm.coolwarm(np.linspace(0, 1, len(unique_diameters)))
for diameter, color in zip(unique_diameters, colors):

orbits_at_diameter = summary.select("orbit.diameter", diameter)

realization_time = orbits_at_diameter.realization_time.to_numpy(
zero_copy_only=False
)

ax.hist(
realization_time[~np.isnan(realization_time)],
histtype="step",
label=f"{diameter:.3f} km",
color=color,
bins=bins,
density=True,
)

# Identify number of objects beyond 100 days
realization_time = summary.realization_time.to_numpy(zero_copy_only=False)
n_objects_beyond_100_days = np.sum(realization_time > 100)

ax.text(
99,
0.01,
rf"$N_{{objects}}$(>100 d)={n_objects_beyond_100_days}",
ha="right",
rotation=90,
)

ax.set_xlim(0, realization_time_max)
ax.set_xlabel("Realization Time for Discoveries [days]")
ax.set_ylabel("PDF")
ax.legend(frameon=False, bbox_to_anchor=(1.01, 0.75))
return fig, ax


def plot_discoveries_by_diameter(
summary: ImpactorResultSummary,
) -> Tuple[plt.Figure, plt.Axes]:

# Calculate the discovery summary
discovery_summary = summary.summarize_discoveries()

fig, ax = plt.subplots(1, 1, dpi=200)

unique_diameters = discovery_summary.diameter.unique().sort().to_pylist()
colors = plt.cm.coolwarm(np.linspace(0, 1, len(unique_diameters)))
for i, (diameter, color) in enumerate(zip(unique_diameters, colors)):

# Filter to the results for this diameter
discoveries_at_diameter = discovery_summary.select("diameter", diameter)
percent_discovered = pc.multiply(
pc.divide(
pc.cast(pc.sum(discoveries_at_diameter.discovered), pa.float64()),
pc.cast(pc.sum(discoveries_at_diameter.total), pa.float64()),
),
100,
).as_py()

ax.bar(i, height=percent_discovered, color=color)
ax.text(
i,
percent_discovered + 1,
f"{percent_discovered:.2f}%",
ha="center",
fontsize=10,
)

x_ticks = np.arange(0, len(unique_diameters), 1)
x_tick_labels = [f"{diameter:.3f}" for diameter in unique_diameters]
ax.set_xticks(x_ticks)
ax.set_xticklabels(x_tick_labels)
ax.set_ylim(0, 100)
ax.set_xlabel("Diameter [km]")
ax.set_ylabel("Discovered [%]")

return fig, ax


def make_analysis_plots(
summary: ImpactorResultSummary,
out_dir: str,
) -> None:

fig, ax = plot_warning_time_histogram(summary)
fig.savefig(
os.path.join(out_dir, "warning_time_histogram.jpg"),
bbox_inches="tight",
dpi=200,
)
logger.info("Generated warning time histogram")
plt.close(fig)

fig, ax = plot_realization_time_histogram(summary)
fig.savefig(
os.path.join(out_dir, "realization_time_histogram.jpg"),
bbox_inches="tight",
dpi=200,
)
logger.info("Generated realization time histogram")
plt.close(fig)

fig, ax = plot_discoveries_by_diameter(summary)
fig.savefig(
os.path.join(out_dir, "discoveries_by_diameter.jpg"),
bbox_inches="tight",
dpi=200,
)
logger.info("Generated discoveries by diameter plot")
plt.close(fig)

return
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,12 @@ def test_compute_warning_time():
warning_time_obj3 = warning_times.select("orbit_id", "test3")
assert warning_time_obj3.warning_time[0].as_py() == 50.0 # 60300 - 60250


# Make sure warning time still works if inputs are not sorted
scrambled_results = results.take([1, 0, 2, 3])
scrambled_impactor_orbits = impactor_orbits.take([1, 0, 2])
warning_times = compute_warning_time(scrambled_impactor_orbits, scrambled_results, threshold=0.25)
warning_times = compute_warning_time(
scrambled_impactor_orbits, scrambled_results, threshold=0.25
)
assert len(warning_times) == 3
assert warning_times.orbit_id.to_pylist() == ["test1", "test2", "test3"]
assert warning_times.warning_time.to_pylist() == [40.0, None, 50.0]
Expand Down Expand Up @@ -387,7 +388,9 @@ def test_compute_realization_time():
# Now test with scrambled inputs
scrambled_results = results.take([1, 0, 2, 3])
scrambled_impactor_orbits = impactor_orbits.take([1, 0, 2])
realization_times = compute_realization_time(scrambled_impactor_orbits, scrambled_results, discovery_dates, threshold=1e-9)
realization_times = compute_realization_time(
scrambled_impactor_orbits, scrambled_results, discovery_dates, threshold=1e-9
)
assert len(realization_times) == 3
assert realization_times.orbit_id.to_pylist() == ["test1", "test2", "test3"]
assert realization_times.realization_time.to_pylist() == [25.0, None, 0.0]
2 changes: 1 addition & 1 deletion src/adam_impact_study/cli/impact.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def run_impact_study(

# Run impact study
logger.info("Starting impact study...")
impact_study_results = run_impact_study_all(
impact_study_results, results_timings = run_impact_study_all(
filtered_orbits,
pointing_file,
run_dir,
Expand Down
Loading

0 comments on commit ebbbbf8

Please sign in to comment.