Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed May 12, 2024
1 parent 753c1d1 commit 99c513c
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 6 deletions.
5 changes: 3 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ def main(args):
# print("Loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint["epoch"]))
else:
if args.model == "phasenet":
raise ("No pretrained model for phasenet, please use phasenet_plus instead")
if args.location is None:
model_url = "https://github.com/AI4EPS/models/releases/download/PhaseNet-v1/model_99.pth"
elif args.model == "phasenet_plus":
if args.location is None:
model_url = "https://github.com/AI4EPS/models/releases/download/PhaseNet-Plus-v1/model_99.pth"
Expand All @@ -426,7 +427,7 @@ def main(args):
else:
raise
checkpoint = torch.hub.load_state_dict_from_url(
model_url, model_dir="./", progress=True, check_hash=True, map_location="cpu"
model_url, model_dir=f"./model_{args.model}", progress=True, check_hash=True, map_location="cpu"
)

## load model from wandb
Expand Down
100 changes: 96 additions & 4 deletions tests/phasenet_plus/paper_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,46 @@ def read_labels(file):
phasenet_picks["event_index"] = phasenet_picks["event_id"]
phasenet_picks["phase_time"] = pd.to_datetime(phasenet_picks["phase_time"])

# %%
phasenet_pt_csv = f"phasenet_pt_picks_{region}.csv"
if reload or not os.path.exists(phasenet_pt_csv):
if region == "NC":
phasenet_pt_picks = pd.read_csv(
"https://huggingface.co/datasets/AI4EPS/quakeflow_nc/resolve/main/models/phasenet_pt_picks.csv"
)
elif region == "SC":
phasenet_pt_picks = pd.read_csv(
"https://huggingface.co/datasets/AI4EPS/quakeflow_sc/resolve/main/models/phasenet_pt_picks.csv"
)
# pick_path = "../../results_phasenet_quakeflow_sc/picks_phasenet"
# event_ids = glob(f"{pick_path}/*")
# picks_list = []
# for event_id in tqdm(event_ids):
# station_ids = glob(f"{event_id}/*.csv")
# for station_id in station_ids:
# if os.stat(station_id).st_size == 0:
# # print(f"Empty file: {station_id}")
# continue
# phasenet_pt_picks = pd.read_csv(station_id)
# phasenet_pt_picks["event_index"] = event_id.split("/")[-1]
# picks_list.append(phasenet_pt_picks)
# phasenet_pt_picks = pd.concat(picks_list, ignore_index=True)
# phasenet_pt_picks["phase_time"] = pd.to_datetime(phasenet_pt_picks["phase_time"])
# phasenet_pt_picks.drop(columns=["dt_s"], inplace=True)
# phasenet_pt_picks = filter_duplicates(phasenet_pt_picks)
# phasenet_pt_picks.drop(columns=["phase_index"], inplace=True)
# phasenet_pt_picks.rename(columns={"event_index": "event_id"}, inplace=True)
# phasenet_pt_picks.to_csv(
# phasenet_pt_csv,
# columns=["event_id", "station_id", "phase_time", "phase_score", "phase_type"],
# index=False,
# )
else:
phasenet_pt_picks = pd.read_csv(phasenet_pt_csv, parse_dates=["phase_time"])

phasenet_pt_picks["event_index"] = phasenet_pt_picks["event_id"]
phasenet_pt_picks["phase_time"] = pd.to_datetime(phasenet_pt_picks["phase_time"])

# %% Load PhaseNet+ model
phasenet_plus_csv = f"phasenet_plus_picks_{region}.csv"
if reload or not os.path.exists(phasenet_plus_csv):
Expand Down Expand Up @@ -233,6 +273,11 @@ def read_labels(file):
idx = phasenet_picks["phase_type"] == "P"
plt.hist(phasenet_picks[idx]["phase_score"], bins=np.linspace(0.3, 1, 70 // 2 + 1), alpha=0.5, label="PhaseNet")
plt.legend()
idx = phasenet_pt_picks["phase_type"] == "P"
plt.hist(
phasenet_pt_picks[idx]["phase_score"], bins=np.linspace(0.3, 1, 70 // 2 + 1), alpha=0.5, label="PhaseNet (PT)"
)
plt.legend()

plt.figure()
idx = phasenet_plus_picks["phase_type"] == "S"
Expand All @@ -242,6 +287,11 @@ def read_labels(file):
idx = phasenet_picks["phase_type"] == "S"
plt.hist(phasenet_picks[idx]["phase_score"], bins=np.linspace(0.3, 1, 70 // 2 + 1), alpha=0.5, label="PhaseNet")
plt.legend()
idx = phasenet_pt_picks["phase_type"] == "S"
plt.hist(
phasenet_pt_picks[idx]["phase_score"], bins=np.linspace(0.3, 1, 70 // 2 + 1), alpha=0.5, label="PhaseNet (PT)"
)
plt.legend()

# %%
plt.figure()
Expand Down Expand Up @@ -276,12 +326,15 @@ def read_labels(file):
num_s_phasenet_plus = len(phasenet_plus_picks[phasenet_plus_picks["phase_type"] == "S"])
num_p_phasenet = len(phasenet_picks[phasenet_picks["phase_type"] == "P"])
num_s_phasenet = len(phasenet_picks[phasenet_picks["phase_type"] == "S"])
num_p_phasenet_pt = len(phasenet_pt_picks[phasenet_pt_picks["phase_type"] == "P"])
num_s_phasenet_pt = len(phasenet_pt_picks[phasenet_pt_picks["phase_type"] == "S"])
bar_width = 0.2
index = np.arange(2)
plt.bar(index, [num_p_labels, num_s_labels], bar_width, label="Labels")
plt.bar(index + bar_width, [num_p_phasenet_plus, num_s_phasenet_plus], bar_width, label="PhaseNet+")
plt.bar(index + 2 * bar_width, [num_p_phasenet, num_s_phasenet], bar_width, label="PhaseNet")
plt.xticks(index + bar_width, ["P", "S"])
plt.bar(index + 3 * bar_width, [num_p_phasenet_pt, num_s_phasenet_pt], bar_width, label="PhaseNet (PT)")
plt.xticks(index + 1.5 * bar_width, ["P", "S"])
plt.legend(loc="lower center")

# %%
Expand All @@ -294,7 +347,10 @@ def read_labels(file):
# %%
for phase_type in ["P", "S"]:
plt.figure()
for picks, name in zip([phasenet_plus_picks, phasenet_picks], ["PhaseNet+", "PhaseNet"]):
# for picks, name in zip([phasenet_plus_picks, phasenet_picks], ["PhaseNet+", "PhaseNet"]):
for picks, name in zip(
[phasenet_plus_picks, phasenet_picks, phasenet_pt_picks], ["PhaseNet+", "PhaseNet", "PhaseNet (PT)"]
):

merged = calc_error(picks, labels, score_threshold, phase_type)

Expand All @@ -312,7 +368,10 @@ def read_labels(file):
# %%
metrics_summary = []
for phase_type in ["P", "S"]:
for picks, name in zip([phasenet_plus_picks, phasenet_picks], ["PhaseNet+", "PhaseNet"]):
# for picks, name in zip([phasenet_plus_picks, phasenet_picks], ["PhaseNet+", "PhaseNet"]):
for picks, name in zip(
[phasenet_plus_picks, phasenet_picks, phasenet_pt_picks], ["PhaseNet+", "PhaseNet", "PhaseNet (PT)"]
):
metrics_list = []
for score_threshold in tqdm(np.linspace(0.1, 1, 20)):
metrics = calc_metrics(
Expand Down Expand Up @@ -349,6 +408,16 @@ def read_labels(file):
)
plt.scatter(metrics_summary[idx].iloc[idx_f1]["recall"], metrics_summary[idx].iloc[idx_f1]["precision"])

idx = (metrics_summary["phase_type"] == "P") & (metrics_summary["model"] == "PhaseNet (PT)")
plt.plot(metrics_summary[idx]["recall"], metrics_summary[idx]["precision"], label="PhaseNet (PT)")
idx_f1 = np.argmax(metrics_summary[idx]["f1"])
print(
metrics_summary[idx].iloc[idx_f1]["recall"],
metrics_summary[idx].iloc[idx_f1]["precision"],
metrics_summary[idx].iloc[idx_f1]["f1"],
)
plt.scatter(metrics_summary[idx].iloc[idx_f1]["recall"], metrics_summary[idx].iloc[idx_f1]["precision"])

plt.legend()
plt.xlabel("Recall")
plt.ylabel("Precision")
Expand Down Expand Up @@ -376,6 +445,16 @@ def read_labels(file):
)
plt.scatter(metrics_summary[idx].iloc[idx_f1]["recall"], metrics_summary[idx].iloc[idx_f1]["precision"])

idx = (metrics_summary["phase_type"] == "S") & (metrics_summary["model"] == "PhaseNet (PT)")
plt.plot(metrics_summary[idx]["recall"], metrics_summary[idx]["precision"], label="PhaseNet (PT)")
idx_f1 = np.argmax(metrics_summary[idx]["f1"])
print(
metrics_summary[idx].iloc[idx_f1]["recall"],
metrics_summary[idx].iloc[idx_f1]["precision"],
metrics_summary[idx].iloc[idx_f1]["f1"],
)
plt.scatter(metrics_summary[idx].iloc[idx_f1]["recall"], metrics_summary[idx].iloc[idx_f1]["precision"])

plt.legend()
plt.xlabel("Recall")
plt.ylabel("Precision")
Expand All @@ -395,6 +474,12 @@ def read_labels(file):
[1.0] + list(metrics_summary[idx]["TPR"]) + [0.0],
label="PhaseNet",
)
idx = (metrics_summary["phase_type"] == "P") & (metrics_summary["model"] == "PhaseNet (PT)")
plt.plot(
[1.0] + list(metrics_summary[idx]["FPR"]) + [0.0],
[1.0] + list(metrics_summary[idx]["TPR"]) + [0.0],
label="PhaseNet (PT)",
)
plt.legend()
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
Expand All @@ -414,6 +499,12 @@ def read_labels(file):
[1.0] + list(metrics_summary[idx]["TPR"]) + [0.0],
label="PhaseNet",
)
idx = (metrics_summary["phase_type"] == "S") & (metrics_summary["model"] == "PhaseNet (PT)")
plt.plot(
[1.0] + list(metrics_summary[idx]["FPR"]) + [0.0],
[1.0] + list(metrics_summary[idx]["TPR"]) + [0.0],
label="PhaseNet (PT)",
)
plt.legend()
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
Expand All @@ -422,7 +513,8 @@ def read_labels(file):

# %%
phase_type = "P"
polarity_threshold = 0.8
polarity_threshold = 0.5
score_threshold = 0.5
picks = phasenet_plus_picks
metrics = calc_metrics(picks, labels, polarity_threshold, score_threshold, time_tolerance, phase_type, "phase_time")
for key, value in metrics.items():
Expand Down
Loading

0 comments on commit 99c513c

Please sign in to comment.