Skip to content

Commit

Permalink
update CC pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Nov 6, 2023
1 parent 8aa9626 commit 64be6b7
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 169 deletions.
13 changes: 9 additions & 4 deletions cctorch/data.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import itertools
from pathlib import Path

import h5py
import matplotlib.pyplot as plt
import numpy as np
import obspy
import pandas as pd
import scipy.signal
import torch
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, IterableDataset
import itertools
import matplotlib.pyplot as plt
from tqdm import tqdm
import obspy


class CCDataset(Dataset):
Expand Down Expand Up @@ -146,7 +146,12 @@ def __init__(
self.device = device
self.dtype = dtype
self.num_batch = None
self.symmetric = True if self.mode == "CC" else False

if self.mode == "CC":
self.symmetric = True
self.data_list2 = self.data_list1
self.data_format2 = self.data_format1
self.data_path2 = self.data_path1

if self.mode == "AN":
## For ambient noise, we split chunks in the sampling function
Expand Down
15 changes: 12 additions & 3 deletions cctorch/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from tqdm import tqdm
import numpy as np


class CCModel(nn.Module):
Expand Down Expand Up @@ -145,7 +146,15 @@ def forward(self, x):
if self.transforms is not None:
meta = self.transforms(meta)

return meta
output = {}
for key in meta:
if key not in ["data1", "data2", "info1", "info2"]:
if isinstance(meta[key], torch.Tensor):
output[key] = meta[key].cpu()
else:
output[key] = meta[key]

return output

def forward_map(self, x):
"""Perform cross-correlation on input data (dataset_type == map)
Expand Down
60 changes: 33 additions & 27 deletions cctorch/transforms.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from datetime import datetime, timedelta, timezone

import numpy as np
import scipy
import torch
import torch.nn.functional as F
import scipy
import torchaudio
from scipy import sparse
from scipy.signal import tukey
from scipy.sparse.linalg import lsmr
from tqdm import tqdm
import torchaudio
from datetime import datetime, timezone, timedelta


#### Common ####
Expand Down Expand Up @@ -48,14 +49,18 @@ def forward(self, data):


class Reduction(torch.nn.Module):
def __init__(self, mode="reduce_x"):
def __init__(self, mode="reduce_x", threshold=0.5):
super().__init__()
self.mode = mode
self.threshold = threshold

def forward(self, meta):
if self.mode == "reduce_x":
ccmean = torch.mean(torch.max(torch.abs(meta["xcorr"]), dim=-1).values, dim=-1)
meta["cc_mean"] = ccmean
# ccmean = torch.mean(torch.max(torch.abs(meta["xcorr"]), dim=-1).values, dim=-1)
# meta["cc_mean"] = ccmean
cc_quality = torch.max(torch.abs(meta["xcorr"]), dim=-1).values # nb, nc, nx
cc_quality = cc_quality * (cc_quality > self.threshold)
meta["cc_sum"] = torch.sum(cc_quality, dim=-1) # nb, nc
else:
raise NotImplementedError

Expand Down Expand Up @@ -134,6 +139,28 @@ def forward(self, data):


##### Cross-Correlation ######


def taper_time(data, alpha=0.8):
taper = tukey(data.shape[-1], alpha)
return data * torch.tensor(taper, device=data.device)


def normalize(x):
x -= torch.mean(x, dim=-1, keepdims=True)
norm = x.square().sum(dim=-1, keepdims=True).sqrt()
norm[norm == 0] = 1
x /= norm
return x


def fft_real_normalize(x):
""""""
x -= torch.mean(x, dim=-1, keepdims=True)
x /= x.square().sum(dim=-1, keepdims=True).sqrt()
return fft_real(x)


class DetectPeaks(torch.nn.Module):
def __init__(self, vmin=0.3, kernel=3, stride=1, K=3):
super().__init__()
Expand Down Expand Up @@ -251,22 +278,6 @@ def fft_real(x):
return torch.fft.rfft(x, n=nfast, dim=-1)


def fft_real_normalize(x):
""""""
x -= torch.mean(x, dim=-1, keepdims=True)
x /= x.square().sum(dim=-1, keepdims=True).sqrt()
return fft_real(x)


def normalize(x):
x -= torch.mean(x, dim=-1, keepdims=True)
# x /= x.square().sum(dim=-1, keepdims=True).sqrt()
norm = x.square().sum(dim=-1, keepdims=True).sqrt()
norm[norm == 0] = 1
x /= norm
return x


# torch helper functions
def count_tensor_byte(*args):
total_byte_size = 0
Expand Down Expand Up @@ -305,11 +316,6 @@ def gather_roll(data, shift_index):
return torch.gather(data, 1, index)


def taper_time(data, alpha=0.8):
taper = tukey(data.shape[-1], alpha)
return data * torch.tensor(taper, device=data.device)


def h_poly(t):
tt = t[None, :] ** torch.arange(4, device=t.device)[:, None]
A = torch.tensor([[1, 0, -3, 2], [0, 1, -2, 1], [0, 0, 3, -2], [0, 0, -1, 1]], dtype=t.dtype, device=t.device)
Expand Down
Loading

0 comments on commit 64be6b7

Please sign in to comment.