Skip to content

Commit

Permalink
improve template matching
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Aug 26, 2024
1 parent 7cbc701 commit 5aa9bd6
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 57 deletions.
32 changes: 23 additions & 9 deletions cctorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,19 @@ def __init__(

if self.mode == "TM":
if data_list1 is not None:
with open(data_list1, "r") as fp:
self.data_list1 = fp.read().splitlines()
if data_list1.endswith(".txt"):
with open(data_list1, "r") as fp:
self.data_list1 = fp.read().splitlines()
else:
self.data_list1 = pd.read_csv(data_list1).set_index("idx_pick")
else:
self.data_list1 = None
if data_list2 is not None:
with open(data_list2, "r") as fp:
self.data_list2 = fp.read().splitlines()
if data_list2.endswith(".txt"):
with open(data_list2, "r") as fp:
self.data_list2 = fp.read().splitlines()
else:
self.data_list2 = pd.read_csv(data_list2).set_index("idx_pick")
else:
self.data_list2 = None

Expand Down Expand Up @@ -321,9 +327,9 @@ def sample(self, block_index):
"data": self.templates[jj],
"index": jj,
"info": {
"idx_eve": self.data_list1.loc[jj, "idx_eve"],
"idx_sta": self.data_list1.loc[jj, "idx_sta"],
"phase_type": self.data_list1.loc[jj, "phase_type"],
"idx_eve": self.data_list2.loc[jj, "idx_eve"],
"idx_sta": self.data_list2.loc[jj, "idx_sta"],
"phase_type": self.data_list2.loc[jj, "phase_type"],
"traveltime": self.traveltime[jj],
"traveltime_mask": self.traveltime_mask[jj],
"traveltime_index": self.traveltime_index[jj],
Expand Down Expand Up @@ -378,6 +384,8 @@ def sample(self, block_index):
info_batch2["traveltime"] = np.stack(info_batch2["traveltime"])
info_batch2["traveltime_mask"] = np.stack(info_batch2["traveltime_mask"])
info_batch2["traveltime_index"] = np.stack(info_batch2["traveltime_index"])
if "begin_time" in info_batch1:
info_batch1["begin_time"] = np.stack(info_batch1["begin_time"])

yield {
"data": data_batch1,
Expand Down Expand Up @@ -412,6 +420,8 @@ def sample(self, block_index):
info_batch2["traveltime"] = np.stack(info_batch2["traveltime"])
info_batch2["traveltime_mask"] = np.stack(info_batch2["traveltime_mask"])
info_batch2["traveltime_index"] = np.stack(info_batch2["traveltime_index"])
if "begin_time" in info_batch1:
info_batch1["begin_time"] = np.stack(info_batch1["begin_time"])

yield {
"data": data_batch1,
Expand Down Expand Up @@ -680,9 +690,13 @@ def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
tmp = trace.data.astype("float32")
data[j, i, : len(tmp)] = tmp[:nt]

# return data, {
# "begin_time": begin_time.datetime, # .strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
# "end_time": end_time.datetime, # .strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
# }
return data, {
"begin_time": begin_time.datetime,
"end_time": end_time.datetime,
"begin_time": np.datetime64(begin_time.datetime),
"end_time": np.datetime64(end_time.datetime),
}


Expand Down
48 changes: 46 additions & 2 deletions cctorch/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import torch.nn.functional as F
import torchaudio
from scipy import sparse
from scipy.interpolate import CubicSpline
from scipy.signal import tukey
from scipy.sparse.linalg import lsmr
from tqdm import tqdm
from scipy.interpolate import CubicSpline


#### Common ####
Expand Down Expand Up @@ -162,7 +162,7 @@ def fft_real_normalize(x):
return fft_real(x)


class DetectPeaks(torch.nn.Module):
class DetectPeaksCC(torch.nn.Module):
def __init__(self, kernel=3, stride=1, topk=2, vabs=True, interp=True, sampling_rate=100.0):
super().__init__()
self.kernel = kernel
Expand Down Expand Up @@ -228,6 +228,50 @@ def forward(self, meta):
return meta


class DetectPeaksTM(torch.nn.Module):
def __init__(self, vmin=0.6, kernel=300, stride=1, topk=2, vabs=True, interp=True, sampling_rate=100.0):
super().__init__()
self.vmin = vmin
self.kernel = kernel
self.stride = stride
self.topk = topk
self.vabs = vabs
self.interp = interp
self.sampling_rate = sampling_rate

def forward(self, meta):
xcorr = meta["xcorr"]
nlag = meta["nlag"]
nb, nc, nx, nt = xcorr.shape # nc = 1 by reduce_c, nx = 1 based on picks

## consider both positive and negative peaks
if self.vabs:
xcorr = torch.abs(xcorr)

smax = F.max_pool2d(xcorr, (1, self.kernel), stride=(1, self.stride), padding=(0, self.kernel // 2))

keep = (smax == xcorr).float()
topk_score, topk_idx = torch.topk(xcorr * keep, self.topk, sorted=True) # nb, 1, 1, k
topk_score, topk_idx = topk_score.cpu().numpy(), topk_idx.cpu().numpy()

if ("begin_time" in meta["info1"]) and ("traveltime" in meta["info2"]):
shift_time = (topk_idx - nlag) / self.sampling_rate
shift_time = np.array((shift_time * 1e3).astype(int), dtype="timedelta64[ms]")
traveltime = [x[0].item() for x in meta["info2"]["traveltime"]]
traveltime = np.array((np.array(traveltime) * 1e3).astype(int), dtype="timedelta64[ms]")[
:, np.newaxis, np.newaxis, np.newaxis
]
begin_time = np.array(meta["info1"]["begin_time"])[:, np.newaxis, np.newaxis, np.newaxis]
phase_time = begin_time + shift_time + traveltime
origin_time = begin_time + shift_time

meta["origin_time"] = origin_time
meta["phase_time"] = phase_time
meta["max_cc"] = topk_score

return meta


## Template Matching
class DetectTM(torch.nn.Module):
def __init__(self, ratio=10, maxpool_kernel=101, median_kernel=6000, K=100, sampling_rate=100.0):
Expand Down
152 changes: 106 additions & 46 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import torch
import torch.distributed as dist
import torchvision.transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm

import utils
from cctorch import CCDataset, CCIterableDataset, CCModel
from cctorch.transforms import *
from cctorch.utils import write_cc_pairs, write_tm_detects
from sklearn.cluster import DBSCAN
from torch.utils.data import DataLoader
from tqdm import tqdm


def get_args_parser(add_help=True):
Expand Down Expand Up @@ -83,9 +83,9 @@ def get_args_parser(add_help=True):
parser.add_argument("--temporal_gradient", action="store_true", help="use temporal gradient")

# cross-correlation parameters
parser.add_argument("--picks_csv", default="local/cctorch/cctorch_picks.csv", type=str, help="picks file")
parser.add_argument("--events_csv", default="local/cctorch/cctorch_events.csv", type=str, help="events file")
parser.add_argument("--stations_csv", default="local/cctorch/cctorch_stations.csv", type=str, help="stations file")
parser.add_argument("--picks_csv", default="cctorch_picks.csv", type=str, help="picks file")
parser.add_argument("--events_csv", default="cctorch_events.csv", type=str, help="events file")
parser.add_argument("--stations_csv", default="cctorch_stations.csv", type=str, help="stations file")
parser.add_argument("--taper", action="store_true", help="taper two data window")
parser.add_argument("--interp", action="store_true", help="interpolate the data window along time axs")
parser.add_argument("--scale_factor", default=10, type=int, help="interpolation scale up factor")
Expand Down Expand Up @@ -253,9 +253,11 @@ def __init__(self, config):
postprocess = []
if args.mode == "CC":
## TODO: add postprocess for cross-correlation
postprocess.append(DetectPeaks(kernel=3, stride=1, topk=2))
postprocess.append(DetectPeaksCC(kernel=3, stride=1, topk=2))
elif args.mode == "TM":
postprocess.append(DetectPeaks(vmin=0.6, kernel=301, stride=1, K=3600 // 5)) # assume 100Hz and 1 hour file
postprocess.append(
DetectPeaksTM(vmin=0.6, kernel=301, stride=1, topk=3600 // 5)
) # assume 100Hz and 1 hour file
elif args.mode == "AN":
## TODO: add postprocess for ambient noise
pass
Expand Down Expand Up @@ -333,31 +335,54 @@ def __init__(self, config):
result_df = []
for i, data in enumerate(tqdm(dataloader, position=rank, desc=f"{rank}/{world_size}: computing")):

idx_eve1 = data[0]["info"]["idx_eve"]
idx_eve2 = data[1]["info"]["idx_eve"]
idx_sta = data[0]["info"]["idx_sta"]
phase_type = data[0]["info"]["phase_type"]
if args.mode == "CC":
idx_eve1 = data[0]["info"]["idx_eve"]
idx_eve2 = data[1]["info"]["idx_eve"]
if args.mode == "TM":
idx_mseed = data[0]["index"]
idx_eve = data[1]["info"]["idx_eve"]
idx_sta = data[1]["info"]["idx_sta"]
phase_type = data[1]["info"]["phase_type"]

result = ccmodel(data)

cc_max = result["cc_max"]
cc_weight = result["cc_weight"]
cc_shift = result["cc_shift"]
cc_dt = result["cc_dt"]

for ii in range(len(idx_eve1)):
result_df.append(
{
"idx_eve1": idx_eve1[ii],
"idx_eve2": idx_eve2[ii],
"idx_sta": idx_sta[ii],
"phase_type": phase_type[ii],
"dt": cc_dt[ii].squeeze().item(),
"shift": cc_shift[ii].squeeze().item(),
"cc": cc_max[ii].squeeze().item(),
"weight": cc_weight[ii].squeeze().item(),
}
)
if args.mode == "CC":
cc_max = result["cc_max"]
cc_weight = result["cc_weight"]
cc_shift = result["cc_shift"]
cc_dt = result["cc_dt"]
for ii in range(len(idx_sta)):
result_df.append(
{
"idx_eve1": idx_eve1[ii],
"idx_eve2": idx_eve2[ii],
"idx_sta": idx_sta[ii],
"phase_type": phase_type[ii],
"dt": cc_dt[ii].squeeze().item(),
"shift": cc_shift[ii].squeeze().item(),
"cc": cc_max[ii].squeeze().item(),
"weight": cc_weight[ii].squeeze().item(),
}
)

if args.mode == "TM":
origin_time = result["origin_time"][:, 0, 0, :]
phase_time = result["phase_time"][:, 0, 0, :]
max_cc = result["max_cc"][:, 0, 0, :]
for ii in range(len(idx_sta)):
for jj in range(len(origin_time[ii])):
if max_cc[ii][jj].item() > ccconfig.min_cc:
result_df.append(
{
"idx_mseed": idx_mseed[ii],
"idx_eve": idx_eve[ii],
"idx_sta": idx_sta[ii],
"phase_type": phase_type[ii],
"phase_time": phase_time[ii][jj].item(),
"origin_time": origin_time[ii][jj].item(),
"cc": max_cc[ii][jj].item(),
}
)

if ccconfig.mode == "CC":

Expand All @@ -369,22 +394,22 @@ def __init__(self, config):
os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}_origin.csv"), index=False
)

##### More accurate by merging all results
# if world_size > 1:
# dist.barrier()

# if rank == 0:
# result_df = []
# for i in tqdm(range(world_size), desc="Merging"):
# if os.path.exists(
# os.path.join(args.result_path, f"{ccconfig.mode}_{i:03d}_{world_size:03d}_origin.csv")
# ):
# result_df.append(
# pd.read_csv(
# os.path.join(args.result_path, f"{ccconfig.mode}_{i:03d}_{world_size:03d}_origin.csv")
# )
# )
# result_df = pd.concat(result_df)
##### More accurate by merging all results
# if world_size > 1:
# dist.barrier()

# if rank == 0:
# result_df = []
# for i in tqdm(range(world_size), desc="Merging"):
# if os.path.exists(
# os.path.join(args.result_path, f"{ccconfig.mode}_{i:03d}_{world_size:03d}_origin.csv")
# ):
# result_df.append(
# pd.read_csv(
# os.path.join(args.result_path, f"{ccconfig.mode}_{i:03d}_{world_size:03d}_origin.csv")
# )
# )
# result_df = pd.concat(result_df)

### Efficient but less accurate when event pairs split into different files
# %% filter based on cc values
Expand Down Expand Up @@ -439,6 +464,41 @@ def __init__(self, config):
f"cat {args.result_path}/CC_*_{world_size:03d}_dt.cc > {args.result_path}/CC_{world_size:03d}_dt.cc"
)

if ccconfig.mode == "TM":

if len(result_df) > 0:
result_df = pd.DataFrame(result_df)
result_df.to_csv(
os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}_origin.csv"), index=False
)

t0 = result_df["origin_time"].min()
result_df["timestamp"] = result_df["origin_time"].apply(lambda x: (x - t0).total_seconds())
clustering = DBSCAN(eps=2, min_samples=3).fit(result_df[["timestamp"]].values)
result_df["event_index"] = clustering.labels_
result_df["event_time"] = result_df.groupby("event_index")["timestamp"].transform("median")
result_df["event_time"] = result_df["event_time"].apply(lambda x: t0 + pd.Timedelta(seconds=x))
result_df.sort_values(by="event_time", inplace=True)
result_df.to_csv(
os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}.csv"), index=False
)

if world_size > 1:
dist.barrier()

if rank == 0:
result_df = []
for i in tqdm(range(world_size), desc="Merging"):
if os.path.exists(os.path.join(args.result_path, f"{ccconfig.mode}_{i:03d}_{world_size:03d}.csv")):
result_df.append(
pd.read_csv(os.path.join(args.result_path, f"{ccconfig.mode}_{i:03d}_{world_size:03d}.csv"))
)
result_df = pd.concat(result_df)
result_df.sort_values(by="event_time", inplace=True)
result_df.to_csv(os.path.join(args.result_path, f"{ccconfig.mode}_{world_size:03d}.csv"), index=False)
result_df = result_df[["event_index", "event_time"]].drop_duplicates()
result_df.to_csv(os.path.join(args.result_path, f"{ccconfig.mode}_{world_size:03d}_event.csv"), index=False)

# MAX_THREADS = 32
# with h5py.File(os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}.h5"), "w") as fp:
# with ThreadPoolExecutor(max_workers=16) as executor:
Expand Down

0 comments on commit 5aa9bd6

Please sign in to comment.