Skip to content

Commit

Permalink
Add LimitedThreadPoolExecutor class and lock to write_cc_pairs function
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Jan 5, 2024
1 parent 4012439 commit a7bf5ec
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 214 deletions.
187 changes: 57 additions & 130 deletions cctorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn.functional as F
import torchaudio
from scipy.sparse import coo_matrix
from torch.utils.data import Dataset, IterableDataset
from tqdm import tqdm

Expand Down Expand Up @@ -144,8 +145,8 @@ def __init__(
self.device = device
self.dtype = dtype
self.num_batch = None
self.pair_list, self.data_list1, self.data_list2 = self.init_pairs(pair_list, data_list1, data_list2, config)

self.pair_matrix, self.row_matrix, self.col_matrix, unique_row, unique_col = self.read_pairs(pair_list)
if self.mode == "CC":
self.symmetric = True
self.data_format2 = self.data_format1
Expand All @@ -155,27 +156,16 @@ def __init__(
## For ambient noise, we split chunks in the sampling function
self.data_list1 = self.data_list1[rank::world_size]
else:
block_num1 = int(np.ceil(len(self.data_list1) / block_size1))
block_num2 = int(np.ceil(len(self.data_list2) / block_size2))
self.group1 = [list(x) for x in np.array_split(self.data_list1, block_num1) if len(x) > 0]
self.group2 = [list(x) for x in np.array_split(self.data_list2, block_num2) if len(x) > 0]
# self.block_index = generate_block_index(
# self.group1,
# self.group2,
# auto_xcorr=config.auto_xcorr,
# pair_list=self.pair_list,
# symmetric=self.symmetric,
# min_sample_per_block=1,
# )[rank::world_size]
# self.block_index = list(itertools.product(range(len(self.group1)), range(len(self.group2))))[
# rank::world_size
# ]
blocks = list(itertools.product(range(len(self.group1)), range(len(self.group2))))
self.block_index = self.filt_empty_block(blocks)[rank::world_size]

print(f"Pairs: {len(self.pair_list)}, Blocks: {len(self.group1)} x {len(self.group2)}")
block_num1 = int(np.ceil(len(unique_row) / block_size1))
block_num2 = int(np.ceil(len(unique_col) / block_size2))
self.group1 = [list(x) for x in np.array_split(unique_row, block_num1) if len(x) > 0]
self.group2 = [list(x) for x in np.array_split(unique_col, block_num2) if len(x) > 0]

blocks = list(itertools.product(range(len(self.group1)), range(len(self.group2))))[rank::world_size]
self.block_index, self.num_batch = self.count_blocks(blocks)

print(
f"data_list1: {len(self.data_list1)}, data_list2: {len(self.data_list2)}, block_size1: {block_size1}, block_size2: {block_size2}"
f"pair_matrix: {self.pair_matrix.shape}, blocks: {len(self.block_index)}, block_size: {self.block_size1} x {self.block_size2}"
)

if (self.data_format1 == "memmap") or (self.data_format2 == "memmap"):
Expand All @@ -195,6 +185,47 @@ def __init__(
config.station_index_file, header=None, names=["index", "station_id", "component"], index_col=0
)

def read_pairs(self, pair_list):
"""
Assume pair_list is a list of pairs of event indices
"""
pair_list = np.loadtxt(pair_list, delimiter=",", dtype=np.int64)
# For TEST
# pair_list = np.array(list(itertools.product(range(6000), range(6000))))
# pair_list = pair_list[pair_list[:, 0] < pair_list[:, 1]]
# pair_list = pair_list[pair_list[:, 1] - pair_list[:, 0] < 10]
unique_row = np.sort(np.unique(pair_list[:, 0]))
unique_col = np.sort(np.unique(pair_list[:, 1]))
print(f"Number of pairs: {len(pair_list)}, list1: {len(unique_row)}, list2: {len(unique_col)}")

rows, cols = pair_list[:, 0], pair_list[:, 1]
data = [True] * len(pair_list)
shape = (max(rows) + 1, max(cols) + 1)
pair_matrix = coo_matrix((data, (rows, cols)), shape=shape, dtype=bool)
pair_matrix = pair_matrix.tocsr()

row_index = coo_matrix((rows, (rows, cols)), shape=shape, dtype=int)
row_index = row_index.tocsr()
col_index = coo_matrix((cols, (rows, cols)), shape=shape, dtype=int)
col_index = col_index.tocsr()

return pair_matrix, row_index, col_index, unique_row, unique_col

def count_blocks(self, blocks):
num_batch = 0
non_empty = []
for i, j in tqdm(blocks, desc="Counting batch"):
index1, index2 = self.group1[i], self.group2[j]
count = (self.pair_matrix[index1, :][:, index2]).sum()
if count > 0:
non_empty.append((i, j))
num_batch += (count - 1) // self.batch_size + 1

return non_empty, num_batch

def __len__(self):
return self.num_batch

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
Expand All @@ -209,33 +240,17 @@ def __iter__(self):
else:
return iter(self.sample(self.block_index[worker_id::num_workers]))

def init_pairs(self, pair_list, data_list1, data_list2, config):
if data_list1 is not None:
data_list1 = pd.unique(pd.read_csv(data_list1, header=None)[0]).tolist()

if data_list2 is not None:
data_list2 = pd.unique(pd.read_csv(data_list2, header=None)[0]).tolist()
else:
data_list2 = data_list1

if pair_list is not None:
pair_list, data_list1, data_list2 = read_pair_list(pair_list)

return pair_list, data_list1, data_list2

def sample(self, block_index):
for i, j in block_index:
local_dict = {}
event1, event2 = self.group1[i], self.group2[j]
pairs = generate_pairs(event1, event2, self.config.auto_xcorr, self.symmetric)
row_index, col_index = self.group1[i], self.group2[j]
row_matrix = self.row_matrix[row_index, :][:, col_index].tocoo()
col_matrix = self.col_matrix[row_index, :][:, col_index].tocoo()

data1, index1, info1, data2, index2, info2 = [], [], [], [], [], []
num = 0

for ii, jj in pairs:
if self.pair_list is not None:
if (ii, jj) not in self.pair_list:
continue

for ii, jj in zip(row_matrix.data, col_matrix.data):
if ii not in local_dict:
if self.data_format1 == "memmap":
meta1 = {
Expand Down Expand Up @@ -400,77 +415,6 @@ def sample_ambient_noise(self, data_list):
"info": {},
}, {"data": data_j, "index": [index_j], "info": {}}

def __len__(self):
if self.num_batch is None:
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
num_workers = 1
worker_id = 0
else:
num_workers = worker_info.num_workers
worker_id = worker_info.id

if self.mode == "AN":
num_samples = self.count_sample_ambient_noise(num_workers, worker_id)
else:
num_samples = self.count_sample(num_workers, worker_id)
self.num_batch = num_samples

return self.num_batch

def count_sample(self, num_workers, worker_id):
# if self.pair_list is not None:
# num_samples = (
# len(self.pair_list) // min(int((self.block_size1 - 1) * (self.block_size2 - 1) / 2), self.batch_size)
# + 1
# )
# else:
# if self.symmetric:
# num_samples = (
# len(self.data_list1)
# * (len(self.data_list1) - 1)
# / 2
# // min(self.batch_size, int((self.block_size1 - 1) * (self.block_size2 - 1) / 2))
# + 1
# )
# else:
# num_samples = (
# len(self.data_list1)
# * len(self.data_list2)
# // min(self.batch_size, int((self.block_size1 - 1) * (self.block_size2 - 1) / 2))
# + 1
# )

if self.mode == "CC":
num_samples = 0
for i, j in tqdm(self.block_index[worker_id::num_workers], desc="Counting batches"):
event1, event2 = self.group1[i], self.group2[j]
num = 0
for x, y in itertools.product(event1, event2):
if (x < y) and ((x, y) in self.pair_list):
num += 1
num_samples += (num - 1) // self.batch_size + 1
else:
num_samples = 0
for i, j in tqdm(self.block_index[worker_id::num_workers], desc="Counting batches"):
event1, event2 = self.group1[i], self.group2[j]
num_samples += (len(event1) * len(event2) - 1) // self.batch_size + 1

return num_samples

def filt_empty_block(self, blocks):
non_empty_blocks = []
if self.mode == "CC":
for i, j in tqdm(blocks, desc="Filtering empty blocks"):
event1, event2 = self.group1[i], self.group2[j]
for x, y in itertools.product(event1, event2):
if (x < y) and ((x, y) in self.pair_list):
non_empty_blocks.append((i, j))
break
else:
non_empty_blocks = blocks
return non_empty_blocks

def count_sample_ambient_noise(self, num_workers, worker_id):
num_samples = 0
for fd in self.data_list1:
Expand Down Expand Up @@ -540,23 +484,6 @@ def generate_pairs(event1, event2, auto_xcorr=False, symmetric=False):
return pairs


def read_pair_list(file_pair_list):
# read pair ids from a text file
# pairs_df = pd.read_csv(file_pair_list, header=None, names=["event1", "event2"])
# # pair_list = {(x["event1"], x["event2"]) for _, x in pairs_df.iterrows()}
# pair_list = pairs_df[["event1", "event2"]].values.tolist()
# data_list1 = sorted(list(set(pairs_df["event1"].tolist())))
# data_list2 = sorted(list(set(pairs_df["event2"].tolist())))

pair_list = np.loadtxt(file_pair_list, delimiter=",", dtype=np.int64)
# pair_list = pair_list[:1_000_000]
data_list1 = np.unique(pair_list[:, 0]).tolist()
data_list2 = np.unique(pair_list[:, 1]).tolist()
pair_list = pair_list.tolist()
pair_list = set(map(tuple, pair_list))
return pair_list, data_list1, data_list2


def generate_block_index(group1, group2, pair_list=None, auto_xcorr=False, symmetric=False, min_sample_per_block=1):
block_index = [(i, j) for i in range(len(group1)) for j in range(len(group2))]
num_empty_index = []
Expand Down
21 changes: 13 additions & 8 deletions cctorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import multiprocessing as mp
import os
from contextlib import nullcontext
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from multiprocessing import shared_memory
Expand Down Expand Up @@ -48,7 +49,7 @@ def write_tm_events(results, result_path, ccconfig, rank=0, world_size=1):
events.to_csv(result_path / f"cctorch_events_{rank:03d}_{world_size:03d}.csv", index=False)


def write_cc_pairs(results, fp, ccconfig, plot_figure=False):
def write_cc_pairs(results, fp, ccconfig, lock=nullcontext(), plot_figure=False):
"""
Write cross-correlation results to disk.
Parameters
Expand Down Expand Up @@ -86,16 +87,18 @@ def write_cc_pairs(results, fp, ccconfig, plot_figure=False):
topk_index = -topk_index

if f"{id1}/{id2}" not in fp:
gp = fp.create_group(f"{id1}/{id2}")
with lock:
gp = fp.create_group(f"{id1}/{id2}")
else:
gp = fp[f"{id1}/{id2}"]

gp.create_dataset(f"cc_index", data=topk_index[i])
gp.create_dataset(f"cc_score", data=topk_score[i])
gp.create_dataset(f"cc_diff", data=cc_diff[i])
gp.create_dataset(f"neighbor_score", data=neighbor_score[i])
if cc_sum is not None:
gp.create_dataset(f"cc_sum", data=cc_sum[i])
with lock:
gp.create_dataset(f"cc_index", data=topk_index[i])
gp.create_dataset(f"cc_score", data=topk_score[i])
gp.create_dataset(f"cc_diff", data=cc_diff[i])
gp.create_dataset(f"neighbor_score", data=neighbor_score[i])
if cc_sum is not None:
gp.create_dataset(f"cc_sum", data=cc_sum[i])

# if id2 != id1:
# fp[f"{id2}/{id1}"] = h5py.SoftLink(f"/{id1}/{id2}")
Expand Down Expand Up @@ -124,6 +127,8 @@ def write_cc_pairs(results, fp, ccconfig, plot_figure=False):
print(f"debug/test_{pair_id[0]}_{pair_id[1]}_{j}.png")
plt.close(fig)

return 0

# with h5py.File(result_path / f"{ccconfig.mode}_{rank:03d}_{world_size:03d}.h5", "a") as fp:
# for meta in results:
# topk_index = meta["topk_index"]
Expand Down
Loading

0 comments on commit a7bf5ec

Please sign in to comment.