Skip to content

Commit

Permalink
fix relative location
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Oct 17, 2024
1 parent 871cf52 commit a0f9dd0
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 203 deletions.
251 changes: 77 additions & 174 deletions scripts/convert_dtcc.py
Original file line number Diff line number Diff line change
@@ -1,192 +1,95 @@
# %%
import json
import multiprocessing as mp
import os
import sys
from datetime import datetime
from glob import glob
from pathlib import Path
import pickle

import h5py
import numpy as np
import pandas as pd
import scipy
from args import parse_args
from tqdm import tqdm

os.environ["OMP_NUM_THREADS"] = "8"


# %%
def extract_picks(pair, data, config, tt_memmap, station_df):
tt_memmap = np.memmap(
tt_memmap,
dtype=np.float32,
mode="r",
shape=tuple(config["traveltime_shape"]),
)

h5, id1 = pair

x = config["interp"]["x"]
x_interp = config["interp"]["x_interp"]
dt = config["interp"]["dt"]
dt_interp = config["interp"]["dt_interp"]
min_cc_score = config["min_cc_score"]
min_cc_diff = config["min_cc_diff"]
num_channel = config["num_channel"]
phase_list = config["phase_list"]

with h5py.File(h5, "r") as fp:
gp = fp[id1]
id1 = int(id1)

for id2 in gp:
ds = gp[id2]
id2 = int(id2)
if id1 > id2:
continue

# TODO: save only the best cc score
cc_score = ds["cc_score"][:] # [nch, nsta, 3]
cc_index = ds["cc_index"][:] # [nch, nsta, 3]
cc_diff = ds["cc_diff"][:] # [nch, nsta]
neighbor_score = ds["neighbor_score"][:] # [nch, nsta, 3]
# print(f"{cc_score.shape = }, {cc_index.shape = }, {cc_diff.shape = }, {neighbor_score.shape = }")

if np.max(cc_score) < min_cc_score or (np.max(cc_diff) < min_cc_diff):
continue

# cubic_score = scipy.interpolate.interp1d(x, neighbor_score, axis=-1, kind="quadratic")(x_interp)
# cubic_index = np.argmax(cubic_score, axis=-1, keepdims=True) - len(x_interp) // 2
# dt_cc = cc_index * dt + cubic_index * dt_interp

key = (id1, id2)
nch, nsta, npick = cc_score.shape
records = []
for i in range(nch // num_channel):
for j in range(nsta):
dt_ct = tt_memmap[id1][i, j] - tt_memmap[id2][i, j]
best = np.argmax(cc_score[i * num_channel : (i + 1) * num_channel, j, 0]) + i * num_channel
if cc_score[best, j, 0] >= min_cc_score:
cubic_score = scipy.interpolate.interp1d(x, neighbor_score[best, j, :], kind="quadratic")(
x_interp
)
cubic_index = np.argmax(cubic_score) - len(x_interp) // 2
dt_cc = cc_index[best, j, 0] * dt + cubic_index * dt_interp

# Shelly (2016) Fluid-faulting evolution in high definition: Connecting fault structure and
# frequency-magnitude variations during the 2014 Long Valley Caldera, California, earthquake swarm
weight = (0.1 + 3 * cc_diff[best, j]) * cc_score[best, j, 0] ** 2
records.append(
[
f"{station_df.loc[j]['station']:<4}",
# dt_ct + dt_cc[best, j, 0],
dt_ct + dt_cc,
weight,
phase_list[i],
]
)

if len(records) > 0:
data[key] = records

return 0


if __name__ == "__main__":
# %%
root_path = "local"
region = "demo"
if len(sys.argv) > 1:
root_path = sys.argv[1]
region = sys.argv[2]
args = parse_args()
root_path = args.root_path
region = args.region

# %%
cctorch_path = f"{region}/cctorch"
with open(f"{root_path}/{region}/config.json", "r") as fp:
config = json.load(fp)

# %%
with open(f"{root_path}/{cctorch_path}/config.json", "r") as fp:
config = json.load(fp)
config["min_cc_score"] = 0.6
config["min_cc_diff"] = 0.0

# %%
event_df = pd.read_csv(f"{root_path}/{cctorch_path}/events.csv", index_col=0)

# %%
station_df = pd.read_csv(f"{root_path}/{cctorch_path}/stations.csv", index_col=0)

# %%
tt_memmap = f"{root_path}/{cctorch_path}/traveltime.dat"

# %%
lines = []
for i, row in station_df.iterrows():
# tmp = f"{row['network']}{row['station']}"
tmp = f"{row['station']}"
line = f"{tmp:<4} {row['latitude']:.4f} {row['longitude']:.4f}\n"
lines.append(line)

with open(f"{root_path}/{cctorch_path}/stlist.txt", "w") as fp:
fp.writelines(lines)

h5_list = sorted(list(glob(f"{root_path}/{cctorch_path}/ccpairs/*.h5")))

# %%
dt = 0.01
dt_interp = dt / 100
x = np.linspace(0, 1, 2 + 1)
x_interp = np.linspace(0, 1, 2 * int(dt / dt_interp) + 1)
num_channel = 3
phase_list = ["P", "S"]
# %%
data_path = f"{region}/cctorch"
result_path = f"{region}/adloc_dd"
if not os.path.exists(f"{result_path}"):
os.makedirs(f"{result_path}")

config["interp"] = {"x": x, "x_interp": x_interp, "dt": dt, "dt_interp": dt_interp}
config["num_channel"] = num_channel
config["phase_list"] = phase_list
# %%
stations = pd.read_csv(f"{root_path}/{data_path}/cctorch_stations.csv")
stations["station_id"] = stations["station"]
stations = stations.groupby("station_id").first().reset_index()

# %%
ctx = mp.get_context("spawn")
with ctx.Manager() as manager:
data = manager.dict()
pair_list = []
num_pair = 0
for h5 in h5_list:
with h5py.File(h5, "r") as fp:
for id1 in tqdm(fp, desc=f"Loading {h5.split('/')[-1]}", leave=True):
gp1 = fp[id1]
# for id2 in gp1:
# pair_list.append((h5, id1, id2))
# pair_list.append([h5, id1, list(gp1.keys())])
pair_list.append([h5, id1])
num_pair += len(gp1.keys())
# %%
events = pd.read_csv(f"{root_path}/{data_path}/cctorch_events.csv", dtype={"event_index": str})
events["time"] = pd.to_datetime(events["event_time"], format="mixed")

ncpu = max(1, min(32, mp.cpu_count() - 1))
pbar = tqdm(total=len(pair_list), desc="Extracting pairs")
print(f"Total pairs: {num_pair}. Using {ncpu} cores.")
# %%
stations["idx_sta"] = np.arange(len(stations)) # reindex in case the index does not start from 0 or is not continuous
events["idx_eve"] = np.arange(len(events)) # reindex in case the index does not start from 0 or is not continuous
mapping_phase_type_int = {"P": 0, "S": 1}

## Debug
# for pair in pair_list:
# extract_picks(pair, data, config, tt_memmap, station_df)
# pbar.update()
# %%
with open(f"{root_path}/{data_path}/dt.cc", "r") as f:
lines = f.readlines()

with ctx.Pool(processes=ncpu) as pool:
# with mp.Pool(processes=ncpu) as pool:
for pair in pair_list:
pool.apply_async(
extract_picks, args=(pair, data, config, tt_memmap, station_df), callback=lambda x: pbar.update()
)
pool.close()
pool.join()
pbar.close()
# %%
event_index1 = []
event_index2 = []
station_index = []
phase_type = []
phase_score = []
phase_dtime = []

stations.set_index("station_id", inplace=True)
events.set_index("event_index", inplace=True)

for line in tqdm(lines):
if line[0] == "#":
evid1, evid2, _ = line[1:].split()
else:
stid, dt, weight, phase = line.split()
event_index1.append(events.loc[evid1, "idx_eve"])
event_index2.append(events.loc[evid2, "idx_eve"])
station_index.append(stations.loc[stid, "idx_sta"])
phase_type.append(mapping_phase_type_int[phase])
phase_score.append(weight)
phase_dtime.append(dt)


dtypes = np.dtype(
[
("idx_eve1", np.int32),
("idx_eve2", np.int32),
("idx_sta", np.int32),
("phase_type", np.int32),
("phase_score", np.float32),
("phase_dtime", np.float32),
]
)
pairs_array = np.memmap(
f"{root_path}/{result_path}/pair_dt.dat",
mode="w+",
shape=(len(phase_dtime),),
dtype=dtypes,
)
pairs_array["idx_eve1"] = event_index1
pairs_array["idx_eve2"] = event_index2
pairs_array["idx_sta"] = station_index
pairs_array["phase_type"] = phase_type
pairs_array["phase_score"] = phase_score
pairs_array["phase_dtime"] = phase_dtime
with open(f"{root_path}/{result_path}/pair_dtypes.pkl", "wb") as f:
pickle.dump(dtypes, f)

data = dict(data)
print(f"Valid pairs: {len(data)}")

# %%
with open(f"{root_path}/{cctorch_path}/dt.cc", "w") as fp:
for key in tqdm(sorted(data.keys()), desc="Writing dt.cc"):
event_index0 = event_df.loc[key[0]]["event_index"]
event_index1 = event_df.loc[key[1]]["event_index"]
fp.write(f"# {event_index0} {event_index1} 0.000\n")
for record in data[key]:
fp.write(f"{record[0]} {record[1]: .4f} {record[2]:.4f} {record[3]}\n")
# %%
events.to_csv(f"{root_path}/{result_path}/pair_events.csv", index=True, index_label="event_index")
stations.to_csv(f"{root_path}/{result_path}/pair_stations.csv", index=True, index_label="station_id")
12 changes: 6 additions & 6 deletions scripts/generate_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,9 @@ def pairing_picks(event_pairs, picks, config):

dtypes = np.dtype(
[
("event_index1", np.int32),
("event_index2", np.int32),
("station_index", np.int32),
("idx_eve1", np.int32),
("idx_eve2", np.int32),
("idx_sta", np.int32),
("phase_type", np.int32),
("phase_score", np.float32),
("phase_dtime", np.float32),
Expand All @@ -208,9 +208,9 @@ def pairing_picks(event_pairs, picks, config):
shape=(len(event_pairs),),
dtype=dtypes,
)
pairs_array["event_index1"] = event_pairs["idx_eve1"].values
pairs_array["event_index2"] = event_pairs["idx_eve2"].values
pairs_array["station_index"] = event_pairs["idx_sta"].values
pairs_array["idx_eve1"] = event_pairs["idx_eve1"].values
pairs_array["idx_eve2"] = event_pairs["idx_eve2"].values
pairs_array["idx_sta"] = event_pairs["idx_sta"].values
pairs_array["phase_type"] = event_pairs["phase_type"].values
pairs_array["phase_score"] = event_pairs["phase_score"].values
pairs_array["phase_dtime"] = event_pairs["phase_dtime"].values
Expand Down
13 changes: 13 additions & 0 deletions scripts/run_adloc_cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@
"station_index": pairs["idx_sta"],
}
)

pairs_df["time_error"] = pred_time - pairs["dt"]

pairs_df = pairs_df[valid_index]
config["MIN_OBS"] = 8
pairs_df = pairs_df.groupby(["event_index1", "event_index2"], as_index=False, group_keys=False).filter(
Expand All @@ -239,6 +242,16 @@

phase_dataset.valid_index = valid_index

## correct origin time
time_shift = np.zeros(len(travel_time.event_time.weight))
time_count = np.zeros(len(travel_time.event_time.weight))
np.add.at(time_shift, pairs_df["event_index1"].values, pairs_df["time_error"].values)
np.add.at(time_shift, pairs_df["event_index2"].values, -pairs_df["time_error"].values)
np.add.at(time_count, pairs_df["event_index1"].values, 1)
np.add.at(time_count, pairs_df["event_index2"].values, 1)
time_shift[time_count > 0] /= time_count[time_count > 0]
travel_time.event_time.weight.data -= torch.tensor(time_shift[:, None], dtype=torch.float32)

invert_event_loc = raw_travel_time.event_loc.weight.clone().detach().numpy()
invert_event_time = raw_travel_time.event_time.weight.clone().detach().numpy()
valid_event_index = np.unique(pairs["idx_eve1"][valid_index])
Expand Down
24 changes: 19 additions & 5 deletions scripts/run_adloc_ct.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,14 @@

pairs_df = pd.DataFrame(
{
"event_index1": pairs["event_index1"],
"event_index2": pairs["event_index2"],
"station_index": pairs["station_index"],
"event_index1": pairs["idx_eve1"],
"event_index2": pairs["idx_eve2"],
"station_index": pairs["idx_sta"],
}
)

pairs_df["time_error"] = pred_time - pairs["phase_dtime"]

pairs_df = pairs_df[valid_index]
config["MIN_OBS"] = 8
pairs_df = pairs_df.groupby(["event_index1", "event_index2"], as_index=False, group_keys=False).filter(
Expand All @@ -238,11 +241,22 @@

phase_dataset.valid_index = valid_index

## correct origin time
time_shift = np.zeros(len(travel_time.event_time.weight))
time_count = np.zeros(len(travel_time.event_time.weight))
np.add.at(time_shift, pairs_df["event_index1"].values, pairs_df["time_error"].values)
np.add.at(time_shift, pairs_df["event_index2"].values, -pairs_df["time_error"].values)
np.add.at(time_count, pairs_df["event_index1"].values, 1)
np.add.at(time_count, pairs_df["event_index2"].values, 1)
time_shift[time_count > 0] /= time_count[time_count > 0]
print(f"{np.mean(time_shift):.3f} {np.std(time_shift):.3f}")
travel_time.event_time.weight.data -= torch.tensor(time_shift[:, None], dtype=torch.float32)

invert_event_loc = raw_travel_time.event_loc.weight.clone().detach().numpy()
invert_event_time = raw_travel_time.event_time.weight.clone().detach().numpy()
valid_event_index = np.unique(pairs["event_index1"][valid_index])
# valid_event_index = np.unique(pairs["event_index1"][valid_index])
valid_event_index = np.concatenate(
[np.unique(pairs["event_index1"][valid_index]), np.unique(pairs["event_index2"][valid_index])]
[np.unique(pairs["idx_eve1"][valid_index]), np.unique(pairs["idx_eve2"][valid_index])]
)
valid_event_index = np.sort(np.unique(valid_event_index))

Expand Down
Loading

0 comments on commit a0f9dd0

Please sign in to comment.