Skip to content

Commit

Permalink
update phasenet phasenet_plus training
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Dec 23, 2023
1 parent 3607432 commit d2f6626
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 144 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

.DS_Store
*.pyc
*.pth
Expand All @@ -8,7 +7,6 @@ output*/
Trash/
results/

autoencoder*
model_phasenet*
EQNet.egg-info/
test_data/
Expand Down
3 changes: 1 addition & 2 deletions eqnet/data/das.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from scipy.interpolate import interp1d
from torch.utils.data import Dataset, IterableDataset

mp.set_start_method("spawn", force=True)

# mp.set_start_method("spawn", force=True)

def normalize(data: torch.Tensor):
"""channel-wise normalization
Expand Down
47 changes: 26 additions & 21 deletions eqnet/data/seismic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch.nn.functional as F
from scipy import signal
from torch.utils.data import Dataset, IterableDataset
from tqdm import tqdm

# import warnings
# warnings.filterwarnings("error")
Expand Down Expand Up @@ -298,17 +299,19 @@ def __init__(
super().__init__()
self.rank = rank
self.world_size = world_size
self.hdf5_fp = None
if hdf5_file is not None:
fp = h5py.File(hdf5_file, "r")
self.hdf5_fp = fp
tmp_hdf5_keys = f"/tmp/{hdf5_file.split('/')[-1]}.txt"
if not os.path.exists(tmp_hdf5_keys):
self.data_list = [event + "/" + station for event in fp.keys() for station in list(fp[event].keys())]
with open(tmp_hdf5_keys, "w") as f:
for x in self.data_list:
f.write(x + "\n")
with h5py.File(hdf5_file, "r", libver="latest", swmr=True) as fp:
self.data_list = []
for event in tqdm(list(fp.keys()), desc="Caching HDF5 keys"):
for station in list(fp[event].keys()):
self.data_list.append(event + "/" + station)
with open(tmp_hdf5_keys, "w") as f:
f.write("\n".join(self.data_list))
print(f"Saved {tmp_hdf5_keys}")
else:
print(f"Reading {tmp_hdf5_keys}")
self.data_list = pd.read_csv(tmp_hdf5_keys, header=None, names=["trace_id"])["trace_id"].values.tolist()
elif data_list is not None:
with open(data_list, "r") as f:
Expand Down Expand Up @@ -450,8 +453,9 @@ def calc_snr(self, waveform, picks, noise_window=300, signal_window=300, gap_win
# picks_.append(tmp)
# return data_, picks_, noise_

def _read_training_h5(self, trace_id):
if self.hdf5_fp is None:
def read_training_h5(self, trace_id, hdf5_fp=None):
close_hdf5 = False
if hdf5_fp is None:
hdf5_fp = h5py.File(os.path.join(self.data_path, trace_id), "r")
event_id = "data"
sta_ids = list(hdf5_fp["data"].keys())
Expand All @@ -464,14 +468,13 @@ def _read_training_h5(self, trace_id):
tmp_max = np.max(np.abs(waveform), axis=1)
if np.all(tmp_max > 0): ## three component data
break
close_hdf5 = True
else:
hdf5_fp = self.hdf5_fp
event_id, sta_id = trace_id.split("/")
waveform = hdf5_fp[trace_id][:, :]
if waveform.shape[1] == 3:
waveform = waveform.T # [3, Nt]

# waveform = hdf5_fp[trace_id][:, :].T # [3, Nt]
waveform = normalize(waveform)
nch, nt = waveform.shape

Expand All @@ -495,11 +498,12 @@ def _read_training_h5(self, trace_id):
## phase polarity
up = attrs["phase_index"][attrs["phase_polarity"] == "U"]
dn = attrs["phase_index"][attrs["phase_polarity"] == "D"]
## assuming having both P and S picks
mask_width = (
attrs["phase_index"][attrs["phase_type"] == "S"] - attrs["phase_index"][attrs["phase_type"] == "P"]
) // 2
mask_width = int(min(mask_width))
## using the minimum P-S
mask_width = np.min(
attrs["phase_index"][attrs["phase_type"] == "S"][:, np.newaxis]
- attrs["phase_index"][attrs["phase_type"] == "P"][np.newaxis, :]
)
mask_width = max(100, int(mask_width / 2.0))
phase_up, mask_up = generate_label(
[up], nt=nt, label_width=self.polarity_width, mask_width=mask_width, return_mask=True
)
Expand Down Expand Up @@ -555,7 +559,7 @@ def _read_training_h5(self, trace_id):
# event_location[0, :] = np.arange(nt) - hdf5_fp[event_id].attrs["time_index"]
event_location[1:, event_mask >= 1.0] = np.array([dx, dy, dz])[:, np.newaxis]

if self.hdf5_fp is None:
if close_hdf5:
hdf5_fp.close()

return {
Expand All @@ -577,11 +581,11 @@ def _read_training_h5(self, trace_id):
}

def sample_train(self, data_list):
hdf5_fp = h5py.File(self.hdf5_file, "r", libver="latest", swmr=True)
while True:
trace_id = np.random.choice(data_list)
# if True:
try:
meta = self._read_training_h5(trace_id)
meta = self.read_training_h5(trace_id, hdf5_fp)
except Exception as e:
print(f"Error reading {trace_id}:\n{e}")
continue
Expand All @@ -591,10 +595,9 @@ def sample_train(self, data_list):

# if self.stack_event and (random.random() < 0.6):
if self.stack_event:
# if True:
try:
trace_id2 = np.random.choice(self.data_list)
meta2 = self._read_training_h5(trace_id2)
meta2 = self.read_training_h5(trace_id2, hdf5_fp)
if meta2 is not None:
meta = stack_event(meta, meta2)
except Exception as e:
Expand Down Expand Up @@ -639,6 +642,8 @@ def sample_train(self, data_list):
"polarity_mask": torch.from_numpy(polarity_mask).float(),
}

hdf5_fp.close()

def taper(stream):
for tr in stream:
tr.taper(max_percentage=0.05, type="cosine")
Expand Down
3 changes: 2 additions & 1 deletion eqnet/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .autoencoder import *
from .eqnet import *
from .phasenet import *
from .phasenet_das import *
from .autoencoder import *
from .phasenet_plus import *
83 changes: 30 additions & 53 deletions eqnet/models/phasenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,18 @@ def __init__(
self,
backbone="unet",
log_scale=True,
add_polarity=True,
add_event=True,
add_polarity=False,
add_event=False,
event_loss_weight=1.0,
polarity_loss_weight=1.0,
) -> None:
super().__init__()
self.backbone_name = backbone
self.add_event = add_event
self.add_polarity = add_polarity
self.event_loss_weight = event_loss_weight
self.polarity_loss_weight = polarity_loss_weight

if backbone == "resnet18":
self.backbone = ResNet(BasicBlock, [2, 2, 2, 2]) # ResNet18
elif backbone == "resnet50":
Expand All @@ -253,14 +257,10 @@ def __init__(
self.polarity_picker = UNetHead(16, 1, feature_names="polarity")
else:
self.phase_picker = DeepLabHead(128, 3, scale_factor=32)
self.event_detector = DeepLabHead(128, 1, scale_factor=2)
if self.add_event:
self.event_detector = DeepLabHead(128, 1, scale_factor=2)
if self.add_polarity:
self.polarity_picker = DeepLabHead(128, 1, scale_factor=32)
# self.phase_picker = FCNHead(128, 3)
# self.event_detector = FCNHead(128, 1)

self.event_loss_weight = event_loss_weight
self.polarity_loss_weight = polarity_loss_weight

@property
def device(self):
Expand All @@ -269,21 +269,14 @@ def device(self):
def forward(self, batched_inputs: Tensor) -> Dict[str, Tensor]:
data = batched_inputs["data"].to(self.device)

if self.training:
phase_pick = batched_inputs["phase_pick"].to(self.device)
event_center = batched_inputs["event_center"].to(self.device)
event_location = batched_inputs["event_location"].to(self.device)
event_mask = batched_inputs["event_mask"].to(self.device)
if self.add_polarity:
polarity = batched_inputs["polarity"].to(self.device)
polarity_mask = batched_inputs["polarity_mask"].to(self.device)
else:
phase_pick = None
event_center = None
event_location = None
event_mask = None
polarity = None
polarity_mask = None
phase_pick = batched_inputs["phase_pick"].to(self.device) if "phase_pick" in batched_inputs else None
event_center = batched_inputs["event_center"].to(self.device) if "event_center" in batched_inputs else None
event_location = (
batched_inputs["event_location"].to(self.device) if "event_location" in batched_inputs else None
)
event_mask = batched_inputs["event_mask"].to(self.device) if "event_mask" in batched_inputs else None
polarity = batched_inputs["polarity"].to(self.device) if "polarity" in batched_inputs else None
polarity_mask = batched_inputs["polarity_mask"].to(self.device) if "polarity_mask" in batched_inputs else None

if self.backbone_name == "swin2":
station_location = batched_inputs["station_location"].to(self.device)
Expand All @@ -292,45 +285,29 @@ def forward(self, batched_inputs: Tensor) -> Dict[str, Tensor]:
features = self.backbone(data)
# features: (batch, station, channel, time)

output = {"loss": 0.0}
output_phase, loss_phase = self.phase_picker(features, phase_pick)
output_event, loss_event = self.event_detector(features, event_center)
output["phase"] = output_phase
output["loss_phase"] = loss_phase
output["loss"] += loss_phase
if self.add_event:
output_event, loss_event = self.event_detector(features, event_center)
output["event"] = output_event
output["loss_event"] = loss_event
output["loss"] += loss_event * self.event_loss_weight
if self.add_polarity:
output_polarity, loss_polarity = self.polarity_picker(features, polarity, mask=polarity_mask)
else:
output_polarity, loss_polarity = None, 0.0

# print(f"{data.shape = }")
# print(f"{phase_pick.shape = }")
# print(f"{event_center.shape = }")
# print(f"{output_phase.shape = }")
# print(f"{output_event.shape = }")
output["polarity"] = output_polarity
output["loss_polarity"] = loss_polarity
output["loss"] += loss_polarity * self.polarity_loss_weight

return {
"loss": loss_phase + loss_event * self.event_loss_weight + loss_polarity * self.polarity_loss_weight,
"loss_phase": loss_phase,
"loss_event": loss_event,
"loss_polarity": loss_polarity,
"phase": output_phase,
"event": output_event,
"polarity": output_polarity,
}
return output


def build_model(
backbone="unet",
log_scale=True,
add_polarity=True,
add_event=True,
event_loss_weight=1.0,
polarity_loss_weight=1.0,
*args,
**kwargs,
) -> PhaseNet:
return PhaseNet(
backbone=backbone,
log_scale=log_scale,
add_event=add_event,
add_polarity=add_polarity,
event_loss_weight=event_loss_weight,
polarity_loss_weight=polarity_loss_weight,
)
return PhaseNet(backbone=backbone, log_scale=log_scale)
21 changes: 21 additions & 0 deletions eqnet/models/phasenet_plus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .phasenet import PhaseNet


def build_model(
backbone="unet",
log_scale=True,
add_polarity=True,
add_event=True,
event_loss_weight=1.0,
polarity_loss_weight=1.0,
*args,
**kwargs,
) -> PhaseNet:
return PhaseNet(
backbone=backbone,
log_scale=log_scale,
add_event=add_event,
add_polarity=add_polarity,
event_loss_weight=event_loss_weight,
polarity_loss_weight=polarity_loss_weight,
)
Loading

0 comments on commit d2f6626

Please sign in to comment.