diff --git a/cctorch/transforms.py b/cctorch/transforms.py index 222c4bf..8f51183 100644 --- a/cctorch/transforms.py +++ b/cctorch/transforms.py @@ -197,7 +197,7 @@ def forward(self, meta): x = np.array([-1, 0, 1]) y = neighbor_score # nb, nc, nx, 3 spl = CubicSpline(x, y, axis=-1) - x_ = np.linspace(-1, 1, 101) + x_ = np.linspace(-1, 1, 201) y_ = spl(x_) # nb, nc, nx, 101 ii = np.argmax(y_, axis=-1, keepdims=True) # nb, nc, nx, 1 sub_shift = np.take_along_axis(x_[np.newaxis, np.newaxis, np.newaxis, :], ii, axis=-1).squeeze( @@ -206,6 +206,8 @@ def forward(self, meta): topk_idx = topk_idx + sub_shift topk_score = np.take_along_axis(y_, ii, axis=-1).squeeze(-1) # nb, nc, nx + # print(f"sub_shift: {sub_shift}, topk_idx: {topk_idx}, topk_score: {topk_score}") + idx = np.argmax(weight, axis=1, keepdims=True) # nb, 1, nx max_cc = np.take_along_axis(topk_score, idx, axis=1) # nb, 1, nx shift_idx = np.take_along_axis(topk_idx, idx, axis=1) # nb, 1, nx diff --git a/run.py b/run.py index b7def37..e87d43b 100644 --- a/run.py +++ b/run.py @@ -16,6 +16,7 @@ from tqdm import tqdm import pandas as pd import matplotlib.pyplot as plt +import torch.distributed as dist def get_args_parser(add_help=True): @@ -360,51 +361,60 @@ def __init__(self, config): if ccconfig.mode == "CC": # %% - result_df = pd.DataFrame(result_df) - result_df.to_csv( - os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}_origin.csv"), index=False - ) + if len(result_df) > 0: - # %% filter based on cc values - result_df = result_df[ - (result_df["cc"] >= ccconfig.min_cc) - & (result_df["shift"].abs() <= result_df["phase_type"].map(ccconfig.max_shift)) - ] - - # %% merge different instrument types of the same stations - stations["network_station"] = stations["network"] + "." + stations["station"] - result_df = result_df.merge(stations[["network_station", "idx_sta"]], on="idx_sta", how="left") - result_df.sort_values("weight", ascending=False, inplace=True) - result_df = result_df.groupby(["idx_eve1", "idx_eve2", "network_station", "phase_type"]).first().reset_index() - result_df.drop(columns=["network_station"], inplace=True) - - # %% filter based on cc observations - result_df = ( - result_df.groupby(["idx_eve1", "idx_eve2"]) - .apply(lambda x: (x.nlargest(ccconfig.max_obs, "weight") if len(x) >= ccconfig.min_obs else None)) - .reset_index(drop=True) - ) + result_df = pd.DataFrame(result_df) + result_df.to_csv( + os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}_origin.csv"), index=False + ) - # %% - event_idx_dict = events["event_index"].to_dict() ## faster than using .loc - station_id_dict = stations["station"].to_dict() + # %% filter based on cc values + result_df = result_df[ + (result_df["cc"] >= ccconfig.min_cc) + & (result_df["shift"].abs() <= result_df["phase_type"].map(ccconfig.max_shift)) + ] + + # %% merge different instrument types of the same stations + stations["network_station"] = stations["network"] + "." + stations["station"] + result_df = result_df.merge(stations[["network_station", "idx_sta"]], on="idx_sta", how="left") + result_df.sort_values("weight", ascending=False, inplace=True) + result_df = ( + result_df.groupby(["idx_eve1", "idx_eve2", "network_station", "phase_type"]).first().reset_index() + ) + result_df.drop(columns=["network_station"], inplace=True) - # %% - result_df.to_csv( - os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}.csv"), index=False - ) + # %% filter based on cc observations + result_df = ( + result_df.groupby(["idx_eve1", "idx_eve2"]) + .apply(lambda x: (x.nlargest(ccconfig.max_obs, "weight") if len(x) >= ccconfig.min_obs else None)) + .reset_index(drop=True) + ) + + # %% + event_idx_dict = events["event_index"].to_dict() ## faster than using .loc + station_id_dict = stations["station"].to_dict() + + # %% + result_df.to_csv( + os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}.csv"), index=False + ) - # %% write to cc file - with open(os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}_dt.cc"), "w") as fp: - for (i, j), record in tqdm(result_df.groupby(["idx_eve1", "idx_eve2"])): - event_idx1 = event_idx_dict[i] - event_idx2 = event_idx_dict[j] - fp.write(f"# {event_idx1} {event_idx2} 0.000\n") - for k, record_ in record.iterrows(): - idx_sta = record_["idx_sta"] - station_id = station_id_dict[idx_sta] - phase_type = record_["phase_type"] - fp.write(f"{station_id} {record_['dt']: .4f} {record_['weight']:.4f} {phase_type}\n") + # %% write to cc file + with open(os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}_dt.cc"), "w") as fp: + for (i, j), record in tqdm(result_df.groupby(["idx_eve1", "idx_eve2"])): + event_idx1 = event_idx_dict[i] + event_idx2 = event_idx_dict[j] + fp.write(f"# {event_idx1} {event_idx2} 0.000\n") + for k, record_ in record.iterrows(): + idx_sta = record_["idx_sta"] + station_id = station_id_dict[idx_sta] + phase_type = record_["phase_type"] + fp.write(f"{station_id} {record_['dt']: .4f} {record_['weight']:.4f} {phase_type}\n") + + if world_size > 1: + dist.barrier() + if rank == 0: + os.system(f"cat {args.result_path}/*_{world_size:03d}_dt.cc > {args.result_path}/dt.cc") # MAX_THREADS = 32 # with h5py.File(os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}.h5"), "w") as fp: