diff --git a/cctorch/data.py b/cctorch/data.py index 6e53262..7b67c0a 100644 --- a/cctorch/data.py +++ b/cctorch/data.py @@ -131,8 +131,6 @@ def __init__( ): super(CCIterableDataset).__init__() - self.pair_list, self.data_list1, self.data_list2 = self.init_pairs(pair_list, data_list1, data_list2, config) - self.mode = config.mode self.config = config self.block_size1 = block_size1 @@ -146,10 +144,10 @@ def __init__( self.device = device self.dtype = dtype self.num_batch = None + self.pair_list, self.data_list1, self.data_list2 = self.init_pairs(pair_list, data_list1, data_list2, config) if self.mode == "CC": self.symmetric = True - self.data_list2 = self.data_list1 self.data_format2 = self.data_format1 self.data_path2 = self.data_path1 @@ -172,6 +170,12 @@ def __init__( self.block_index = list(itertools.product(range(len(self.group1)), range(len(self.group2))))[ rank::world_size ] + print( + f"Pairs: {len(self.pair_list)}, Blocks: {len(self.group1)} x {len(self.group2)} = {len(self.block_index)}" + ) + print( + f"data_list1: {len(self.data_list1)}, data_list2: {len(self.data_list2)}, block_size1: {block_size1}, block_size2: {block_size2}" + ) if (self.data_format1 == "memmap") or (self.data_format2 == "memmap"): self.templates = np.memmap( @@ -414,43 +418,43 @@ def __len__(self): return self.num_batch def count_sample(self, num_workers, worker_id): - # num_samples = 0 - # for i, j in tqdm(self.block_index[worker_id::num_workers], desc="Counting batches"): - # event1, event2 = self.group1[i], self.group2[j] - # pairs = generate_pairs(event1, event2, self.config.auto_xcorr, self.symmetric) - # if self.pair_list is None: - # num_samples += (len(pairs) - 1) // self.batch_size + 1 + # if self.pair_list is not None: + # num_samples = ( + # len(self.pair_list) // min(int((self.block_size1 - 1) * (self.block_size2 - 1) / 2), self.batch_size) + # + 1 + # ) + # else: + # if self.symmetric: + # num_samples = ( + # len(self.data_list1) + # * (len(self.data_list1) - 1) + # / 2 + # // min(self.batch_size, int((self.block_size1 - 1) * (self.block_size2 - 1) / 2)) + # + 1 + # ) # else: - # tmp = 0 - # for pair in pairs: - # if pair in self.pair_list: - # tmp += 1 - # if tmp % self.batch_size == 0: - # num_samples += 1 - # tmp = 0 - # if tmp > 0: - # num_samples += 1 - if self.pair_list is not None: - num_samples = ( - len(self.pair_list) // min(int((self.block_size1 - 1) * (self.block_size2 - 1) / 2), self.batch_size) - + 1 - ) + # num_samples = ( + # len(self.data_list1) + # * len(self.data_list2) + # // min(self.batch_size, int((self.block_size1 - 1) * (self.block_size2 - 1) / 2)) + # + 1 + # ) + + if self.mode == "CC": + num_samples = 0 + for i, j in tqdm(self.block_index[worker_id::num_workers], desc="Counting batches"): + event1, event2 = self.group1[i], self.group2[j] + num = 0 + for x, y in itertools.product(event1, event2): + if (x < y) and ((x, y) in self.pair_list): + num += 1 + num_samples += (num - 1) // self.batch_size + 1 else: - if self.symmetric: - num_samples = ( - len(self.data_list1) - * (len(self.data_list1) - 1) - / 2 - // min(self.batch_size, int((self.block_size1 - 1) * (self.block_size2 - 1) / 2)) - + 1 - ) - else: - num_samples = ( - len(self.data_list1) - * len(self.data_list2) - // min(self.batch_size, int((self.block_size1 - 1) * (self.block_size2 - 1) / 2)) - + 1 - ) + num_samples = 0 + for i, j in tqdm(self.block_index[worker_id::num_workers], desc="Counting batches"): + event1, event2 = self.group1[i], self.group2[j] + num_samples += (len(event1) * len(event2) - 1) // self.batch_size + 1 + return num_samples def count_sample_ambient_noise(self, num_workers, worker_id): @@ -529,6 +533,7 @@ def read_pair_list(file_pair_list): # data_list2 = sorted(list(set(pairs_df["event2"].tolist()))) pair_list = np.loadtxt(file_pair_list, delimiter=",", dtype=np.int64) + # pair_list = pair_list[:10_000_000] data_list1 = np.unique(pair_list[:, 0]).tolist() data_list2 = np.unique(pair_list[:, 1]).tolist() pair_list = pair_list.tolist() diff --git a/run.py b/run.py index b9872b8..2cdd4e7 100644 --- a/run.py +++ b/run.py @@ -23,6 +23,7 @@ from cctorch.transforms import * from cctorch.utils import write_cc_pairs, write_results from torch.utils.data import DataLoader +from tqdm import tqdm def get_args_parser(add_help=True): @@ -328,9 +329,10 @@ def __init__(self, config): threads = [] fp = h5py.File(os.path.join(args.result_path, f"{ccconfig.mode}_{rank:03d}_{world_size:03d}.h5"), "w") - metric_logger = utils.MetricLogger(delimiter=" ") - log_freq = max(1, 10240 // args.batch_size) if args.mode == "CC" else 1 - for data in metric_logger.log_every(dataloader, log_freq, ""): + # metric_logger = utils.MetricLogger(delimiter=" ") + # log_freq = max(1, 10240 // args.batch_size) if args.mode == "CC" else 1 + # for data in metric_logger.log_every(dataloader, log_freq, ""): + for data in tqdm(dataloader): result = ccmodel(data) thread = threading.Thread(