Skip to content

Commit

Permalink
add merge
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Aug 13, 2024
1 parent c48af33 commit 9466dd3
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 42 deletions.
4 changes: 3 additions & 1 deletion cctorch/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def forward(self, meta):
x = np.array([-1, 0, 1])
y = neighbor_score # nb, nc, nx, 3
spl = CubicSpline(x, y, axis=-1)
x_ = np.linspace(-1, 1, 101)
x_ = np.linspace(-1, 1, 201)
y_ = spl(x_) # nb, nc, nx, 101
ii = np.argmax(y_, axis=-1, keepdims=True) # nb, nc, nx, 1
sub_shift = np.take_along_axis(x_[np.newaxis, np.newaxis, np.newaxis, :], ii, axis=-1).squeeze(
Expand All @@ -206,6 +206,8 @@ def forward(self, meta):
topk_idx = topk_idx + sub_shift
topk_score = np.take_along_axis(y_, ii, axis=-1).squeeze(-1) # nb, nc, nx

# print(f"sub_shift: {sub_shift}, topk_idx: {topk_idx}, topk_score: {topk_score}")

idx = np.argmax(weight, axis=1, keepdims=True) # nb, 1, nx
max_cc = np.take_along_axis(topk_score, idx, axis=1) # nb, 1, nx
shift_idx = np.take_along_axis(topk_idx, idx, axis=1) # nb, 1, nx
Expand Down
92 changes: 51 additions & 41 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import torch.distributed as dist


def get_args_parser(add_help=True):
Expand Down Expand Up @@ -360,51 +361,60 @@ def __init__(self, config):
if ccconfig.mode == "CC":

# %%
result_df = pd.DataFrame(result_df)
result_df.to_csv(
os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}_origin.csv"), index=False
)
if len(result_df) > 0:

# %% filter based on cc values
result_df = result_df[
(result_df["cc"] >= ccconfig.min_cc)
& (result_df["shift"].abs() <= result_df["phase_type"].map(ccconfig.max_shift))
]

# %% merge different instrument types of the same stations
stations["network_station"] = stations["network"] + "." + stations["station"]
result_df = result_df.merge(stations[["network_station", "idx_sta"]], on="idx_sta", how="left")
result_df.sort_values("weight", ascending=False, inplace=True)
result_df = result_df.groupby(["idx_eve1", "idx_eve2", "network_station", "phase_type"]).first().reset_index()
result_df.drop(columns=["network_station"], inplace=True)

# %% filter based on cc observations
result_df = (
result_df.groupby(["idx_eve1", "idx_eve2"])
.apply(lambda x: (x.nlargest(ccconfig.max_obs, "weight") if len(x) >= ccconfig.min_obs else None))
.reset_index(drop=True)
)
result_df = pd.DataFrame(result_df)
result_df.to_csv(
os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}_origin.csv"), index=False
)

# %%
event_idx_dict = events["event_index"].to_dict() ## faster than using .loc
station_id_dict = stations["station"].to_dict()
# %% filter based on cc values
result_df = result_df[
(result_df["cc"] >= ccconfig.min_cc)
& (result_df["shift"].abs() <= result_df["phase_type"].map(ccconfig.max_shift))
]

# %% merge different instrument types of the same stations
stations["network_station"] = stations["network"] + "." + stations["station"]
result_df = result_df.merge(stations[["network_station", "idx_sta"]], on="idx_sta", how="left")
result_df.sort_values("weight", ascending=False, inplace=True)
result_df = (
result_df.groupby(["idx_eve1", "idx_eve2", "network_station", "phase_type"]).first().reset_index()
)
result_df.drop(columns=["network_station"], inplace=True)

# %%
result_df.to_csv(
os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}.csv"), index=False
)
# %% filter based on cc observations
result_df = (
result_df.groupby(["idx_eve1", "idx_eve2"])
.apply(lambda x: (x.nlargest(ccconfig.max_obs, "weight") if len(x) >= ccconfig.min_obs else None))
.reset_index(drop=True)
)

# %%
event_idx_dict = events["event_index"].to_dict() ## faster than using .loc
station_id_dict = stations["station"].to_dict()

# %%
result_df.to_csv(
os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}.csv"), index=False
)

# %% write to cc file
with open(os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}_dt.cc"), "w") as fp:
for (i, j), record in tqdm(result_df.groupby(["idx_eve1", "idx_eve2"])):
event_idx1 = event_idx_dict[i]
event_idx2 = event_idx_dict[j]
fp.write(f"# {event_idx1} {event_idx2} 0.000\n")
for k, record_ in record.iterrows():
idx_sta = record_["idx_sta"]
station_id = station_id_dict[idx_sta]
phase_type = record_["phase_type"]
fp.write(f"{station_id} {record_['dt']: .4f} {record_['weight']:.4f} {phase_type}\n")
# %% write to cc file
with open(os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}_dt.cc"), "w") as fp:
for (i, j), record in tqdm(result_df.groupby(["idx_eve1", "idx_eve2"])):
event_idx1 = event_idx_dict[i]
event_idx2 = event_idx_dict[j]
fp.write(f"# {event_idx1} {event_idx2} 0.000\n")
for k, record_ in record.iterrows():
idx_sta = record_["idx_sta"]
station_id = station_id_dict[idx_sta]
phase_type = record_["phase_type"]
fp.write(f"{station_id} {record_['dt']: .4f} {record_['weight']:.4f} {phase_type}\n")

if world_size > 1:
dist.barrier()
if rank == 0:
os.system(f"cat {args.result_path}/*_{world_size:03d}_dt.cc > {args.result_path}/dt.cc")

# MAX_THREADS = 32
# with h5py.File(os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}.h5"), "w") as fp:
Expand Down

0 comments on commit 9466dd3

Please sign in to comment.