Skip to content

Commit

Permalink
update paper figure
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed May 11, 2024
1 parent bdc9edf commit 753c1d1
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions tests/phasenet_plus/paper_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,18 @@ def read_labels(file):
if __name__ == "__main__":

# %% Load origin PhaseNet model
phasenet_csv = "phasenet_picks.csv"
# if not os.path.exists(phasenet_csv):
if True:
phasenet_picks = pd.read_csv(
"https://huggingface.co/datasets/AI4EPS/quakeflow_sc/resolve/main/models/phasenet_picks.csv"
)
reload = False
region = "NC"
phasenet_csv = f"phasenet_picks_{region}.csv"
if reload or not os.path.exists(phasenet_csv):
if region == "NC":
phasenet_picks = pd.read_csv(
"https://huggingface.co/datasets/AI4EPS/quakeflow_nc/resolve/main/models/phasenet_picks.csv"
)
elif region == "SC":
phasenet_picks = pd.read_csv(
"https://huggingface.co/datasets/AI4EPS/quakeflow_sc/resolve/main/models/phasenet_picks.csv"
)
# phasenet_picks = pd.read_csv("phasenet_origin/picks_sc.csv")
# phasenet_picks[["event_index", "station_id"]] = phasenet_picks["file_name"].str.split("/", expand=True)
# phasenet_picks.drop(columns=["file_name", "begin_time"], inplace=True)
Expand All @@ -168,12 +174,16 @@ def read_labels(file):
phasenet_picks["phase_time"] = pd.to_datetime(phasenet_picks["phase_time"])

# %% Load PhaseNet+ model
phasenet_plus_csv = "phasenet_plus_picks.csv"
# if not os.path.exists(phasenet_plus_csv):
if True:
phasenet_plus_picks = pd.read_csv(
"https://huggingface.co/datasets/AI4EPS/quakeflow_sc/resolve/main/models/phasenet_plus_picks.csv"
)
phasenet_plus_csv = f"phasenet_plus_picks_{region}.csv"
if reload or not os.path.exists(phasenet_plus_csv):
if region == "NC":
phasenet_plus_picks = pd.read_csv(
"https://huggingface.co/datasets/AI4EPS/quakeflow_nc/resolve/main/models/phasenet_plus_picks.csv"
)
elif region == "SC":
phasenet_plus_picks = pd.read_csv(
"https://huggingface.co/datasets/AI4EPS/quakeflow_sc/resolve/main/models/phasenet_plus_picks.csv"
)
# pick_path = "../../results_ps_test_sc/picks_phasenet_plus"
# event_ids = glob(f"{pick_path}/*")
# picks_list = []
Expand Down Expand Up @@ -204,10 +214,12 @@ def read_labels(file):
phasenet_plus_picks["phase_time"] = pd.to_datetime(phasenet_plus_picks["phase_time"])

# %% Load labels
label_csv = "labels.csv"
label_csv = f"labels_{region}.csv"
if not os.path.exists(label_csv):
# labels = read_labels(file="/nfs/quakeflow_dataset/NC/quakeflow_nc/waveform_test.h5")
labels = read_labels(file="/nfs/quakeflow_dataset/SC/quakeflow_sc/waveform_test.h5")
if region == "NC":
labels = read_labels(file="/nfs/quakeflow_dataset/NC/quakeflow_nc/waveform_test.h5")
elif region == "SC":
labels = read_labels(file="/nfs/quakeflow_dataset/SC/quakeflow_sc/waveform_test.h5")
labels.to_csv(label_csv, index=False)
else:
labels = pd.read_csv(label_csv, parse_dates=["phase_time"])
Expand Down Expand Up @@ -244,16 +256,16 @@ def read_labels(file):

# %%
polarity_threshold = 0.5
num_u = len(phasenet_plus_picks[phasenet_plus_picks["phase_polarity"] > polarity_threshold])
num_n = len(phasenet_plus_picks[np.abs(phasenet_plus_picks["phase_polarity"].values) <= polarity_threshold])
num_d = len(phasenet_plus_picks[phasenet_plus_picks["phase_polarity"] < -polarity_threshold])
plt.figure()
plt.bar(["U", "N", "D"], [num_u, num_n, num_d], label="PhaseNet+")

num_u = len(labels[labels["phase_polarity"] == "U"])
num_n = len(labels[labels["phase_polarity"] == "N"])
num_d = len(labels[labels["phase_polarity"] == "D"])
plt.bar(["U", "N", "D"], [num_u, num_n, num_d], alpha=0.5, label="Labels")

num_u = len(phasenet_plus_picks[phasenet_plus_picks["phase_polarity"] > polarity_threshold])
num_n = len(phasenet_plus_picks[np.abs(phasenet_plus_picks["phase_polarity"].values) <= polarity_threshold])
num_d = len(phasenet_plus_picks[phasenet_plus_picks["phase_polarity"] < -polarity_threshold])
plt.bar(["U", "N", "D"], [num_u, num_n, num_d], alpha=0.5, label="PhaseNet+")
plt.legend()

# %%
Expand Down

0 comments on commit 753c1d1

Please sign in to comment.