From d2f6626beacb0b3ab5f0e28a935d9a8a49ad4a6e Mon Sep 17 00:00:00 2001 From: zhuwq Date: Fri, 22 Dec 2023 21:44:04 -0800 Subject: [PATCH] update phasenet phasenet_plus training --- .gitignore | 2 - eqnet/data/das.py | 3 +- eqnet/data/seismic_trace.py | 47 ++++++++++-------- eqnet/models/__init__.py | 3 +- eqnet/models/phasenet.py | 83 +++++++++++-------------------- eqnet/models/phasenet_plus.py | 21 ++++++++ eqnet/utils/visualization.py | 86 +++++++++++++++++++++++--------- train.py | 93 +++++++++++++++++++++-------------- train.yaml | 86 ++++++++++++++++++++++++++++++++ utils.py | 11 ++--- 10 files changed, 291 insertions(+), 144 deletions(-) create mode 100644 eqnet/models/phasenet_plus.py create mode 100644 train.yaml diff --git a/.gitignore b/.gitignore index bbedf69..462a50b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ - .DS_Store *.pyc *.pth @@ -8,7 +7,6 @@ output*/ Trash/ results/ -autoencoder* model_phasenet* EQNet.egg-info/ test_data/ diff --git a/eqnet/data/das.py b/eqnet/data/das.py index 473413f..67f9343 100644 --- a/eqnet/data/das.py +++ b/eqnet/data/das.py @@ -17,8 +17,7 @@ from scipy.interpolate import interp1d from torch.utils.data import Dataset, IterableDataset -mp.set_start_method("spawn", force=True) - +# mp.set_start_method("spawn", force=True) def normalize(data: torch.Tensor): """channel-wise normalization diff --git a/eqnet/data/seismic_trace.py b/eqnet/data/seismic_trace.py index 843e269..f1c12cb 100644 --- a/eqnet/data/seismic_trace.py +++ b/eqnet/data/seismic_trace.py @@ -15,6 +15,7 @@ import torch.nn.functional as F from scipy import signal from torch.utils.data import Dataset, IterableDataset +from tqdm import tqdm # import warnings # warnings.filterwarnings("error") @@ -298,17 +299,19 @@ def __init__( super().__init__() self.rank = rank self.world_size = world_size - self.hdf5_fp = None if hdf5_file is not None: - fp = h5py.File(hdf5_file, "r") - self.hdf5_fp = fp tmp_hdf5_keys = f"/tmp/{hdf5_file.split('/')[-1]}.txt" if not os.path.exists(tmp_hdf5_keys): - self.data_list = [event + "/" + station for event in fp.keys() for station in list(fp[event].keys())] - with open(tmp_hdf5_keys, "w") as f: - for x in self.data_list: - f.write(x + "\n") + with h5py.File(hdf5_file, "r", libver="latest", swmr=True) as fp: + self.data_list = [] + for event in tqdm(list(fp.keys()), desc="Caching HDF5 keys"): + for station in list(fp[event].keys()): + self.data_list.append(event + "/" + station) + with open(tmp_hdf5_keys, "w") as f: + f.write("\n".join(self.data_list)) + print(f"Saved {tmp_hdf5_keys}") else: + print(f"Reading {tmp_hdf5_keys}") self.data_list = pd.read_csv(tmp_hdf5_keys, header=None, names=["trace_id"])["trace_id"].values.tolist() elif data_list is not None: with open(data_list, "r") as f: @@ -450,8 +453,9 @@ def calc_snr(self, waveform, picks, noise_window=300, signal_window=300, gap_win # picks_.append(tmp) # return data_, picks_, noise_ - def _read_training_h5(self, trace_id): - if self.hdf5_fp is None: + def read_training_h5(self, trace_id, hdf5_fp=None): + close_hdf5 = False + if hdf5_fp is None: hdf5_fp = h5py.File(os.path.join(self.data_path, trace_id), "r") event_id = "data" sta_ids = list(hdf5_fp["data"].keys()) @@ -464,14 +468,13 @@ def _read_training_h5(self, trace_id): tmp_max = np.max(np.abs(waveform), axis=1) if np.all(tmp_max > 0): ## three component data break + close_hdf5 = True else: - hdf5_fp = self.hdf5_fp event_id, sta_id = trace_id.split("/") waveform = hdf5_fp[trace_id][:, :] if waveform.shape[1] == 3: waveform = waveform.T # [3, Nt] - # waveform = hdf5_fp[trace_id][:, :].T # [3, Nt] waveform = normalize(waveform) nch, nt = waveform.shape @@ -495,11 +498,12 @@ def _read_training_h5(self, trace_id): ## phase polarity up = attrs["phase_index"][attrs["phase_polarity"] == "U"] dn = attrs["phase_index"][attrs["phase_polarity"] == "D"] - ## assuming having both P and S picks - mask_width = ( - attrs["phase_index"][attrs["phase_type"] == "S"] - attrs["phase_index"][attrs["phase_type"] == "P"] - ) // 2 - mask_width = int(min(mask_width)) + ## using the minimum P-S + mask_width = np.min( + attrs["phase_index"][attrs["phase_type"] == "S"][:, np.newaxis] + - attrs["phase_index"][attrs["phase_type"] == "P"][np.newaxis, :] + ) + mask_width = max(100, int(mask_width / 2.0)) phase_up, mask_up = generate_label( [up], nt=nt, label_width=self.polarity_width, mask_width=mask_width, return_mask=True ) @@ -555,7 +559,7 @@ def _read_training_h5(self, trace_id): # event_location[0, :] = np.arange(nt) - hdf5_fp[event_id].attrs["time_index"] event_location[1:, event_mask >= 1.0] = np.array([dx, dy, dz])[:, np.newaxis] - if self.hdf5_fp is None: + if close_hdf5: hdf5_fp.close() return { @@ -577,11 +581,11 @@ def _read_training_h5(self, trace_id): } def sample_train(self, data_list): + hdf5_fp = h5py.File(self.hdf5_file, "r", libver="latest", swmr=True) while True: trace_id = np.random.choice(data_list) - # if True: try: - meta = self._read_training_h5(trace_id) + meta = self.read_training_h5(trace_id, hdf5_fp) except Exception as e: print(f"Error reading {trace_id}:\n{e}") continue @@ -591,10 +595,9 @@ def sample_train(self, data_list): # if self.stack_event and (random.random() < 0.6): if self.stack_event: - # if True: try: trace_id2 = np.random.choice(self.data_list) - meta2 = self._read_training_h5(trace_id2) + meta2 = self.read_training_h5(trace_id2, hdf5_fp) if meta2 is not None: meta = stack_event(meta, meta2) except Exception as e: @@ -639,6 +642,8 @@ def sample_train(self, data_list): "polarity_mask": torch.from_numpy(polarity_mask).float(), } + hdf5_fp.close() + def taper(stream): for tr in stream: tr.taper(max_percentage=0.05, type="cosine") diff --git a/eqnet/models/__init__.py b/eqnet/models/__init__.py index 4607d1d..81f1593 100644 --- a/eqnet/models/__init__.py +++ b/eqnet/models/__init__.py @@ -1,4 +1,5 @@ +from .autoencoder import * from .eqnet import * from .phasenet import * from .phasenet_das import * -from .autoencoder import * +from .phasenet_plus import * diff --git a/eqnet/models/phasenet.py b/eqnet/models/phasenet.py index 51a9f7b..b916b55 100644 --- a/eqnet/models/phasenet.py +++ b/eqnet/models/phasenet.py @@ -229,14 +229,18 @@ def __init__( self, backbone="unet", log_scale=True, - add_polarity=True, - add_event=True, + add_polarity=False, + add_event=False, event_loss_weight=1.0, polarity_loss_weight=1.0, ) -> None: super().__init__() self.backbone_name = backbone + self.add_event = add_event self.add_polarity = add_polarity + self.event_loss_weight = event_loss_weight + self.polarity_loss_weight = polarity_loss_weight + if backbone == "resnet18": self.backbone = ResNet(BasicBlock, [2, 2, 2, 2]) # ResNet18 elif backbone == "resnet50": @@ -253,14 +257,10 @@ def __init__( self.polarity_picker = UNetHead(16, 1, feature_names="polarity") else: self.phase_picker = DeepLabHead(128, 3, scale_factor=32) - self.event_detector = DeepLabHead(128, 1, scale_factor=2) + if self.add_event: + self.event_detector = DeepLabHead(128, 1, scale_factor=2) if self.add_polarity: self.polarity_picker = DeepLabHead(128, 1, scale_factor=32) - # self.phase_picker = FCNHead(128, 3) - # self.event_detector = FCNHead(128, 1) - - self.event_loss_weight = event_loss_weight - self.polarity_loss_weight = polarity_loss_weight @property def device(self): @@ -269,21 +269,14 @@ def device(self): def forward(self, batched_inputs: Tensor) -> Dict[str, Tensor]: data = batched_inputs["data"].to(self.device) - if self.training: - phase_pick = batched_inputs["phase_pick"].to(self.device) - event_center = batched_inputs["event_center"].to(self.device) - event_location = batched_inputs["event_location"].to(self.device) - event_mask = batched_inputs["event_mask"].to(self.device) - if self.add_polarity: - polarity = batched_inputs["polarity"].to(self.device) - polarity_mask = batched_inputs["polarity_mask"].to(self.device) - else: - phase_pick = None - event_center = None - event_location = None - event_mask = None - polarity = None - polarity_mask = None + phase_pick = batched_inputs["phase_pick"].to(self.device) if "phase_pick" in batched_inputs else None + event_center = batched_inputs["event_center"].to(self.device) if "event_center" in batched_inputs else None + event_location = ( + batched_inputs["event_location"].to(self.device) if "event_location" in batched_inputs else None + ) + event_mask = batched_inputs["event_mask"].to(self.device) if "event_mask" in batched_inputs else None + polarity = batched_inputs["polarity"].to(self.device) if "polarity" in batched_inputs else None + polarity_mask = batched_inputs["polarity_mask"].to(self.device) if "polarity_mask" in batched_inputs else None if self.backbone_name == "swin2": station_location = batched_inputs["station_location"].to(self.device) @@ -292,45 +285,29 @@ def forward(self, batched_inputs: Tensor) -> Dict[str, Tensor]: features = self.backbone(data) # features: (batch, station, channel, time) + output = {"loss": 0.0} output_phase, loss_phase = self.phase_picker(features, phase_pick) - output_event, loss_event = self.event_detector(features, event_center) + output["phase"] = output_phase + output["loss_phase"] = loss_phase + output["loss"] += loss_phase + if self.add_event: + output_event, loss_event = self.event_detector(features, event_center) + output["event"] = output_event + output["loss_event"] = loss_event + output["loss"] += loss_event * self.event_loss_weight if self.add_polarity: output_polarity, loss_polarity = self.polarity_picker(features, polarity, mask=polarity_mask) - else: - output_polarity, loss_polarity = None, 0.0 - - # print(f"{data.shape = }") - # print(f"{phase_pick.shape = }") - # print(f"{event_center.shape = }") - # print(f"{output_phase.shape = }") - # print(f"{output_event.shape = }") + output["polarity"] = output_polarity + output["loss_polarity"] = loss_polarity + output["loss"] += loss_polarity * self.polarity_loss_weight - return { - "loss": loss_phase + loss_event * self.event_loss_weight + loss_polarity * self.polarity_loss_weight, - "loss_phase": loss_phase, - "loss_event": loss_event, - "loss_polarity": loss_polarity, - "phase": output_phase, - "event": output_event, - "polarity": output_polarity, - } + return output def build_model( backbone="unet", log_scale=True, - add_polarity=True, - add_event=True, - event_loss_weight=1.0, - polarity_loss_weight=1.0, *args, **kwargs, ) -> PhaseNet: - return PhaseNet( - backbone=backbone, - log_scale=log_scale, - add_event=add_event, - add_polarity=add_polarity, - event_loss_weight=event_loss_weight, - polarity_loss_weight=polarity_loss_weight, - ) + return PhaseNet(backbone=backbone, log_scale=log_scale) diff --git a/eqnet/models/phasenet_plus.py b/eqnet/models/phasenet_plus.py new file mode 100644 index 0000000..7b14937 --- /dev/null +++ b/eqnet/models/phasenet_plus.py @@ -0,0 +1,21 @@ +from .phasenet import PhaseNet + + +def build_model( + backbone="unet", + log_scale=True, + add_polarity=True, + add_event=True, + event_loss_weight=1.0, + polarity_loss_weight=1.0, + *args, + **kwargs, +) -> PhaseNet: + return PhaseNet( + backbone=backbone, + log_scale=log_scale, + add_event=add_event, + add_polarity=add_polarity, + event_loss_weight=event_loss_weight, + polarity_loss_weight=polarity_loss_weight, + ) diff --git a/eqnet/utils/visualization.py b/eqnet/utils/visualization.py index 4d6768b..99c603c 100644 --- a/eqnet/utils/visualization.py +++ b/eqnet/utils/visualization.py @@ -18,7 +18,7 @@ def normalize(x): return x -def visualize_autoencoder_das_train(meta, preds, epoch, figure_dir="figures"): +def plot_autoencoder_das_train(meta, preds, epoch, figure_dir="figures"): meta_data = meta["data"] raw_data = meta_data.clone().permute(0, 2, 3, 1).numpy() # data = normalize_local(meta_data.clone()).permute(0, 2, 3, 1).numpy() @@ -61,7 +61,7 @@ def visualize_autoencoder_das_train(meta, preds, epoch, figure_dir="figures"): plt.close(fig) -def visualize_das_train(meta, preds, epoch, figure_dir="figures", dt=0.01, dx=10, prefix=""): +def plot_das_train(meta, preds, epoch, figure_dir="figures", dt=0.01, dx=10, prefix=""): meta_data = meta["data"].cpu() raw_data = meta_data.clone().permute(0, 2, 3, 1).numpy() # data = normalize_local(meta_data.clone()).permute(0, 2, 3, 1).numpy() @@ -148,31 +148,71 @@ def visualize_das_train(meta, preds, epoch, figure_dir="figures", dt=0.01, dx=10 plt.close(fig) -def visualize_phasenet_train(meta, phase, event, polarity=None, epoch=0, figure_dir="figures"): +def plot_phasenet_train(meta, phase, epoch=0, figure_dir="figures", prefix=""): for i in range(meta["data"].shape[0]): plt.close("all") - fig, axes = plt.subplots(9, 1, figsize=(10, 10)) + chn_name = ["E", "N", "Z"] - # chn_id = list(range(meta["waveform_raw"].shape[1])) - # random.shuffle(chn_id) - # for j in chn_id: - # if torch.max(torch.abs(meta["waveform_raw"][i, j, :, 0])) > 0.1: - # axes[0].plot(meta["waveform_raw"][i, j, :, 0], linewidth=0.5, color=f"C{j}", label=f"{chn_name[j]}") - # axes[0].legend(loc="upper right") - # axes[1].plot(meta["data"][i, j, :, 0], linewidth=0.5, color=f"C{j}", label=f"{chn_name[j]}") - # axes[1].legend(loc="upper right") - # break + + if "raw_data" in meta: + shift = 3 + fig, axes = plt.subplots(7, 1, figsize=(10, 10)) + for j in range(3): + axes[j].plot(meta["raw_data"][i, j, :, 0], linewidth=0.5, color="k", label=f"{chn_name[j]}") + axes[j].set_xticklabels([]) + axes[j].grid("on") + else: + fig, axes = plt.subplots(4, 1, figsize=(10, 10)) + shift = 0 for j in range(3): - axes[j].plot(meta["data_raw"][i, j, :, 0], linewidth=0.5, color="k", label=f"{chn_name[j]}") - axes[j].set_xticklabels([]) - axes[j].grid("on") + axes[j + shift].plot(meta["data"][i, j, :, 0], linewidth=0.5, color="k", label=f"{chn_name[j]}") + axes[j + shift].set_xticklabels([]) + axes[j + shift].grid("on") + + k = 3 + shift + axes[k].plot(phase[i, 1, :, 0], "b") + axes[k].plot(phase[i, 2, :, 0], "r") + axes[k].plot(meta["phase_pick"][i, 1, :, 0], "--C0") + axes[k].plot(meta["phase_pick"][i, 2, :, 0], "--C3") + axes[k].plot(meta["phase_mask"][i, 0, :, 0], ":", color="gray") + axes[k].set_ylim(-0.05, 1.05) + axes[k].set_xticklabels([]) + axes[k].grid("on") + + if "RANK" in os.environ: + rank = int(os.environ["RANK"]) + fig.savefig(f"{figure_dir}/{prefix}{epoch:02d}_{rank:02d}_{i:02d}.png", dpi=300) + else: + fig.savefig(f"{figure_dir}/{prefix}{epoch:02d}_{i:02d}.png", dpi=300) + + if i >= 20: + break + + +def plot_phasenet_plus_train(meta, phase, event, polarity=None, epoch=0, figure_dir="figures", prefix=""): + for i in range(meta["data"].shape[0]): + plt.close("all") + + chn_name = ["E", "N", "Z"] + + if "raw_data" in meta: + shift = 3 + fig, axes = plt.subplots(9, 1, figsize=(10, 10)) + for j in range(3): + axes[j].plot(meta["raw_data"][i, j, :, 0], linewidth=0.5, color="k", label=f"{chn_name[j]}") + axes[j].set_xticklabels([]) + axes[j].grid("on") + else: + fig, axes = plt.subplots(6, 1, figsize=(10, 10)) + shift = 0 + for j in range(3): - axes[j + 3].plot(meta["data"][i, j, :, 0], linewidth=0.5, color="k", label=f"{chn_name[j]}") - axes[j + 3].set_xticklabels([]) - axes[j + 3].grid("on") + axes[j + shift].plot(meta["data"][i, j, :, 0], linewidth=0.5, color="k", label=f"{chn_name[j]}") + axes[j + shift].set_xticklabels([]) + axes[j + shift].grid("on") - k = 6 + k = 3 + shift axes[k].plot(phase[i, 1, :, 0], "b") axes[k].plot(phase[i, 2, :, 0], "r") axes[k].plot(meta["phase_pick"][i, 1, :, 0], "--C0") @@ -198,9 +238,9 @@ def visualize_phasenet_train(meta, phase, event, polarity=None, epoch=0, figure_ if "RANK" in os.environ: rank = int(os.environ["RANK"]) - fig.savefig(f"{figure_dir}/{epoch:02d}_{rank:02d}_{i:02d}.png", dpi=300) + fig.savefig(f"{figure_dir}/{prefix}{epoch:02d}_{rank:02d}_{i:02d}.png", dpi=300) else: - fig.savefig(f"{figure_dir}/{epoch:02d}_{i:02d}.png", dpi=300) + fig.savefig(f"{figure_dir}/{prefix}{epoch:02d}_{i:02d}.png", dpi=300) if i >= 20: break @@ -314,7 +354,7 @@ def plot_phasenet( plt.close(fig) -def visualize_eqnet_train(meta, phase, event, epoch, figure_dir="figures"): +def plot_eqnet_train(meta, phase, event, epoch, figure_dir="figures"): for i in range(meta["data"].shape[0]): plt.close("all") fig, axes = plt.subplots(3, 1, figsize=(10, 10)) diff --git a/train.py b/train.py index 9c00720..1a49f75 100644 --- a/train.py +++ b/train.py @@ -19,7 +19,6 @@ import utils from eqnet.data import ( AutoEncoderIterableDataset, - DASDataset, DASIterableDataset, SeismicNetworkIterableDataset, SeismicTraceIterableDataset, @@ -31,30 +30,30 @@ logger = logging.getLogger("EQNet") -def evaluate(model, data_loader, scaler, args, epoch=0, total_sample=1): +def evaluate(model, data_loader, scaler, args, epoch=0, total_samples=1): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = f"Test: " - num_processed_samples = 0 + processed_samples = 0 with torch.inference_mode(): for meta in metric_logger.log_every(data_loader, args.print_freq, header): output = model(meta) loss = output["loss"] batch_size = meta["data"].shape[0] metric_logger.meters["loss"].update(loss.item(), n=batch_size) - num_processed_samples += batch_size - if num_processed_samples > total_sample: + processed_samples += batch_size + if processed_samples > total_samples: break - plot_results(meta, model, output, args, epoch, "test_") - del meta, output, loss - metric_logger.synchronize_between_processes() print(f"Test loss = {metric_logger.loss.global_avg:.3e}") if args.wandb and utils.is_main_process(): wandb.log({"test/test_loss": metric_logger.loss.global_avg, "test/epoch": epoch}) + plot_results(meta, model, output, args, epoch, "test_") + del meta, output, loss + return metric_logger @@ -67,25 +66,29 @@ def train_one_epoch( scaler, args, epoch, - total_sample, + total_samples, ): metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) - if args.model == "phasenet": + if args.model == "phasenet_plus": metric_logger.add_meter("loss_phase", utils.SmoothedValue(window_size=1, fmt="{value}")) metric_logger.add_meter("loss_event", utils.SmoothedValue(window_size=1, fmt="{value}")) metric_logger.add_meter("loss_polarity", utils.SmoothedValue(window_size=1, fmt="{value}")) header = f"Epoch: [{epoch}]" - # ctx = nullcontext() if scaler is None else torch.amp.autocast(device_type=args.device, dtype=args.ptdtype) - ctx = nullcontext() if args.device == "cpu" else torch.amp.autocast(device_type=args.device, dtype=args.ptdtype) + ctx = ( + nullcontext() + if args.device in ["cpu", "mps"] + else torch.amp.autocast(device_type=args.device, dtype=args.ptdtype) + ) model.train() - num_processed_samples = 0 + processed_samples = 0 for i, meta in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): with ctx: output = model(meta) loss = output["loss"] + optimizer.zero_grad() if scaler is not None: scaler.scale(loss).backward() @@ -107,13 +110,15 @@ def train_one_epoch( model_ema.n_averaged.fill_(0) batch_size = meta["data"].shape[0] - num_processed_samples += batch_size - if num_processed_samples >= total_sample: + processed_samples += batch_size + if processed_samples >= total_samples: break - # break - - metric_logger.update(lr=optimizer.param_groups[0]["lr"], loss=loss.item()) + metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) + if args.model == "phasenet_plus": + metric_logger.update(loss_phase=output["loss_phase"].item()) + metric_logger.update(loss_event=output["loss_event"].item()) + metric_logger.update(loss_polarity=output["loss_polarity"].item()) if args.wandb and utils.is_main_process(): wandb.log( { @@ -132,36 +137,43 @@ def train_one_epoch( def plot_results(meta, model, output, args, epoch, prefix=""): with torch.inference_mode(): if args.model == "phasenet": + phase = torch.softmax(output["phase"], dim=1).cpu().float() + meta["data"] = moving_normalize(meta["data"]) + print("Plotting...") + eqnet.utils.plot_phasenet_train(meta, phase, epoch=epoch, figure_dir=args.figure_dir, prefix=prefix) + del phase + elif args.model == "phasenet_plus": phase = torch.softmax(output["phase"], dim=1).cpu().float() event = torch.sigmoid(output["event"]).cpu().float() polarity = torch.sigmoid(output["polarity"]).cpu().float() - # meta["raw"] = meta["data"].clone() meta["data"] = moving_normalize(meta["data"]) print("Plotting...") - eqnet.utils.visualize_phasenet_train(meta, phase, event, polarity, epoch=epoch, figure_dir=args.figure_dir) + eqnet.utils.plot_phasenet_plus_train( + meta, phase, event, polarity, epoch=epoch, figure_dir=args.figure_dir, prefix=prefix + ) del phase, event, polarity - if args.model == "deepdenoiser": + elif args.model == "deepdenoiser": pass elif args.model == "phasenet_das": phase = torch.softmax(output["phase"], dim=1).cpu().float() meta["data"] = moving_normalize(meta["data"], filter=2048, stride=256) print("Plotting...") - eqnet.utils.visualize_das_train(meta, phase, epoch=epoch, figure_dir=args.figure_dir, prefix=prefix) + eqnet.utils.plot_das_train(meta, phase, epoch=epoch, figure_dir=args.figure_dir, prefix=prefix) del phase elif args.model == "autoencoder": preds = model(meta) print("Plotting...") - eqnet.utils.visualize_autoencoder_das_train(meta, preds, epoch=epoch, figure_dir=args.figure_dir) + eqnet.utils.plot_autoencoder_das_train(meta, preds, epoch=epoch, figure_dir=args.figure_dir) del preds elif args.model == "eqnet": phase = F.softmax(output["phase"], dim=1).cpu().float() event = torch.sigmoid(output["event"]).cpu().float() print("Plotting...") - eqnet.utils.visualize_eqnet_train(meta, phase, event, epoch=epoch, figure_dir=args.figure_dir) + eqnet.utils.plot_eqnet_train(meta, phase, event, epoch=epoch, figure_dir=args.figure_dir) del phase, event @@ -186,8 +198,9 @@ def main(args): np.random.seed(1337 + rank) device = torch.device(args.device) - dtype = "bfloat16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "float16" + dtype = "bfloat16" if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else "float16" ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype] + scaler = torch.cuda.amp.GradScaler(enabled=((dtype == "float16") & torch.cuda.is_available())) args.dtype, args.ptdtype = dtype, ptdtype torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn @@ -197,7 +210,7 @@ def main(args): else: torch.backends.cudnn.benchmark = True - if args.model == "phasenet": + if args.model in ["phasenet", "phasenet_plus"]: dataset = SeismicTraceIterableDataset( data_path=args.data_path, data_list=args.data_list, @@ -211,7 +224,18 @@ def main(args): world_size=world_size, ) train_sampler = None - dataset_test = dataset + dataset_test = SeismicTraceIterableDataset( + data_path=args.test_data_path, + data_list=args.test_data_list, + hdf5_file=args.test_hdf5_file, + format="h5", + training=True, + stack_event=False, + flip_polarity=False, + drop_channel=False, + rank=rank, + world_size=world_size, + ) test_sampler = None elif args.model == "phasenet_das": dataset = DASIterableDataset( @@ -292,8 +316,8 @@ def main(args): else: dataset = SeismicNetworkIterableDataset(args.dataset) train_sampler = None - test_sampler = None dataset_test = dataset + test_sampler = None else: raise ("Unknown model") @@ -303,7 +327,6 @@ def main(args): batch_sampler=train_batch_sampler, num_workers=args.workers, ) - data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=1, @@ -337,9 +360,9 @@ def main(args): backbone=args.backbone, in_channels=1, out_channels=(len(args.phases) + 1), - ## phasenet-das + ## phasenet_das reg=args.reg, - ## phasenet + ## phasenet_plus polarity_loss_weight=args.polarity_loss_weight, ) logger.info("Model:\n{}".format(model)) @@ -410,9 +433,6 @@ def main(args): else: raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.") - # scaler = torch.cuda.amp.GradScaler() if args.amp else None - scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) - iters_per_epoch = len(data_loader) args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == "steplr": @@ -496,13 +516,13 @@ def main(args): epoch, len(dataset), ) - print(f"Training time of epoch {epoch} of rank {args.rank}: {time.time() - tmp_time:.3f}") + print(f"Training time of epoch {epoch} of rank {rank}: {time.time() - tmp_time:.3f}") tmp_time = time.time() metric = evaluate(model, data_loader_test, scaler, args, epoch, len(dataset_test)) if model_ema: metric = evaluate(model_ema, data_loader_test, scaler, args, epoch, len(dataset_test)) - print(f"Testing time of epoch {epoch} of rank {args.rank}: {time.time() - tmp_time:.3f}") + print(f"Testing time of epoch {epoch} of rank {rank}: {time.time() - tmp_time:.3f}") tmp_time = time.time() checkpoint = { @@ -553,6 +573,7 @@ def get_args_parser(add_help=True): parser.add_argument("--test-label-path", default="+", type=None, help="test label path") parser.add_argument("--test-label-list", default="+", type=None, help="test label path") parser.add_argument("--test-noise-list", default="+", type=None, help="test noise list") + parser.add_argument("--test-hdf5-file", default=None, type=str, help="hdf5 file for testing") parser.add_argument("--dataset", default="", type=str, help="dataset name") parser.add_argument("--model", default="phasenet_das", type=str, help="model name") parser.add_argument("--backbone", default="unet", type=str, help="model backbone") diff --git a/train.yaml b/train.yaml new file mode 100644 index 0000000..16c71ea --- /dev/null +++ b/train.yaml @@ -0,0 +1,86 @@ +name: quakeflow + +workdir: . + +num_nodes: 1 + +resources: + + cloud: gcp + + region: us-west1 + + zone: us-west1-b + + # instance_type: n2-highmem-16 + + accelerators: V100:1 + + cpus: 8+ + + disk_size: 500 + + use_spot: True + + # image_id: docker:zhuwq0/quakeflow:latest + +envs: + JOB: quakeflow + NCPU: 1 + ROOT_PATH: /data + MODEL_NAME: phasenet + +file_mounts: + + # /data: + # # source: s3://scedc-pds/ + # # source: s3://ncedc-pds/ + # source: gs://quakeflow_dataset/ + # # source: gs://quakeflow_share/ + # # source: gs://das_arcata/ + # mode: MOUNT + + # /dataset: + # source: gs://quakeflow_dataset + # mode: MOUNT + + # /dataset/waveform_ps: + # source: gs://quakeflow_dataset/NC/waveform_ps + # mode: COPY + + /dataset/train.h5: + source: gs://quakeflow_dataset/NC/waveform_ps/2021.h5 + mode: COPY + + /dataset/test.h5: + source: gs://quakeflow_dataset/NC/waveform_ps/2022.h5 + mode: COPY + + ~/.ssh/id_rsa.pub: ~/.ssh/id_rsa.pub + ~/.ssh/id_rsa: ~/.ssh/id_rsa + ~/.config/rclone/rclone.conf: ~/.config/rclone/rclone.conf + +setup: | + echo "Begin setup." + # sudo apt install rclone + pip install fsspec gcsfs + pip install obspy pyproj + pip install h5py tqdm wandb + # pip install torch torchvision torchaudio + pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121 + +run: | + num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l` + master_addr=`echo "$SKYPILOT_NODE_IPS" | head -n1` + [[ ${SKYPILOT_NUM_GPUS_PER_NODE} -gt $NCPU ]] && nproc_per_node=${SKYPILOT_NUM_GPUS_PER_NODE} || nproc_per_node=$NCPU + if [ "${SKYPILOT_NODE_RANK}" == "0" ]; then + ls -al /dataset + fi + torchrun \ + --nproc_per_node=${nproc_per_node} \ + --node_rank=${SKYPILOT_NODE_RANK} \ + --nnodes=$num_nodes \ + --master_addr=$master_addr \ + --master_port=8008 \ + train.py --model $MODEL_NAME --batch-size=256 --hdf5-file /dataset/train.h5 --test-hdf5-file /dataset/test.h5 \ + --stack-event --flip-polarity --drop-channel --output model_$MODEL_NAME \ No newline at end of file diff --git a/utils.py b/utils.py index 026f41e..75c9278 100644 --- a/utils.py +++ b/utils.py @@ -4,7 +4,7 @@ import hashlib import os import time -from collections import defaultdict, deque, OrderedDict +from collections import OrderedDict, defaultdict, deque from typing import List, Optional, Tuple import torch @@ -150,7 +150,7 @@ def log_every(self, iterable, print_freq, header=None): time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB, - datetime=time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()), + datetime=datetime.datetime.now().strftime("%m-%dT%H:%M:%S"), ) ) else: @@ -162,7 +162,7 @@ def log_every(self, iterable, print_freq, header=None): meters=str(self), time=str(iter_time), data=str(data_time), - datetime=time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()), + datetime=datetime.datetime.now().strftime("%m-%dT%H:%M:%S"), ) ) i += 1 @@ -306,8 +306,7 @@ def average_checkpoints(inputs): for fpath in inputs: with open(fpath, "rb") as f: state = torch.load( - f, - map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), + f, map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), weights_only=True ) # Copies over the settings from the first checkpoint if new_state is None: @@ -386,7 +385,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T # Deep copy to avoid side effects on the model object. model = copy.deepcopy(model) - checkpoint = torch.load(checkpoint_path, map_location="cpu") + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) # Load the weights to the model to validate that everything works # and remove unnecessary weights (such as auxiliaries, etc.)