Skip to content

Commit

Permalink
improve visulization
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed May 8, 2024
1 parent 99c03ca commit b158e05
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
31 changes: 23 additions & 8 deletions eqnet/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def plot_phasenet_plus(
fig, axes = plt.subplots(9, 1, figsize=(10, 10))
t = pd.date_range(pd.Timestamp(begin_time[i]), periods=nt, freq=pd.Timedelta(seconds=dt))
for j in range(3):
axes[j].plot(t, meta["raw_data"][i, j, :, 0], linewidth=0.5, color="k", label=f"{chn_name[j]}")
axes[j].plot(t, meta["raw_data"][i, j, :, 0], lw=0.5, color="k", label=f"{chn_name[j]}")
axes[j].set_xlim(t[0], t[-1])
axes[j].set_xticklabels([])
axes[j].grid("on")
Expand All @@ -409,21 +409,23 @@ def plot_phasenet_plus(

for j in range(3):
t = pd.date_range(pd.Timestamp(begin_time[i]), periods=nt, freq=pd.Timedelta(seconds=dt))
axes[j + shift].plot(t, meta["data"][i, j, :, 0], linewidth=0.5, color="k", label=f"{chn_name[j]}")
axes[j + shift].plot(t, meta["data"][i, j, :, 0], lw=0.5, color="k", label=f"{chn_name[j]}")
axes[j + shift].set_xlim(t[0], t[-1])
axes[j + shift].set_xticklabels([])
axes[j + shift].grid("on")
axes[j + shift].legend(loc="upper right")

k = 3 + shift
t = pd.date_range(pd.Timestamp(begin_time[i]), periods=nt, freq=pd.Timedelta(seconds=dt))
axes[k].plot(t, phase[i, 1, :, 0], "b", label="P-phase")
axes[k].plot(t, phase[i, 2, :, 0], "r", label="S-phase")
axes[k].plot(t, phase[i, 2, :, 0], "r", lw=1.0)
axes[k].plot(t, phase[i, 1, :, 0], "b", lw=1.0)
color = {"P": "b", "S": "r"}
for ii, pick in enumerate(phase_picks[i]):
tt = pd.to_datetime(pick["phase_time"])
axes[k].plot([tt, tt], [-0.05, 1.05], f"--{color[pick['phase_type']]}", linewidth=0.5)
axes[k].plot([tt, tt], [-0.05, 1.05], f"--{color[pick['phase_type']]}", linewidth=0.8)

axes[k].plot([], [], "-b", label="P-phase")
axes[k].plot([], [], "-r", label="S-phase")
axes[k].set_xlim(t[0], t[-1])
axes[k].set_ylim(-0.05, 1.05)
axes[k].set_xticklabels([])
Expand All @@ -432,7 +434,20 @@ def plot_phasenet_plus(

t = pd.date_range(pd.Timestamp(begin_time[i]), periods=nt_polarity, freq=pd.Timedelta(seconds=dt_polarity))
# axes[k + 1].plot(t, (polarity[i, 0, :, 0] - 0.5) * 2.0, "b", label="polarity")
axes[k + 1].plot(t, polarity[i, 1, :, 0] - polarity[i, 2, :, 0], "b", label="Polarity")
# axes[k + 1].plot(t, polarity[i, 1, :, 0] - polarity[i, 2, :, 0], "b", label="Polarity")
for ii, pick in enumerate(phase_picks[i]):
tt = pd.to_datetime(pick["phase_time"])
amp = pick["phase_polarity"]
if abs(amp) > 0.15:
axes[k + 1].annotate(
"",
xy=(tt, -0.03 * np.sign(amp)),
xytext=(tt, amp),
arrowprops=dict(arrowstyle="<-", color=f"{color[pick['phase_type']]}", lw=1.5),
)
axes[k + 1].plot([], [], "-b", label="P-polarity")
axes[k + 1].plot([], [], "-r", label="S-polarity")
axes[k + 1].plot([t[0], t[-1]], [0.0, 0.0], "-", color="blue", lw=1.0)
axes[k + 1].set_xlim(t[0], t[-1])
axes[k + 1].set_ylim(-1.05, 1.05)
axes[k + 1].set_xticklabels([])
Expand All @@ -458,8 +473,8 @@ def plot_phasenet_plus(
axes[k + 2].plot([at, at], [-0.05, 1.05], "--C0", linewidth=2.0)
axes[k + 2].annotate(
"",
xy=(at, 0.3),
xytext=(ot, 0.3),
xy=(max(t[0], at), 0.3),
xytext=(max(t[0], ot), 0.3),
arrowprops=dict(arrowstyle="<-", color="C1", lw=2),
)
axes[k + 2].plot([], [], "--C3", label="Origin time")
Expand Down
1 change: 0 additions & 1 deletion predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,6 @@ def main(args):
model_url = "https://github.com/AI4EPS/models/releases/download/PhaseNet-Plus-LCSN/model_99.pth"
elif args.model == "phasenet_das":
if args.location is None:
# model_url = "ai4eps/model-registry/PhaseNet-DAS:latest"
# model_url = "https://github.com/AI4EPS/models/releases/download/PhaseNet-DAS-v0/PhaseNet-DAS-v0.pth"
model_url = "https://github.com/AI4EPS/models/releases/download/PhaseNet-DAS-v1/PhaseNet-DAS-v1.pth"
elif args.location == "forge":
Expand Down

0 comments on commit b158e05

Please sign in to comment.