Skip to content

Commit

Permalink
fix ambient noise
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Dec 15, 2024
1 parent a0f3f45 commit a8963c6
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 278 deletions.
275 changes: 69 additions & 206 deletions cctorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,20 @@ def __init__(
self.data_list2 = None

if self.mode == "AN":
## For ambient noise, we split chunks in the sampling function
self.data_list1 = self.data_list1[rank::world_size]
else:
block_num1 = int(np.ceil(len(unique_row) / block_size1))
block_num2 = int(np.ceil(len(unique_col) / block_size2))
self.group1 = [list(x) for x in np.array_split(unique_row, block_num1) if len(x) > 0]
self.group2 = [list(x) for x in np.array_split(unique_col, block_num2) if len(x) > 0]
self.data_list1 = pd.read_csv(data_list1)
self.data_list2 = self.data_list1

blocks = list(itertools.product(range(len(self.group1)), range(len(self.group2))))[rank::world_size]
self.block_index, self.num_batch = self.count_blocks(blocks)
block_num1 = int(np.ceil(len(unique_row) / block_size1))
block_num2 = int(np.ceil(len(unique_col) / block_size2))
self.group1 = [list(x) for x in np.array_split(unique_row, block_num1) if len(x) > 0]
self.group2 = [list(x) for x in np.array_split(unique_col, block_num2) if len(x) > 0]

print(
f"pair_matrix: {self.pair_matrix.shape}, blocks: {len(self.block_index)}, block_size: {self.block_size1} x {self.block_size2}"
)
blocks = list(itertools.product(range(len(self.group1)), range(len(self.group2))))[rank::world_size]
self.block_index, self.num_batch = self.count_blocks(blocks)

print(
f"pair_matrix: {self.pair_matrix.shape}, blocks: {len(self.block_index)}, block_size: {self.block_size1} x {self.block_size2}"
)

if (self.data_format1 == "memmap") or (self.data_format2 == "memmap"):
self.templates = np.memmap(
Expand Down Expand Up @@ -272,10 +272,7 @@ def __iter__(self):
num_workers = worker_info.num_workers
worker_id = worker_info.id

if self.mode == "AN":
return iter(self.sample_ambient_noise(self.data_list1[worker_id::num_workers]))
else:
return iter(self.sample(self.block_index[worker_id::num_workers]))
return iter(self.sample(self.block_index[worker_id::num_workers]))

def sample(self, block_index):
for i, j in block_index:
Expand All @@ -288,6 +285,7 @@ def sample(self, block_index):
num = 0

for ii, jj in zip(row_matrix.data, col_matrix.data):

if self.data_format1 == "memmap":
if ii not in local_dict:
meta1 = {
Expand All @@ -311,18 +309,28 @@ def sample(self, block_index):
else:
meta1 = local_dict[ii]
else:
if self.data_list1[ii] not in local_dict:
meta1 = read_data(
self.data_list1[ii], self.data_path1, self.data_format1, mode=self.mode, config=self.config
if self.data_list1.loc[ii, "file_name"] not in local_dict:
data, info = read_data(
self.data_list1.loc[ii, "file_name"],
self.data_path1,
self.data_format1,
mode=self.mode,
config=self.config,
)
meta1["index"] = ii
data = torch.tensor(meta1["data"], dtype=self.dtype).to(self.device)
data = torch.tensor(data, dtype=self.dtype).to(self.device)
if self.transforms is not None:
data = self.transforms(data)
meta1["data"] = data
local_dict[self.data_list1[ii]] = meta1
meta1 = {
"data": data,
"index": ii,
"info": {
"file_name": self.data_list1.loc[ii, "file_name"],
"channel_index": self.data_list1.loc[ii, "channel_index"],
},
}
local_dict[self.data_list1.loc[ii, "file_name"]] = meta1
else:
meta1 = local_dict[self.data_list1[ii]]
meta1 = local_dict[self.data_list1.loc[ii, "file_name"]]

if self.data_format2 == "memmap":
if jj not in local_dict:
Expand All @@ -347,25 +355,43 @@ def sample(self, block_index):
else:
meta2 = local_dict[jj]
else:
if self.data_list2[jj] not in local_dict:
meta2 = read_data(
self.data_list2[jj], self.data_path2, self.data_format2, mode=self.mode, config=self.config
if self.data_list2.loc[jj, "file_name"] not in local_dict:
data, info = read_data(
self.data_list2.loc[jj, "file_name"],
self.data_path2,
self.data_format2,
mode=self.mode,
config=self.config,
)
meta2["index"] = jj
data = torch.tensor(meta2["data"], dtype=self.dtype).to(self.device)
data = torch.tensor(data, dtype=self.dtype).to(self.device)
if self.transforms is not None:
data = self.transforms(data)
meta2["data"] = data
local_dict[self.data_list2[jj]] = meta2
meta2 = {
"data": data,
"index": jj,
"info": {
"file_name": self.data_list2.loc[jj, "file_name"],
"channel_index": self.data_list2.loc[jj, "channel_index"],
},
}
local_dict[self.data_list2.loc[jj, "file_name"]] = meta2
else:
meta2 = local_dict[self.data_list2[jj]]

data1.append(meta1["data"])
index1.append(meta1["index"])
info1.append(meta1["info"])
data2.append(meta2["data"])
index2.append(meta2["index"])
info2.append(meta2["info"])
meta2 = local_dict[self.data_list2.loc[jj, "file_name"]]

if self.mode == "AN":
data1.append(meta1["data"][:, :, self.data_list1.loc[ii, "channel_index"]])
index1.append(self.data_list1.loc[ii, "channel_index"])
info1.append({"file_name": self.data_list1.loc[ii, "file_name"]})
data2.append(meta2["data"][:, :, self.data_list2.loc[jj, "channel_index"]])
index2.append(self.data_list2.loc[jj, "channel_index"])
info2.append({"file_name": self.data_list2.loc[jj, "file_name"]})
else:
data1.append(meta1["data"])
index1.append(meta1["index"])
info1.append(meta1["info"])
data2.append(meta2["data"])
index2.append(meta2["index"])
info2.append(meta2["info"])

num += 1
if num == self.batch_size:
Expand Down Expand Up @@ -441,164 +467,11 @@ def sample(self, block_index):
"info": info_batch2,
}

def sample_ambient_noise(self, data_list):
for fd in data_list:
meta = read_data(fd, self.data_path1, self.data_format1, mode=self.mode) # (nch, nt)
data = meta["data"].float().unsqueeze(0).unsqueeze(0) # (1, 1, nx, nt)

if (self.config.transform_on_file) and (self.transforms is not None):
data = self.transforms(data)

# plt.figure()
# tmp = data[0, 0, :, :].cpu().numpy()
# vmax = np.std(tmp[:, -1000:]) * 5
# plt.imshow(tmp[:, -1000:], aspect="auto", vmin=-vmax, vmax=vmax, cmap="seismic")
# plt.colorbar()
# plt.savefig(f"cctorch_step1_{fd.split('/')[-1]}.png", dpi=300)
# raise

nb, nc, nx, nt = data.shape

## cut blocks
min_channel = self.config.min_channel if self.config.min_channel is not None else 0
max_channel = self.config.max_channel if self.config.max_channel is not None else nx
left_channel = self.config.left_channel if self.config.left_channel is not None else -nx
right_channel = self.config.right_channel if self.config.right_channel is not None else nx

if self.config.fixed_channels is not None:
## only process channels passed by "--fixed-channels" as source
lists_1 = (
self.config.fixed_channels
if isinstance(self.config.fixed_channels, list)
else [self.fixed_channels]
)
else:
## using delta_channel to down-sample channels needed for ambient noise
## using min_channel and max_channel to selected channels that are within a range
lists_1 = range(min_channel, max_channel, self.config.delta_channel)
lists_2 = range(min_channel, max_channel, self.config.delta_channel)
block_num1 = int(np.ceil(len(lists_1) / self.block_size1))
block_num2 = int(np.ceil(len(lists_2) / self.block_size2))
group_1 = [list(x) for x in np.array_split(lists_1, block_num1) if len(x) > 0]
group_2 = [list(x) for x in np.array_split(lists_2, block_num2) if len(x) > 0]
block_index = list(itertools.product(range(len(group_1)), range(len(group_2))))

## loop each block
for i, j in block_index:
block1 = group_1[i]
block2 = group_2[j]
index_i = []
index_j = []
for ii, jj in itertools.product(block1, block2):
if (jj < (ii + left_channel)) or (jj > (ii + right_channel)):
continue
index_i.append(ii)
index_j.append(jj)

data_i = data[:, :, index_i, :].to(self.device)
data_j = data[:, :, index_j, :].to(self.device)

if (self.config.transform_on_batch) and (self.transforms is not None):
data_i = self.transforms(data_i)
data_j = self.transforms(data_j)

yield {
"data": data_i,
"index": [index_i],
"info": {},
}, {"data": data_j, "index": [index_j], "info": {}}

def count_sample_ambient_noise(self, num_workers, worker_id):
num_samples = 0
for fd in self.data_list1:
nx, nt = get_shape_das_continuous_data_h5(self.data_path1 / fd) # (nch, nt)

## cut blocks
min_channel = self.config.min_channel if self.config.min_channel is not None else 0
max_channel = self.config.max_channel if self.config.max_channel is not None else nx
left_channel = self.config.left_channel if self.config.left_channel is not None else -nx
right_channel = self.config.right_channel if self.config.right_channel is not None else nx

if self.config.fixed_channels is not None:
## only process channels passed by "--fixed-channels" as source
lists_1 = (
self.config.fixed_channels
if isinstance(self.config.fixed_channels, list)
else [self.fixed_channels]
)
else:
## using delta_channel to down-sample channels needed for ambient noise
## using min_channel and max_channel to selected channels that are within a range
lists_1 = range(min_channel, max_channel, self.config.delta_channel)
lists_2 = range(min_channel, max_channel, self.config.delta_channel)
block_num1 = int(np.ceil(len(lists_1) / self.block_size1))
block_num2 = int(np.ceil(len(lists_2) / self.block_size2))
group_1 = [list(x) for x in np.array_split(lists_1, block_num1) if len(x) > 0]
group_2 = [list(x) for x in np.array_split(lists_2, block_num2) if len(x) > 0]
block_index = list(itertools.product(range(len(group_1)), range(len(group_2))))

## loop each block
for i, j in block_index:
num_samples += 1

return num_samples


def generate_pairs(event1, event2, auto_xcorr=False, symmetric=False):
event1 = set(event1)
event2 = set(event2)
event_inner = event1 & event2
event_outer1 = event1 - event_inner
event_outer2 = event2 - event_inner
event_inner = list(event_inner)
event_outer1 = list(event_outer1)
event_outer2 = list(event_outer2)

if symmetric:
if auto_xcorr:
condition = lambda evt1, evt2: evt1 <= evt2
else:
condition = lambda evt1, evt2: evt1 < evt2
else:
condition = lambda evt1, evt2: True

pairs = []
if len(event_inner) > 0:
# pairs += [(evt1, evt2) for i1, evt1 in enumerate(event_inner) for evt2 in event_inner[i1 + xcor_offset :]]
pairs += [(evt1, evt2) for evt1 in event_inner for evt2 in event_inner if condition(evt1, evt2)]
if len(event_outer1) > 0:
pairs += [(evt1, evt2) for evt1 in event_outer1 for evt2 in event_inner] # if condition(evt1, evt2)]
if len(event_outer2) > 0:
pairs += [(evt1, evt2) for evt1 in event_inner for evt2 in event_outer2] # if condition(evt1, evt2)]
if len(event_outer1) > 0 and len(event_outer2) > 0:
pairs += [(evt1, evt2) for evt1 in event_outer1 for evt2 in event_outer2] # if condition(evt1, evt2)]

# print(f"Total number of pairs: {len(pairs)}")
return pairs


def generate_block_index(group1, group2, pair_list=None, auto_xcorr=False, symmetric=False, min_sample_per_block=1):
block_index = [(i, j) for i in range(len(group1)) for j in range(len(group2))]
num_empty_index = []
for i, j in tqdm(block_index, desc="Generating blocks"):
num_samples = 0
event1, event2 = group1[i], group2[j]
pairs = generate_pairs(event1, event2, auto_xcorr=auto_xcorr, symmetric=symmetric)
if pair_list is None:
num_samples = len(pairs)
else:
for pair in pairs:
if pair in pair_list:
num_samples += 1
if num_samples >= min_sample_per_block:
num_empty_index.append((i, j))
return num_empty_index


def read_data(file_name, data_path, format="h5", mode="CC", config={}):
if mode == "CC":
if format == "h5":
data_list, info_list = read_das_eventphase_data_h5(
data_list, info_list = read_das_phase_data_h5(
data_path / file_name, phase="P", event=True, dataset_keys=["shift_index"]
)
## TODO: check with Jiaxuan; why do we need to read a list but return the first one
Expand All @@ -616,7 +489,8 @@ def read_data(file_name, data_path, format="h5", mode="CC", config={}):
else:
raise ValueError(f"Unknown mode: {mode}")

return {"data": data, "info": info}
# return {"data": data, "info": info}
return data, info


def read_mseed(fname, highpass_filter=False, sampling_rate=100, config=None):
Expand Down Expand Up @@ -783,19 +657,8 @@ def read_das_continuous_data_h5(fn, dataset_keys=[]):
return data, info


def get_shape_das_continuous_data_h5(file):
with h5py.File(file, "r") as f:
if "Data" in f:
data_shape = f["Data"].shape
elif "data" in f:
data_shape = f["data"].shape
else:
raise ValueError("Cannot find data in the file")
return data_shape


# helper reading functions
def read_das_eventphase_data_h5(fn, phase=None, event=False, dataset_keys=None, attrs_only=False):
def read_das_phase_data_h5(fn, phase=None, event=False, dataset_keys=None, attrs_only=False):
"""
read event phase data from hdf5 file
Args:
Expand Down
Loading

0 comments on commit a8963c6

Please sign in to comment.