Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Jan 28, 2024
1 parent d4dec64 commit 8441109
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 40 deletions.
78 changes: 46 additions & 32 deletions cctorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,47 +266,61 @@ def sample(self, block_index):
num = 0

for ii, jj in zip(row_matrix.data, col_matrix.data):
if ii not in local_dict:
if self.data_format1 == "memmap":
if self.data_format1 == "memmap":
if ii not in local_dict:
meta1 = {
"data": self.templates[ii],
"index": ii,
"info": {"shift_index": self.traveltime_index[ii]},
}
data = torch.tensor(meta1["data"], dtype=self.dtype).to(self.device)
if self.transforms is not None:
data = self.transforms(data)
meta1["data"] = data
local_dict[ii] = meta1
else:
meta1 = local_dict[ii]
else:
if self.data_list1[ii] not in local_dict:
meta1 = read_data(
self.data_list1[ii], self.data_path1, self.data_format1, mode=self.mode, config=self.config
)
meta1["index"] = ii
data = torch.tensor(meta1["data"], dtype=self.dtype).to(self.device)
if self.transforms is not None:
data = self.transforms(data)

meta1["data"] = data
local_dict[ii if self.data_format1 == "memmap" else self.data_list1[ii]] = meta1
else:
meta1 = local_dict[ii if self.data_format1 == "memmap" else self.data_list1[ii]]
data = torch.tensor(meta1["data"], dtype=self.dtype).to(self.device)
if self.transforms is not None:
data = self.transforms(data)
meta1["data"] = data
local_dict[self.data_list1[ii]] = meta1
else:
meta1 = local_dict[self.data_list1[ii]]

if jj not in local_dict:
if self.data_format2 == "memmap":
if self.data_format2 == "memmap":
if jj not in local_dict:
meta2 = {
"data": self.templates[jj],
"index": jj,
"info": {"shift_index": self.traveltime_index[jj]},
}
data = torch.tensor(meta2["data"], dtype=self.dtype).to(self.device)
if self.transforms is not None:
data = self.transforms(data)
meta2["data"] = data
local_dict[jj] = meta2
else:
meta2 = local_dict[jj]
else:
if self.data_list2[jj] not in local_dict:
meta2 = read_data(
self.data_list2[jj], self.data_path2, self.data_format2, mode=self.mode, config=self.config
)
meta2["index"] = jj
data = torch.tensor(meta2["data"], dtype=self.dtype).to(self.device)
if self.transforms is not None:
data = self.transforms(data)

meta2["data"] = data
local_dict[jj if self.data_format2 == "memmap" else self.data_list2[jj]] = meta2
else:
meta2 = local_dict[jj if self.data_format2 == "memmap" else self.data_list2[jj]]
data = torch.tensor(meta2["data"], dtype=self.dtype).to(self.device)
if self.transforms is not None:
data = self.transforms(data)
meta2["data"] = data
local_dict[self.data_list2[jj]] = meta2
else:
meta2 = local_dict[self.data_list2[jj]]

data1.append(meta1["data"])
index1.append(meta1["index"])
Expand All @@ -319,12 +333,12 @@ def sample(self, block_index):
if num == self.batch_size:
data_batch1 = torch.stack(data1)
data_batch2 = torch.stack(data2)
if (
(self.mode == "TM")
and (data_batch2.shape[1] != data_batch1.shape[1])
and (data_batch2.shape[1] % data_batch1.shape[1] == 0)
):
data_batch1 = data_batch1.repeat(1, data_batch2.shape[1] // data_batch1.shape[1], 1, 1)
# if (
# (self.mode == "TM")
# and (data_batch2.shape[1] != data_batch1.shape[1])
# and (data_batch2.shape[1] % data_batch1.shape[1] == 0)
# ):
# data_batch1 = data_batch1.repeat(1, data_batch2.shape[1] // data_batch1.shape[1], 1, 1)

info_batch1 = {k: [x[k] for x in info1] for k in info1[0].keys()}
info_batch2 = {k: [x[k] for x in info2] for k in info2[0].keys()}
Expand All @@ -349,12 +363,12 @@ def sample(self, block_index):
if num > 0:
data_batch1 = torch.stack(data1)
data_batch2 = torch.stack(data2)
if (
(self.mode == "TM")
and (data_batch2.shape[1] != data_batch1.shape[1])
and (data_batch2.shape[1] % data_batch1.shape[1] == 0)
):
data_batch1 = data_batch1.repeat(1, data_batch2.shape[1] // data_batch1.shape[1], 1, 1)
# if (
# (self.mode == "TM")
# and (data_batch2.shape[1] != data_batch1.shape[1])
# and (data_batch2.shape[1] % data_batch1.shape[1] == 0)
# ):
# data_batch1 = data_batch1.repeat(1, data_batch2.shape[1] // data_batch1.shape[1], 1, 1)
info_batch1 = {k: [x[k] for x in info1] for k in info1[0].keys()}
info_batch2 = {k: [x[k] for x in info2] for k in info2[0].keys()}
if "shift_index" in info_batch1:
Expand Down
7 changes: 4 additions & 3 deletions cctorch/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,10 @@ def forward(self, meta):
keep = (smax == xcorr).float()
topk_scores, topk_inds = torch.topk(xcorr * keep, self.K, sorted=True)

top_inds = topk_inds[..., 0]
nearby = torch.stack([top_inds - 1, top_inds, top_inds + 1], dim=-1).clamp(0, xcorr.shape[-1] - 1)
meta["neighbor_score"] = torch.gather(xcorr, -1, nearby)
if self.K == 3: # CC pairs
top_inds = topk_inds[..., 0]
nearby = torch.stack([top_inds - 1, top_inds, top_inds + 1], dim=-1).clamp(0, xcorr.shape[-1] - 1)
meta["neighbor_score"] = torch.gather(xcorr, -1, nearby)

meta["topk_score"] = topk_scores
meta["topk_index"] = topk_inds - nlag
Expand Down
8 changes: 3 additions & 5 deletions cctorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,13 @@ def write_tm_detects(results, fp, ccconfig, lock=nullcontext(), plot_figure=Fals
for meta in results:
topk_index = meta["topk_index"].numpy()
topk_score = meta["topk_score"].numpy()
neighbor_score = meta["neighbor_score"].numpy()
pair_index = meta["pair_index"]

nb, nch, nx, nk = topk_index.shape

for i in range(nb):
if topk_score[i].max() < ccconfig.min_cc_score:
continue
select_index = np.where(topk_score[i] >= ccconfig.min_cc_score)

pair_id = pair_index[i]
id1, id2 = pair_id
Expand All @@ -86,9 +84,9 @@ def write_tm_detects(results, fp, ccconfig, lock=nullcontext(), plot_figure=Fals
gp = fp[f"{id1}/{id2}"]

with lock:
gp.create_dataset(f"cc_index", data=topk_index[i, ..., select_index])
gp.create_dataset(f"cc_score", data=topk_score[i, ..., select_index])
gp.create_dataset(f"neighbor_score", data=neighbor_score[i, ..., select_index])
idx = np.where(topk_score[i] >= ccconfig.min_cc_score)
gp.create_dataset(f"cc_index", data=topk_index[i][idx])
gp.create_dataset(f"cc_score", data=topk_score[i][idx])

return 0

Expand Down
2 changes: 2 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,11 @@ def __init__(self, config):
if args.mode == "CC":
thread = executor.submit(write_cc_pairs, [result], fp, ccconfig, lock)
futures.add(thread)
# write_cc_pairs([result], fp, ccconfig, lock)
if args.mode == "TM":
thread = executor.submit(write_tm_detects, [result], fp, ccconfig, lock)
futures.add(thread)
# write_tm_detects([result], fp, ccconfig, lock)
if len(futures) >= MAX_THREADS:
done, futures = wait(futures, return_when=FIRST_COMPLETED)
executor.shutdown(wait=True)
Expand Down

0 comments on commit 8441109

Please sign in to comment.