diff --git a/cctorch/data.py b/cctorch/data.py index 7b67c0a..82afc01 100644 --- a/cctorch/data.py +++ b/cctorch/data.py @@ -170,9 +170,7 @@ def __init__( self.block_index = list(itertools.product(range(len(self.group1)), range(len(self.group2))))[ rank::world_size ] - print( - f"Pairs: {len(self.pair_list)}, Blocks: {len(self.group1)} x {len(self.group2)} = {len(self.block_index)}" - ) + print(f"Pairs: {len(self.pair_list)}, Blocks: {len(self.group1)} x {len(self.group2)}") print( f"data_list1: {len(self.data_list1)}, data_list2: {len(self.data_list2)}, block_size1: {block_size1}, block_size2: {block_size2}" ) @@ -533,7 +531,7 @@ def read_pair_list(file_pair_list): # data_list2 = sorted(list(set(pairs_df["event2"].tolist()))) pair_list = np.loadtxt(file_pair_list, delimiter=",", dtype=np.int64) - # pair_list = pair_list[:10_000_000] + # pair_list = pair_list[:1_000_000] data_list1 = np.unique(pair_list[:, 0]).tolist() data_list2 = np.unique(pair_list[:, 1]).tolist() pair_list = pair_list.tolist() diff --git a/run.py b/run.py index 2cdd4e7..8cd82b8 100644 --- a/run.py +++ b/run.py @@ -332,7 +332,7 @@ def __init__(self, config): # metric_logger = utils.MetricLogger(delimiter=" ") # log_freq = max(1, 10240 // args.batch_size) if args.mode == "CC" else 1 # for data in metric_logger.log_every(dataloader, log_freq, ""): - for data in tqdm(dataloader): + for data in tqdm(dataloader, position=rank, desc=f"CC {rank}/{world_size}"): result = ccmodel(data) thread = threading.Thread(