From 8aa96264e5a0d2c411999c4e716be6dfaa3e3cd0 Mon Sep 17 00:00:00 2001 From: zhuwq Date: Sun, 24 Sep 2023 16:32:14 -0700 Subject: [PATCH] template matching works --- cctorch/data.py | 235 +++++++++++++++++++++++++++++------------- cctorch/model.py | 60 +++++++---- cctorch/transforms.py | 157 +++++++++++++++++++--------- cctorch/utils.py | 89 ++++++++++------ run.py | 103 ++++++++++++------ tests/test_qtm.py | 6 +- 6 files changed, 442 insertions(+), 208 deletions(-) diff --git a/cctorch/data.py b/cctorch/data.py index 481c402..b305bb3 100644 --- a/cctorch/data.py +++ b/cctorch/data.py @@ -11,6 +11,7 @@ import itertools import matplotlib.pyplot as plt from tqdm import tqdm +import obspy class CCDataset(Dataset): @@ -20,8 +21,10 @@ def __init__( pair_list=None, data_list1=None, data_list2=None, - data_path="./", - data_format="h5", + data_path1="./", + data_path2="./", + data_format1="h5", + data_format2="h5", block_size1=1, block_size2=1, device="cpu", @@ -55,13 +58,14 @@ def __init__( self.group1 = [list(x) for x in np.array_split(self.data_list1, block_num1) if len(x) > 0] self.group2 = [list(x) for x in np.array_split(self.data_list2, block_num2) if len(x) > 0] self.block_index = generate_block_index(self.group1, self.group2, self.pair_list) - self.data_path = Path(data_path) - self.data_format = data_format + self.data_path1 = Path(data_path1) + self.data_path2 = Path(data_path2) + self.data_format1 = data_format1 + self.data_format2 = data_format2 self.transforms = transforms self.device = device def __getitem__(self, idx): - i, j = self.block_index[idx] event1, event2 = self.group1[i], self.group2[j] @@ -73,14 +77,14 @@ def __getitem__(self, idx): continue if event1[ii] not in index_dict: - data_dict = read_data(event1[ii], self.data_path, self.data_format) + data_dict = read_data(event1[ii], self.data_path1, self.data_format1) data.append(torch.tensor(data_dict["data"])) info.append(data_dict["info"]) index_dict[event1[ii]] = len(data) - 1 idx1 = index_dict[event1[ii]] if event2[jj] not in index_dict: - data_dict = read_data(event2[jj], self.data_path, self.data_format) + data_dict = read_data(event2[jj], self.data_path2, self.data_format2) data.append(torch.tensor(data_dict["data"])) info.append(data_dict["info"]) index_dict[event2[jj]] = len(data) - 1 @@ -111,8 +115,10 @@ def __init__( pair_list=None, data_list1=None, data_list2=None, - data_path="./", - data_format="h5", + data_path1="./", + data_path2="./", + data_format1="h5", + data_format2="h5", block_size1=1, block_size2=1, dtype=torch.float32, @@ -131,29 +137,52 @@ def __init__( self.config = config self.block_size1 = block_size1 self.block_size2 = block_size2 - self.data_path = Path(data_path) - self.data_format = data_format + self.data_path1 = Path(data_path1) + self.data_path2 = Path(data_path2) + self.data_format1 = data_format1 + self.data_format2 = data_format2 self.transforms = transforms self.batch_size = batch_size self.device = device self.dtype = dtype self.num_batch = None + self.symmetric = True if self.mode == "CC" else False if self.mode == "AN": - ## For ambient noise, we split chunks in the sampling function + ## For ambient noise, we split chunks in the sampling function self.data_list1 = self.data_list1[rank::world_size] else: block_num1 = int(np.ceil(len(self.data_list1) / block_size1)) block_num2 = int(np.ceil(len(self.data_list2) / block_size2)) self.group1 = [list(x) for x in np.array_split(self.data_list1, block_num1) if len(x) > 0] self.group2 = [list(x) for x in np.array_split(self.data_list2, block_num2) if len(x) > 0] - self.block_index = generate_block_index(self.group1, self.group2, auto_xcorr=config.auto_xcorr, pair_list=self.pair_list, min_sample_per_block=1)[rank::world_size] - - if self.data_format == "memmap": - self.ndarray = np.memmap(self.data_path, dtype=np.float32, mode="r", shape=tuple(config.template_shape)) + self.block_index = generate_block_index( + self.group1, + self.group2, + auto_xcorr=config.auto_xcorr, + pair_list=self.pair_list, + symmetric=self.symmetric, + min_sample_per_block=1, + )[rank::world_size] + + if (self.data_format1 == "memmap") or (self.data_format2 == "memmap"): + self.templates = np.memmap( + config.template_file, + dtype=np.float32, + mode="r", + shape=tuple(config.template_shape), + ) + self.traveltime_index = np.memmap( + config.traveltime_index_file, + dtype=np.int32, + mode="r", + shape=tuple(config.traveltime_shape), + ) + config.stations = pd.read_csv( + config.station_index_file, header=None, names=["index", "station_id", "component"], index_col=0 + ) def __iter__(self): - worker_info = torch.utils.data.get_worker_info() if worker_info is None: num_workers = 1 @@ -167,9 +196,7 @@ def __iter__(self): else: return iter(self.sample(self.block_index[worker_id::num_workers])) - def init_pairs(self, pair_list, data_list1, data_list2, config): - if data_list1 is not None: data_list1 = pd.unique(pd.read_csv(data_list1, header=None)[0]).tolist() @@ -183,28 +210,29 @@ def init_pairs(self, pair_list, data_list1, data_list2, config): return pair_list, data_list1, data_list2 - def sample(self, block_index): - for i, j in block_index: local_dict = {} event1, event2 = self.group1[i], self.group2[j] - pairs = generate_pairs(event1, event2, self.config.auto_xcorr) + pairs = generate_pairs(event1, event2, self.config.auto_xcorr, self.symmetric) data1, index1, info1, data2, index2, info2 = [], [], [], [], [], [] num = 0 - for (ii, jj) in pairs: - + for ii, jj in pairs: if self.pair_list is not None: - if (ii, jj) not in self.pair_list: continue if ii not in local_dict: - if self.data_format == "memmap": - meta1 = {"data": self.ndarray[ii], "index":ii, "info": {}} + if self.data_format1 == "memmap": + meta1 = { + "data": self.templates[ii], + "index": ii, + "info": {"shift_index": self.traveltime_index[ii]}, + } else: - meta1 = read_data(ii, self.data_path, self.data_format) + meta1 = read_data(ii, self.data_path1, self.data_format1, mode=self.mode, config=self.config) + meta1["index"] = ii data = torch.tensor(meta1["data"], dtype=self.dtype).to(self.device) if self.transforms is not None: data = self.transforms(data) @@ -215,10 +243,15 @@ def sample(self, block_index): meta1 = local_dict[ii] if jj not in local_dict: - if self.data_format == "memmap": - meta2 = {"data": self.ndarray[jj], "index": jj, "info": {}} + if self.data_format2 == "memmap": + meta2 = { + "data": self.templates[jj], + "index": jj, + "info": {"shift_index": self.traveltime_index[jj]}, + } else: - meta2 = read_data(jj, self.data_path, self.data_format) + meta2 = read_data(jj, self.data_path2, self.data_format2, mode=self.mode, config=self.config) + meta2["index"] = jj data = torch.tensor(meta2["data"], dtype=self.dtype).to(self.device) if self.transforms is not None: data = self.transforms(data) @@ -237,14 +270,31 @@ def sample(self, block_index): num += 1 if num == self.batch_size: - data_batch1 = torch.stack(data1) data_batch2 = torch.stack(data2) - # info_batch1 = {k: [x[k] for x in info1] for k in info1[0].keys()} - # info_batch2 = {k: [x[k] for x in info2] for k in info2[0].keys()} - # yield {"data": data_batch1, "info": info_batch1}, {"data": data_batch2, "info": info_batch2} - yield {"data": data_batch1, "index": index1, "info": info1}, {"data": data_batch2, "index": index2, "info": info2} - + if ( + (self.mode == "TM") + and (data_batch2.shape[1] != data_batch1.shape[1]) + and (data_batch2.shape[1] % data_batch1.shape[1] == 0) + ): + data_batch1 = data_batch1.repeat(1, data_batch2.shape[1] // data_batch1.shape[1], 1, 1) + + info_batch1 = {k: [x[k] for x in info1] for k in info1[0].keys()} + info_batch2 = {k: [x[k] for x in info2] for k in info2[0].keys()} + if "shift_index" in info_batch1: + info_batch1["shift_index"] = torch.tensor( + np.stack(info_batch1["shift_index"]), dtype=torch.int64 + ) + if "shift_index" in info_batch2: + info_batch2["shift_index"] = torch.tensor( + np.stack(info_batch2["shift_index"]), dtype=torch.int64 + ) + yield {"data": data_batch1, "index": index1, "info": info_batch1}, { + "data": data_batch2, + "index": index2, + "info": info_batch2, + } + num = 0 data1, index1, info1, data2, index2, info2 = [], [], [], [], [], [] @@ -252,19 +302,29 @@ def sample(self, block_index): if num > 0: data_batch1 = torch.stack(data1) data_batch2 = torch.stack(data2) - # info_batch1 = {k: [x[k] for x in info1] for k in info1[0].keys()} - # info_batch2 = {k: [x[k] for x in info2] for k in info2[0].keys()} - - # yield {"data": data_batch1, "info": info_batch1}, {"data": data_batch2, "info": info_batch2} - yield {"data": data_batch1,"index": index1, "info": info1}, {"data": data_batch2, "index": index2, "info": info2} + if ( + (self.mode == "TM") + and (data_batch2.shape[1] != data_batch1.shape[1]) + and (data_batch2.shape[1] % data_batch1.shape[1] == 0) + ): + data_batch1 = data_batch1.repeat(1, data_batch2.shape[1] // data_batch1.shape[1], 1, 1) + info_batch1 = {k: [x[k] for x in info1] for k in info1[0].keys()} + info_batch2 = {k: [x[k] for x in info2] for k in info2[0].keys()} + if "shift_index" in info_batch1: + info_batch1["shift_index"] = torch.tensor(np.stack(info_batch1["shift_index"]), dtype=torch.int64) + if "shift_index" in info_batch2: + info_batch2["shift_index"] = torch.tensor(np.stack(info_batch2["shift_index"]), dtype=torch.int64) + yield {"data": data_batch1, "index": index1, "info": info_batch1}, { + "data": data_batch2, + "index": index2, + "info": info_batch2, + } def sample_ambient_noise(self, data_list): - for fd in data_list: - - meta = read_data(fd, self.data_path, self.data_format, mode=self.mode) # (nch, nt) + meta = read_data(fd, self.data_path1, self.data_format1, mode=self.mode) # (nch, nt) data = meta["data"].float().unsqueeze(0).unsqueeze(0) # (1, 1, nx, nt) - + if (self.config.transform_on_file) and (self.transforms is not None): data = self.transforms(data) @@ -293,7 +353,7 @@ def sample_ambient_noise(self, data_list): ) else: ## using delta_channel to down-sample channels needed for ambient noise - ## using min_channel and max_channel to selected channels that are within a range + ## using min_channel and max_channel to selected channels that are within a range lists_1 = range(min_channel, max_channel, self.config.delta_channel) lists_2 = range(min_channel, max_channel, self.config.delta_channel) block_num1 = int(np.ceil(len(lists_1) / self.block_size1)) @@ -316,18 +376,19 @@ def sample_ambient_noise(self, data_list): data_i = data[:, :, index_i, :].to(self.device) data_j = data[:, :, index_j, :].to(self.device) - + if (self.config.transform_on_batch) and (self.transforms is not None): data_i = self.transforms(data_i) data_j = self.transforms(data_j) - yield {"data": data_i, "index": [index_i], "info": {},}, {"data": data_j, "index": [index_j], "info": {}} - + yield { + "data": data_i, + "index": [index_i], + "info": {}, + }, {"data": data_j, "index": [index_j], "info": {}} def __len__(self): - if self.num_batch is None: - worker_info = torch.utils.data.get_worker_info() if worker_info is None: num_workers = 1 @@ -344,15 +405,13 @@ def __len__(self): return self.num_batch - - def count_sample(self, num_workers, worker_id): num_samples = 0 for i, j in tqdm(self.block_index[worker_id::num_workers], desc="Counting batches"): event1, event2 = self.group1[i], self.group2[j] - pairs = generate_pairs(event1, event2, self.config.auto_xcorr) + pairs = generate_pairs(event1, event2, self.config.auto_xcorr, self.symmetric) if self.pair_list is None: - num_samples += (len(pairs)-1)//self.batch_size + 1 + num_samples += (len(pairs) - 1) // self.batch_size + 1 else: tmp = 0 for pair in pairs: @@ -365,13 +424,10 @@ def count_sample(self, num_workers, worker_id): num_samples += 1 return num_samples - def count_sample_ambient_noise(self, num_workers, worker_id): - num_samples = 0 for fd in self.data_list1: - - nx, nt = get_shape_das_continous_data_h5(self.data_path / fd) # (nch, nt) + nx, nt = get_shape_das_continuous_data_h5(self.data_path1 / fd) # (nch, nt) ## cut blocks min_channel = self.config.min_channel if self.config.min_channel is not None else 0 @@ -388,7 +444,7 @@ def count_sample_ambient_noise(self, num_workers, worker_id): ) else: ## using delta_channel to down-sample channels needed for ambient noise - ## using min_channel and max_channel to selected channels that are within a range + ## using min_channel and max_channel to selected channels that are within a range lists_1 = range(min_channel, max_channel, self.config.delta_channel) lists_2 = range(min_channel, max_channel, self.config.delta_channel) block_num1 = int(np.ceil(len(lists_1) / self.block_size1)) @@ -396,7 +452,7 @@ def count_sample_ambient_noise(self, num_workers, worker_id): group_1 = [list(x) for x in np.array_split(lists_1, block_num1) if len(x) > 0] group_2 = [list(x) for x in np.array_split(lists_2, block_num2) if len(x) > 0] block_index = list(itertools.product(range(len(group_1)), range(len(group_2)))) - + ## loop each block for i, j in block_index: num_samples += 1 @@ -405,7 +461,6 @@ def count_sample_ambient_noise(self, num_workers, worker_id): def generate_pairs(event1, event2, auto_xcorr=False, symmetric=True): - xcor_offset = 0 if auto_xcorr else 1 event1 = set(event1) @@ -431,7 +486,7 @@ def generate_pairs(event1, event2, auto_xcorr=False, symmetric=True): pairs += [(evt1, evt2) for evt1 in event_inner for evt2 in event_outer2 if condition(evt1, evt2)] if len(event_outer1) > 0 and len(event_outer2) > 0: pairs += [(evt1, evt2) for evt1 in event_outer1 for evt2 in event_outer2 if condition(evt1, evt2)] - + # print(f"Total number of pairs: {len(pairs)}") return pairs @@ -445,13 +500,13 @@ def read_pair_list(file_pair_list): return pair_list, data_list1, data_list2 -def generate_block_index(group1, group2, pair_list=None, auto_xcorr=False, min_sample_per_block=1): +def generate_block_index(group1, group2, pair_list=None, auto_xcorr=False, symmetric=False, min_sample_per_block=1): block_index = [(i, j) for i in range(len(group1)) for j in range(len(group2))] num_empty_index = [] for i, j in tqdm(block_index, desc="Generating blocks"): num_samples = 0 event1, event2 = group1[i], group2[j] - pairs = generate_pairs(event1, event2, auto_xcorr=auto_xcorr) + pairs = generate_pairs(event1, event2, auto_xcorr=auto_xcorr, symmetric=symmetric) if pair_list is None: num_samples = len(pairs) else: @@ -463,8 +518,7 @@ def generate_block_index(group1, group2, pair_list=None, auto_xcorr=False, min_s return num_empty_index -def read_data(file_name, data_path, format="h5", mode="CC"): - +def read_data(file_name, data_path, format="h5", mode="CC", config={}): if mode == "CC": if format == "h5": data_list, info_list = read_das_eventphase_data_h5( @@ -473,12 +527,48 @@ def read_data(file_name, data_path, format="h5", mode="CC"): ## TODO: check with Jiaxuan; why do we need to read a list but return the first one data = data_list[0] info = info_list[0] - + elif mode == "AN": - if format == "h5": + if format == "h5": data, info = read_das_continuous_data_h5(data_path / file_name, dataset_keys=[]) - return {"data": torch.tensor(data), "info": info} + elif mode == "TM": + if format == "mseed": + data, info = read_mseed(file_name, config.stations, config) + else: + raise ValueError(f"Unknown mode: {mode}") + + return {"data": data, "info": info} + + +def read_mseed(file_name, stations, config): + meta = obspy.read(file_name) + meta.merge(fill_value="latest") + for tr in meta: + if tr.stats.sampling_rate != config.sampling_rate: + tr.resample(config.sampling_rate) + begin_time = min([tr.stats.starttime for tr in meta]) + end_time = max([tr.stats.endtime for tr in meta]) + meta.detrend("constant") + meta.trim(begin_time, end_time, pad=True, fill_value=0) + nt = meta[0].stats.npts + data = np.zeros([3, len(stations), nt]) + component_mapping = {"1": 2, "2": 1, "3": 0, "E": 0, "N": 1, "Z": 2} + for i, sta in stations.iterrows(): + if len(sta["component"]) == 3: + for j, c in enumerate(sta["component"]): + st = meta.select(id=f"{sta['station_id']}{c}") + data[j, i, :] = st[0].data + else: + j = component_mapping[sta["component"]] + st = meta.select(id=f"{sta['station_id']}{c}") + data[j, i, :] = st[0].data + + return data, { + "begin_time": begin_time.datetime, + "end_time": end_time.datetime, + "station_id": stations["station_id"].tolist(), + } def read_das_continuous_data_h5(fn, dataset_keys=[]): @@ -494,7 +584,8 @@ def read_das_continuous_data_h5(fn, dataset_keys=[]): info[key] = f[key][:] return data, info -def get_shape_das_continous_data_h5(file): + +def get_shape_das_continuous_data_h5(file): with h5py.File(file, "r") as f: if "Data" in f: data_shape = f["Data"].shape diff --git a/cctorch/model.py b/cctorch/model.py index ee4d535..a67cd00 100644 --- a/cctorch/model.py +++ b/cctorch/model.py @@ -20,6 +20,7 @@ def __init__( self.nlag = config.nlag self.nma = config.nma self.channel_shift = config.channel_shift + self.reduce_t = config.reduce_t self.reduce_x = config.reduce_x self.domain = config.domain @@ -32,6 +33,10 @@ def __init__( self.to_device = to_device self.device = device + # TM + self.shift_t = config.shift_t + self.normalize = config.normalize + def forward(self, x): """Perform cross-correlation on input data Args: @@ -52,7 +57,6 @@ def forward(self, x): data1 = x1["data"] data2 = x2["data"] - if self.domain == "frequency": # xcorr with fft in frequency domain nfast = (data1.shape[-1] - 1) * 2 @@ -70,17 +74,38 @@ def forward(self, x): elif self.domain == "time": ## using conv1d in time domain nb1, nc1, nx1, nt1 = data1.shape - data1 = data1.view(1, nb1 * nc1 * nx1, nt1) nb2, nc2, nx2, nt2 = data2.shape - data2 = data2.view(nb2 * nc2 * nx2, 1, nt2) - if self.channel_shift != 0: - xcor = F.conv1d( - data1, torch.roll(data2, self.channel_shift, dims=-2), padding=self.nlag, groups=nb1 * nc1 * nx1 + + if self.shift_t: + nt_index = torch.arange(nt1).unsqueeze(0).unsqueeze(0).unsqueeze(0) + + shift_index = x2["info"]["shift_index"] + shift_index = shift_index.repeat_interleave( + 3, dim=1 + ) # repeat to match three channels for P and S wave templates + shift_index = (nt_index + shift_index.unsqueeze(-1)) % nt1 + data1 = data1.gather(-1, shift_index) + + if self.normalize: + data2 = (data2 - torch.mean(data2, dim=-1, keepdim=True)) / ( + torch.std(data2, dim=-1, keepdim=True) + torch.finfo(data1.dtype).eps ) - else: - xcor = F.conv1d(data1, data2, padding=self.nlag, groups=nb1 * nc1 * nx1) + + data1 = data1.view(1, nb1 * nc1 * nx1, nt1) + data2 = data2.view(nb2 * nc2 * nx2, 1, nt2) + xcor = F.conv1d(data1, data2, padding=self.nlag, stride=1, groups=nb1 * nc1 * nx1) + + if self.normalize: + data1_ = F.pad(data1, (nt2 // 2, nt2 - 1 - nt2 // 2), mode="reflect") + local_mean = F.avg_pool1d(data1_, nt2, stride=1) + local_std = F.lp_pool1d(data1 - local_mean, norm_type=2, kernel_size=nt2, stride=1) * np.sqrt(nt2) + xcor = xcor / (local_std + torch.finfo(data1.dtype).eps) + xcor = xcor.view(nb1, nc1, nx1, -1) + if self.reduce_x: + xcor = torch.sum(xcor, dim=(-3, -2), keepdim=True) + elif self.domain == "stft": nb1, nc1, nx1, nt1 = data1.shape # nb2, nc2, nx2, nt2 = data2.shape @@ -88,8 +113,8 @@ def forward(self, x): # data2 = data2.view(nb2 * nc2 * nx2, nt2) data2 = data2.view(nb1 * nc1 * nx1, nt1) if not self.pre_fft: - data1 = torch.stft(data1, n_fft=self.nlag*2, hop_length = self.nlag, center=True, return_complex=True) - data2 = torch.stft(data2, n_fft=self.nlag*2, hop_length = self.nlag, center=True, return_complex=True) + data1 = torch.stft(data1, n_fft=self.nlag * 2, hop_length=self.nlag, center=True, return_complex=True) + data2 = torch.stft(data2, n_fft=self.nlag * 2, hop_length=self.nlag, center=True, return_complex=True) if self.spectral_whitening: # freqs = np.fft.fftfreq(self.nlag*2, d=self.dt) # data1 = data1 / torch.clip(torch.abs(data1), min=1e-7) #float32 eps @@ -100,30 +125,28 @@ def forward(self, x): xcor = torch.fft.irfft(torch.sum(data1 * torch.conj(data2), dim=-1), dim=-1) xcor = torch.roll(xcor, self.nlag, dims=-1) xcor = xcor.view(nb1, nc1, nx1, -1) - + else: raise ValueError("domain should be frequency or time or stft") - # pair_index = [(i.item(), j.item()) for i, j in zip(x1["info"]["index"], x2["info"]["index"])] pair_index = [(i, j) for i, j in zip(x1["index"], x2["index"])] meta = { - "xcorr": xcor, - "pair_index": pair_index, + "xcorr": xcor, + "pair_index": pair_index, "nlag": self.nlag, - "data1": x1["data"], + "data1": x1["data"], "data2": x2["data"], "info1": x1["info"], "info2": x2["info"], - } - + } + if self.transforms is not None: meta = self.transforms(meta) return meta - def forward_map(self, x): """Perform cross-correlation on input data (dataset_type == map) Args: @@ -141,7 +164,6 @@ def forward_map(self, x): num_pairs = pair_index.shape[0] for i in tqdm(range(0, num_pairs, self.batch_size)): - c1 = pair_index[i : i + self.batch_size, 0] c2 = pair_index[i : i + self.batch_size, 1] if len(c1) == 1: ## returns a view of the original tensor diff --git a/cctorch/transforms.py b/cctorch/transforms.py index 2ceeb04..d994489 100644 --- a/cctorch/transforms.py +++ b/cctorch/transforms.py @@ -7,16 +7,72 @@ from scipy.sparse.linalg import lsmr from tqdm import tqdm import torchaudio +from datetime import datetime, timezone, timedelta + + +#### Common #### +class Filtering(torch.nn.Module): + def __init__(self, fmin, fmax, fs, ftype="bandpass", alpha=0.01, dtype=torch.float32, device="cpu"): + super().__init__() + self.f1 = fmin + self.f2 = fmax + self.fs = fs + self.alpha = alpha + if ftype == "bandpass": + b, a = scipy.signal.butter(2, [fmin, fmax], ftype, fs=fs) + elif ftype == "highpass": + b, a = scipy.signal.butter(2, fmin, ftype, fs=fs) + elif ftype == "lowpass": + b, a = scipy.signal.butter(2, fmax, ftype, fs=fs) + else: + raise ValueError("Unknown filter type") + self.a = torch.tensor(a, dtype=dtype).to(device) + self.b = torch.tensor(b, dtype=dtype).to(device) + + def forward(self, data): + data = data - torch.mean(data, dim=-1, keepdim=True) + max_, _ = torch.max(torch.abs(data), dim=-1, keepdim=True) + max_[max_ == 0.0] = 1.0 + data = data / max_ + + # data = data - (torch.linspace(0, 1, data.shape[-1], device=data.device, dtype=data.dtype) + # * (data[..., -1, None] - data[..., 0, None]) + data[..., 0, None]) + + taper = tukey(data.shape[-1], self.alpha * 3000 / data.shape[-1]) ## relative to 3000 samples + # taper = tukey(data.shape[-1], self.alpha) + data = data * torch.tensor(taper, device=data.device, dtype=data.dtype) + + data = torchaudio.functional.filtfilt(data, a_coeffs=self.a, b_coeffs=self.b, clamp=False) * max_ + + return data + + +class Reduction(torch.nn.Module): + def __init__(self, mode="reduce_x"): + super().__init__() + self.mode = mode + + def forward(self, meta): + if self.mode == "reduce_x": + ccmean = torch.mean(torch.max(torch.abs(meta["xcorr"]), dim=-1).values, dim=-1) + meta["cc_mean"] = ccmean + else: + raise NotImplementedError + + return meta ##### Ambient Noise ###### + def remove_temporal_mean(data): return data - torch.mean(data, dim=-1, keepdim=True) + def remove_spatial_median(data): return data - torch.median(data, dim=-2, keepdim=True)[0] + # def temporal_gradient(data): # return torch.gradient(data, dim=-1)[0] class TemporalGradient(torch.nn.Module): @@ -27,6 +83,7 @@ def __init__(self, fs=100.0): def forward(self, data): return torch.gradient(data, dim=-1)[0] * self.fs + class Decimation(torch.nn.Module): def __init__(self, decimation=2): super().__init__() @@ -35,13 +92,13 @@ def __init__(self, decimation=2): def forward(self, data): return data[..., :: self.decimation] + class TemporalMovingNormalization(torch.nn.Module): def __init__(self, window_size=64): super().__init__() self.window_size = window_size def forward(self, data): - nb, nc, nx, nt = data.shape moving_mean = F.avg_pool2d( data, @@ -77,8 +134,6 @@ def forward(self, data): ##### Cross-Correlation ###### - - class DetectPeaks(torch.nn.Module): def __init__(self, vmin=0.3, kernel=3, stride=1, K=3): super().__init__() @@ -88,7 +143,6 @@ def __init__(self, vmin=0.3, kernel=3, stride=1, K=3): self.K = K def forward(self, meta): - xcorr = meta["xcorr"] if "nlag" in meta: nlag = meta["nlag"] @@ -116,61 +170,64 @@ def forward(self, meta): return meta -class Filtering(torch.nn.Module): - def __init__(self, fmin, fmax, fs, ftype="bandpass", alpha=0.01, dtype=torch.float32, device="cpu"): +## Template Matching +class DetectTM(torch.nn.Module): + def __init__(self, ratio=10, maxpool_kernel=101, median_kernel=6000, K=100, sampling_rate=100.0): super().__init__() - self.f1 = fmin - self.f2 = fmax - self.fs = fs - self.alpha = alpha - if ftype == "bandpass": - b, a = scipy.signal.butter(2, [fmin, fmax], ftype, fs=fs) - elif ftype == "highpass": - b, a = scipy.signal.butter(2, fmin, ftype, fs=fs) - elif ftype == "lowpass": - b, a = scipy.signal.butter(2, fmax, ftype, fs=fs) - else: - raise ValueError("Unknown filter type") - self.a = torch.tensor(a, dtype=dtype).to(device) - self.b = torch.tensor(b, dtype=dtype).to(device) - - def forward(self, data): - - data = data - torch.mean(data, dim=-1, keepdim=True) - max_, _ = torch.max(torch.abs(data), dim=-1, keepdim=True) - max_[max_ == 0.0] = 1.0 - data = data / max_ - - # data = data - (torch.linspace(0, 1, data.shape[-1], device=data.device, dtype=data.dtype) - # * (data[..., -1, None] - data[..., 0, None]) + data[..., 0, None]) - - taper = tukey(data.shape[-1], self.alpha*3000/data.shape[-1]) ## relative to 3000 samples - # taper = tukey(data.shape[-1], self.alpha) - data = data * torch.tensor(taper, device=data.device, dtype=data.dtype) - - data = torchaudio.functional.filtfilt(data, a_coeffs=self.a, b_coeffs=self.b, clamp=False) * max_ - - return data - - -class Reduction(torch.nn.Module): - def __init__(self, mode="reduce_x"): - - super().__init__() - self.mode = mode + self.ratio = ratio + self.maxpool_kernel = maxpool_kernel + self.maxpool_stride = 1 + self.median_kernel = median_kernel + self.median_stride = median_kernel // 2 + self.K = K + self.sampling_rate = sampling_rate + + def convert(self, topk_score, topk_inds, begin_time): + nb, nc, nx, nk = topk_score.shape + event_time = [] + event_score = [] + for i in range(nb): + for j in range(nc): + for k in range(nx): + for l in range(nk): + if topk_score[i, j, k, l] > 0: + event_time.append( + begin_time[i] + timedelta(seconds=topk_inds[i, j, k, l].item() / self.sampling_rate) + ) + event_score.append(topk_score[i, j, k, l].item()) + return event_time, event_score def forward(self, meta): + scores = meta["xcorr"] + + nb, nc, nx, nt = scores.shape + smax = F.max_pool2d( + scores, (1, self.maxpool_kernel), stride=(1, self.maxpool_stride), padding=(0, self.maxpool_kernel // 2) + )[:, :, :, :nt] + scores_ = F.pad(scores, (0, self.median_kernel, 0, 0), mode="reflect", value=0) + ## MAD = median(|x_i - median(x)|) + unfolded = scores_.unfold(-1, self.median_kernel, self.median_stride) + mad = (unfolded - unfolded.median(dim=-1, keepdim=True).values).abs().median(dim=-1).values + mad = F.interpolate(mad, scale_factor=(1, self.median_stride), mode="bilinear", align_corners=False)[ + :, :, :, :nt + ] + keep = (smax == scores).float() * (scores > self.ratio * mad).float() + scores = scores * keep - if self.mode == "reduce_x": - ccmean = torch.mean(torch.max(torch.abs(meta["xcorr"]), dim=-1).values, dim=-1) - meta["cc_mean"] = ccmean + if self.K == 0: + K = max(round(nt * 10.0 / 3000.0), 3) else: - raise NotImplementedError + K = self.K + topk_scores, topk_inds = torch.topk(scores, K, dim=-1, sorted=True) + + event_time, event_score = self.convert(topk_scores, topk_inds, meta["info1"]["begin_time"]) + meta["event_time"] = event_time + meta["event_score"] = event_score return meta -############################################## +############################################## Old Func ############################################## def xcorr_lag(nt): diff --git a/cctorch/utils.py b/cctorch/utils.py index 0e77053..f4eed2e 100644 --- a/cctorch/utils.py +++ b/cctorch/utils.py @@ -21,41 +21,54 @@ from tqdm.auto import tqdm import json + # %% def write_results(results, result_path, ccconfig, rank=0, world_size=1): if ccconfig.mode == "CC": - ## TODO: add writting for CC write_cc_pairs(results, result_path, ccconfig, rank=rank, world_size=world_size) elif ccconfig.mode == "TM": - ## TODO: add writting for CC - pass + write_tm_events(results, result_path, ccconfig, rank=rank, world_size=world_size) elif ccconfig.mode == "AN": write_ambient_noise(results, result_path, ccconfig, rank=rank, world_size=world_size) else: raise ValueError(f"{ccconfig.mode} not supported") +def write_tm_events(results, result_path, ccconfig, rank=0, world_size=1): + if not isinstance(result_path, Path): + result_path = Path(result_path) + + events = [] + for meta in results: + for event_time, event_score in zip(meta["event_time"], meta["event_score"]): + events.append({"event_time": event_time.isoformat(), "event_score": round(event_score, 3)}) + if len(events) > 0: + events = pd.DataFrame(events) + events = events.sort_values(by="event_time", ascending=True) + events.to_csv(result_path / f"cctorch_events_{rank:03d}_{world_size:03d}.csv", index=False) + + def write_cc_pairs(results, result_path, ccconfig, rank=0, world_size=1, plot_figure=False): """ Write cross-correlation results to disk. Parameters ---------- results : list of dict - List of results from cross-correlation. + List of results from cross-correlation. e.g. [{ - "topk_index": topk_index, - "topk_score": topk_score, - "neighbor_score": neighbor_score, + "topk_index": topk_index, + "topk_score": topk_score, + "neighbor_score": neighbor_score, "pair_index": pair_index}] """ - + if not isinstance(result_path, Path): result_path = Path(result_path) min_cc_score = ccconfig.min_cc_score - min_cc_ratio = ccconfig.min_cc_ratio - min_cc_weight = ccconfig.min_cc_weight - + min_cc_ratio = ccconfig.min_cc_ratio + min_cc_weight = ccconfig.min_cc_weight + if "cc_mean" in results[0]: with open(result_path / f"{ccconfig.mode}_{rank:03d}_{world_size:03d}.txt", "a") as fp: for meta in results: @@ -65,7 +78,7 @@ def write_cc_pairs(results, result_path, ccconfig, rank=0, world_size=1, plot_fi for i in range(nb): pair_id = pair_index[i] id1, id2 = pair_id - score = ','.join([f"{x.item():.3f}" for x in cc_mean[i]]) + score = ",".join([f"{x.item():.3f}" for x in cc_mean[i]]) fp.write(f"{id1},{id2},{score}\n") with h5py.File(result_path / f"{ccconfig.mode}_{rank:03d}_{world_size:03d}.h5", "a") as fp: @@ -78,24 +91,25 @@ def write_cc_pairs(results, result_path, ccconfig, rank=0, world_size=1, plot_fi nb, nch, nx, nk = topk_index.shape for i in range(nb): - cc_score = topk_score[i, :, :, 0] cc_weight = topk_score[i, :, :, 0] - topk_score[i, :, :, 1] - if ((cc_score.max() >= min_cc_score) and (cc_weight.max() >= min_cc_weight) and - (torch.sum((cc_score > min_cc_score) & (cc_weight > min_cc_weight)) >= nch * nx * min_cc_ratio)): - + if ( + (cc_score.max() >= min_cc_score) + and (cc_weight.max() >= min_cc_weight) + and (torch.sum((cc_score > min_cc_score) & (cc_weight > min_cc_weight)) >= nch * nx * min_cc_ratio) + ): pair_id = pair_index[i] id1, id2 = pair_id if int(id1) > int(id2): id1, id2 = id2, id1 - topk_index = - topk_index - + topk_index = -topk_index + if f"{id1}/{id2}" not in fp: gp = fp.create_group(f"{id1}/{id2}") else: gp = fp[f"{id1}/{id2}"] - + if f"cc_index" in gp: del gp["cc_index"] gp.create_dataset(f"cc_index", data=topk_index[i].cpu()) @@ -108,17 +122,17 @@ def write_cc_pairs(results, result_path, ccconfig, rank=0, world_size=1, plot_fi if f"neighbor_score" in gp: del gp["neighbor_score"] gp.create_dataset(f"neighbor_score", data=neighbor_score[i].cpu()) - + if id2 != id1: if f"{id2}/{id1}" not in fp: # fp[f"{id2}/{id1}"] = h5py.SoftLink(f"/{id1}/{id2}") gp = fp.create_group(f"{id2}/{id1}") else: gp = fp[f"{id2}/{id1}"] - + if f"cc_index" in gp: del gp["cc_index"] - gp.create_dataset(f"cc_index", data= - topk_index[i].cpu()) + gp.create_dataset(f"cc_index", data=-topk_index[i].cpu()) if f"neighbor_score" in gp: del gp["neighbor_score"] gp.create_dataset(f"neighbor_score", data=neighbor_score[i].cpu().flip(-1)) @@ -128,15 +142,29 @@ def write_cc_pairs(results, result_path, ccconfig, rank=0, world_size=1, plot_fi if f"cc_weight" in gp: del gp["cc_weight"] gp["cc_weight"] = fp[f"{id1}/{id2}/cc_weight"] - + if plot_figure: for j in range(nch): fig, ax = plt.subplots(nrows=1, ncols=3, squeeze=False, figsize=(10, 20), sharey=True) - ax[0, 0].imshow(meta["xcorr"][i, j, :, :].cpu().numpy(), cmap="seismic", vmax=1, vmin=-1, aspect="auto") + ax[0, 0].imshow( + meta["xcorr"][i, j, :, :].cpu().numpy(), cmap="seismic", vmax=1, vmin=-1, aspect="auto" + ) for k in range(nx): - ax[0, 1].plot(meta["data1"][i, j, k, :].cpu().numpy()/np.max(np.abs(meta["data1"][i, j, k, :].cpu().numpy()))+k, linewidth=1, color="k") - ax[0, 2].plot(meta["data2"][i, j, k, :].cpu().numpy()/np.max(np.abs(meta["data2"][i, j, k, :].cpu().numpy()))+k, linewidth=1, color="k") - + ax[0, 1].plot( + meta["data1"][i, j, k, :].cpu().numpy() + / np.max(np.abs(meta["data1"][i, j, k, :].cpu().numpy())) + + k, + linewidth=1, + color="k", + ) + ax[0, 2].plot( + meta["data2"][i, j, k, :].cpu().numpy() + / np.max(np.abs(meta["data2"][i, j, k, :].cpu().numpy())) + + k, + linewidth=1, + color="k", + ) + try: fig.savefig(f"debug/test_{pair_id[0]}_{pair_id[1]}_{j}.png", dpi=300) except: @@ -144,10 +172,9 @@ def write_cc_pairs(results, result_path, ccconfig, rank=0, world_size=1, plot_fi fig.savefig(f"debug/test_{pair_id[0]}_{pair_id[1]}_{j}.png", dpi=300) print(f"debug/test_{pair_id[0]}_{pair_id[1]}_{j}.png") plt.close(fig) - -def write_ambient_noise(results, result_path, ccconfig, rank=0, world_size=1): +def write_ambient_noise(results, result_path, ccconfig, rank=0, world_size=1): if not isinstance(result_path, Path): result_path = Path(result_path) @@ -161,7 +188,7 @@ def write_ambient_noise(results, result_path, ccconfig, rank=0, world_size=1): # for j, pair_id in enumerate(meta["pair_index"]): for pair_id in meta["pair_index"]: list1, list2 = pair_id - + for j, (id1, id2) in enumerate(zip(list1, list2)): if f"{id1}/{id2}" not in fp: gp = fp.create_group(f"{id1}/{id2}") @@ -173,7 +200,7 @@ def write_ambient_noise(results, result_path, ccconfig, rank=0, world_size=1): count = ds.attrs["count"] ds[:] = count / (count + 1) * ds[:] + data[..., j, :] / (count + 1) ds.attrs["count"] = count + 1 - + if f"{id2}/{id1}" not in fp: gp = fp.create_group(f"{id2}/{id1}") ds = gp.create_dataset("xcorr", data=np.flip(data[..., j, :], axis=-1)) diff --git a/run.py b/run.py index 1b7b3f5..e098a4d 100644 --- a/run.py +++ b/run.py @@ -21,7 +21,8 @@ from cctorch import ( CCDataset, CCIterableDataset, - CCModel,) + CCModel, +) from cctorch.transforms import * from cctorch.utils import write_results @@ -39,13 +40,19 @@ def get_args_parser(add_help=True): parser.add_argument("--pair_list", default=None, type=str, help="pair list") parser.add_argument("--data_list1", default=None, type=str, help="data list 1") parser.add_argument("--data_list2", default=None, type=str, help="data list 1") - parser.add_argument("--data_path", default="./", type=str, help="data path") - parser.add_argument("--data_format", default="h5", type=str, help="data type in {h5, memmap}") + parser.add_argument("--data_path1", default="./", type=str, help="data path") + parser.add_argument("--data_path2", default="./", type=str, help="data path") + parser.add_argument("--data_format1", default="h5", type=str, help="data type in {h5, memmap}") + parser.add_argument("--data_format2", default="h5", type=str, help="data type in {h5, memmap}") parser.add_argument("--config", default=None, type=str, help="config file") parser.add_argument("--result_path", default="./results", type=str, help="results path") parser.add_argument("--dataset_type", default="iterable", type=str, help="data loader type in {map, iterable}") - parser.add_argument("--block_size1", default=1000, type=int, help="Number of sample for the 1st data pair dimension") - parser.add_argument("--block_size2", default=1000, type=int, help="Number of sample for the 2nd data pair dimension") + parser.add_argument( + "--block_size1", default=1000, type=int, help="Number of sample for the 1st data pair dimension" + ) + parser.add_argument( + "--block_size2", default=1000, type=int, help="Number of sample for the 2nd data pair dimension" + ) parser.add_argument("--auto_xcorr", action="store_true", help="do auto-correlation for data list") ## common @@ -57,14 +64,22 @@ def get_args_parser(add_help=True): parser.add_argument("--buffer_size", default=10, type=int, help="buffer size for writing to h5 file") parser.add_argument("--workers", default=4, type=int, help="data loading workers") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)") - parser.add_argument("--dtype", default="float32", type=str, help="data type (Use float32 or float64, Default: float32)") + parser.add_argument( + "--dtype", default="float32", type=str, help="data type (Use float32 or float64, Default: float32)" + ) + parser.add_argument("--normalize", action="store_true", help="normalized cross-correlation (pearson correlation)") + + ## template matching parameters + parser.add_argument("--shift_t", action="store_true", help="shift continuous waveform to align with template time") ## ambient noise parameters parser.add_argument("--min_channel", default=0, type=int, help="minimum channel index") parser.add_argument("--max_channel", default=None, type=int, help="maximum channel index") parser.add_argument("--delta_channel", default=1, type=int, help="channel interval") parser.add_argument("--left_channel", default=None, type=int, help="channel index of the left end from the source") - parser.add_argument("--right_channel", default=None, type=int, help="channel index of the right end from the source") + parser.add_argument( + "--right_channel", default=None, type=int, help="channel index of the right end from the source" + ) parser.add_argument( "--fixed_channels", nargs="+", @@ -110,12 +125,11 @@ def get_args_parser(add_help=True): def main(args): - logging.basicConfig(filename="cctorch.log", level=logging.INFO) utils.init_distributed_mode(args) rank = utils.get_rank() if args.distributed else 0 world_size = utils.get_world_size() if args.distributed else 1 - + if args.config is not None: with open(args.config, "r") as f: config = json.load(f) @@ -137,7 +151,7 @@ class CCConfig: dt = 1 / fs maxlag = args.maxlag nlag = int(maxlag / dt) - pre_fft = False ## if true, do fft in dataloader + pre_fft = False ## if true, do fft in dataloader auto_xcorr = args.auto_xcorr ## ambinet noise @@ -157,7 +171,7 @@ class CCConfig: fmin = 0.1 fmax = 10 ftype = "bandpass" - alpha = 0.05 # tukey window parameter + alpha = 0.05 # tukey window parameter order = 2 #### Decimate decimate_factor = 2 @@ -170,8 +184,13 @@ class CCConfig: mccc = args.mccc use_pair_index = True if args.dataset_type == "map" else False min_cc_score = 0.6 - min_cc_ratio = 0.0 ## ratio is defined as the portion of channels with cc score larger than min_cc_score - min_cc_weight = 0.0 ## the weight is defined as the difference between largest and second largest cc score + min_cc_ratio = 0.0 ## ratio is defined as the portion of channels with cc score larger than min_cc_score + min_cc_weight = 0.0 ## the weight is defined as the difference between largest and second largest cc score + + ## template matching + shift_t = args.shift_t + reduce_x = args.reduce_x + normalize = args.normalize def __init__(self, config): if config is not None: @@ -180,13 +199,18 @@ def __init__(self, config): ccconfig = CCConfig(config) - if (rank == 0): - if os.path.exists(args.result_path): - print(f"Remove existing result path: {args.result_path}") - if os.path.exists(args.result_path.rstrip("/") + "_backup"): - shutil.rmtree(args.result_path.rstrip("/") + "_backup") - shutil.move(args.result_path.rstrip("/"), args.result_path.rstrip("/") + "_backup") - os.makedirs(args.result_path) + ## Sanity check + if args.mode == "TM": + assert ccconfig.shift_t + assert ccconfig.nlag == 0 + + # if rank == 0: + # if os.path.exists(args.result_path): + # print(f"Remove existing result path: {args.result_path}") + # if os.path.exists(args.result_path.rstrip("/") + "_backup"): + # shutil.rmtree(args.result_path.rstrip("/") + "_backup") + # shutil.move(args.result_path.rstrip("/"), args.result_path.rstrip("/") + "_backup") + # os.makedirs(args.result_path) preprocess = [] if args.mode == "CC": @@ -204,13 +228,23 @@ def __init__(self, config): pass elif args.mode == "AN": ## TODO add preprocess for ambient noise - if args.temporal_gradient: ## convert to strain rate + if args.temporal_gradient: ## convert to strain rate preprocess.append(TemporalGradient(ccconfig.fs)) - preprocess.append(TemporalMovingNormalization(int(30*ccconfig.fs))) #30s for 25Hz - preprocess.append(Filtering(ccconfig.fmin, ccconfig.fmax, ccconfig.fs, ccconfig.ftype, ccconfig.alpha, ccconfig.dtype, ccconfig.transform_device)) #50Hz - preprocess.append(Decimation(ccconfig.decimate_factor)) #25Hz + preprocess.append(TemporalMovingNormalization(int(30 * ccconfig.fs))) # 30s for 25Hz + preprocess.append( + Filtering( + ccconfig.fmin, + ccconfig.fmax, + ccconfig.fs, + ccconfig.ftype, + ccconfig.alpha, + ccconfig.dtype, + ccconfig.transform_device, + ) + ) # 50Hz + preprocess.append(Decimation(ccconfig.decimate_factor)) # 25Hz preprocess.append(T.Lambda(remove_spatial_median)) - preprocess.append(TemporalMovingNormalization(int(2*ccconfig.fs//ccconfig.decimate_factor))) #2s for 25Hz + preprocess.append(TemporalMovingNormalization(int(2 * ccconfig.fs // ccconfig.decimate_factor))) # 2s for 25Hz preprocess = T.Compose(preprocess) @@ -220,8 +254,7 @@ def __init__(self, config): postprocess.append(DetectPeaks()) postprocess.append(Reduction()) elif args.mode == "TM": - ## TODO add postprocess for template matching - pass + postprocess.append(DetectTM()) elif args.mode == "AN": ## TODO add postprocess for ambient noise pass @@ -235,14 +268,16 @@ def __init__(self, config): data_list2=args.data_list2, block_size1=args.block_size1, block_size2=args.block_size2, - data_path=args.data_path, - data_format=args.data_format, + data_path1=args.data_path1, + data_path2=args.data_path2, + data_format1=args.data_format1, + data_format2=args.data_format2, device="cpu" if args.workers > 0 else args.device, transforms=preprocess, rank=rank, world_size=world_size, ) - elif args.dataset_type == "iterable": ## prefered + elif args.dataset_type == "iterable": ## prefered dataset = CCIterableDataset( config=ccconfig, pair_list=args.pair_list, @@ -250,8 +285,10 @@ def __init__(self, config): data_list2=args.data_list2, block_size1=args.block_size1, block_size2=args.block_size2, - data_path=args.data_path, - data_format=args.data_format, + data_path1=args.data_path1, + data_path2=args.data_path2, + data_format1=args.data_format1, + data_format2=args.data_format2, device=args.device, transforms=preprocess, batch_size=args.batch_size, @@ -288,7 +325,7 @@ def __init__(self, config): num = 0 results = [] metric_logger = utils.MetricLogger(delimiter=" ") - log_freq = max(1, 10240//args.batch_size) if args.mode == "CC" else 1 + log_freq = max(1, 10240 // args.batch_size) if args.mode == "CC" else 1 for data in metric_logger.log_every(dataloader, log_freq, ""): result = ccmodel(data) results.append(result) diff --git a/tests/test_qtm.py b/tests/test_qtm.py index 0669ac9..36a67a2 100644 --- a/tests/test_qtm.py +++ b/tests/test_qtm.py @@ -73,7 +73,7 @@ plt.plot(normalize(template[ieve, ich + 3, ista, :]) / 3 + ista + shift, "r", linewidth=0.5) plt.show() -# %% Load continous waveform +# %% Load continuous waveform year = "2019" jday = "185" hour = "17" @@ -152,9 +152,9 @@ def detect_peaks(scores, ratio=10, maxpool_kernel=101, median_kernel=1000, K=100 nb1, nc1, nx1, nt1 = data1.shape nb2, nc2, nx2, nt2 = data2.shape -## shift continous waveform to align with template time +## shift continuous waveform to align with template time nt_index = torch.arange(nt1).unsqueeze(0).unsqueeze(0).unsqueeze(0) -adjusted_index = (nt_index + shift_index.unsqueeze(-1)) % nt +adjusted_index = (nt_index + shift_index.unsqueeze(-1)) % nt1 data1 = data1.gather(-1, adjusted_index) device = "cpu"