diff --git a/eqnet/utils/visualization.py b/eqnet/utils/visualization.py index 479d0ad..016eee2 100644 --- a/eqnet/utils/visualization.py +++ b/eqnet/utils/visualization.py @@ -398,7 +398,7 @@ def plot_phasenet_plus( fig, axes = plt.subplots(9, 1, figsize=(10, 10)) t = pd.date_range(pd.Timestamp(begin_time[i]), periods=nt, freq=pd.Timedelta(seconds=dt)) for j in range(3): - axes[j].plot(t, meta["raw_data"][i, j, :, 0], linewidth=0.5, color="k", label=f"{chn_name[j]}") + axes[j].plot(t, meta["raw_data"][i, j, :, 0], lw=0.5, color="k", label=f"{chn_name[j]}") axes[j].set_xlim(t[0], t[-1]) axes[j].set_xticklabels([]) axes[j].grid("on") @@ -409,7 +409,7 @@ def plot_phasenet_plus( for j in range(3): t = pd.date_range(pd.Timestamp(begin_time[i]), periods=nt, freq=pd.Timedelta(seconds=dt)) - axes[j + shift].plot(t, meta["data"][i, j, :, 0], linewidth=0.5, color="k", label=f"{chn_name[j]}") + axes[j + shift].plot(t, meta["data"][i, j, :, 0], lw=0.5, color="k", label=f"{chn_name[j]}") axes[j + shift].set_xlim(t[0], t[-1]) axes[j + shift].set_xticklabels([]) axes[j + shift].grid("on") @@ -417,13 +417,15 @@ def plot_phasenet_plus( k = 3 + shift t = pd.date_range(pd.Timestamp(begin_time[i]), periods=nt, freq=pd.Timedelta(seconds=dt)) - axes[k].plot(t, phase[i, 1, :, 0], "b", label="P-phase") - axes[k].plot(t, phase[i, 2, :, 0], "r", label="S-phase") + axes[k].plot(t, phase[i, 2, :, 0], "r", lw=1.0) + axes[k].plot(t, phase[i, 1, :, 0], "b", lw=1.0) color = {"P": "b", "S": "r"} for ii, pick in enumerate(phase_picks[i]): tt = pd.to_datetime(pick["phase_time"]) - axes[k].plot([tt, tt], [-0.05, 1.05], f"--{color[pick['phase_type']]}", linewidth=0.5) + axes[k].plot([tt, tt], [-0.05, 1.05], f"--{color[pick['phase_type']]}", linewidth=0.8) + axes[k].plot([], [], "-b", label="P-phase") + axes[k].plot([], [], "-r", label="S-phase") axes[k].set_xlim(t[0], t[-1]) axes[k].set_ylim(-0.05, 1.05) axes[k].set_xticklabels([]) @@ -432,7 +434,20 @@ def plot_phasenet_plus( t = pd.date_range(pd.Timestamp(begin_time[i]), periods=nt_polarity, freq=pd.Timedelta(seconds=dt_polarity)) # axes[k + 1].plot(t, (polarity[i, 0, :, 0] - 0.5) * 2.0, "b", label="polarity") - axes[k + 1].plot(t, polarity[i, 1, :, 0] - polarity[i, 2, :, 0], "b", label="Polarity") + # axes[k + 1].plot(t, polarity[i, 1, :, 0] - polarity[i, 2, :, 0], "b", label="Polarity") + for ii, pick in enumerate(phase_picks[i]): + tt = pd.to_datetime(pick["phase_time"]) + amp = pick["phase_polarity"] + if abs(amp) > 0.15: + axes[k + 1].annotate( + "", + xy=(tt, -0.03 * np.sign(amp)), + xytext=(tt, amp), + arrowprops=dict(arrowstyle="<-", color=f"{color[pick['phase_type']]}", lw=1.5), + ) + axes[k + 1].plot([], [], "-b", label="P-polarity") + axes[k + 1].plot([], [], "-r", label="S-polarity") + axes[k + 1].plot([t[0], t[-1]], [0.0, 0.0], "-", color="blue", lw=1.0) axes[k + 1].set_xlim(t[0], t[-1]) axes[k + 1].set_ylim(-1.05, 1.05) axes[k + 1].set_xticklabels([]) @@ -458,8 +473,8 @@ def plot_phasenet_plus( axes[k + 2].plot([at, at], [-0.05, 1.05], "--C0", linewidth=2.0) axes[k + 2].annotate( "", - xy=(at, 0.3), - xytext=(ot, 0.3), + xy=(max(t[0], at), 0.3), + xytext=(max(t[0], ot), 0.3), arrowprops=dict(arrowstyle="<-", color="C1", lw=2), ) axes[k + 2].plot([], [], "--C3", label="Origin time") diff --git a/predict.py b/predict.py index 8dea338..d4e2980 100644 --- a/predict.py +++ b/predict.py @@ -415,7 +415,6 @@ def main(args): model_url = "https://github.com/AI4EPS/models/releases/download/PhaseNet-Plus-LCSN/model_99.pth" elif args.model == "phasenet_das": if args.location is None: - # model_url = "ai4eps/model-registry/PhaseNet-DAS:latest" # model_url = "https://github.com/AI4EPS/models/releases/download/PhaseNet-DAS-v0/PhaseNet-DAS-v0.pth" model_url = "https://github.com/AI4EPS/models/releases/download/PhaseNet-DAS-v1/PhaseNet-DAS-v1.pth" elif args.location == "forge":