Skip to content

Commit

Permalink
Update train.yaml and train.py, and add phasenet_plus.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Jan 2, 2024
1 parent 343d3e2 commit 5bec875
Show file tree
Hide file tree
Showing 8 changed files with 533 additions and 141 deletions.
149 changes: 149 additions & 0 deletions docs/phasenet_plus.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import obspy\n",
"import torch\n",
"import os\n",
"from glob import glob"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Download event data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"!wget -q https://github.com/AI4EPS/PhaseNet/releases/download/test_data/test_data.zip -O test_data.zip\n",
"!unzip -q -o test_data.zip\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run PhaseNet-Plus"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"mseed_list = glob('test_data/mseed/*.mseed')\n",
"with open(\"mseed_list.txt\", \"w\") as f:\n",
" f.write(\"\\n\".join(mseed_list))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"ngpu = torch.cuda.device_count()\n",
"base_cmd = \"../predict.py --model phasenet_plus --data_list mseed_list.txt --result_path ./results --format=mseed --batch_size 1 --workers 1\"\n",
"\n",
"plot_figure = True\n",
"if plot_figure:\n",
" base_cmd += \" --plot_figure\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"python ../predict.py --model phasenet_plus --data_list mseed_list.txt --result_path ./results --format=mseed --batch_size 1 --workers 1 --plot_figure --device cpu\n",
"Not using distributed mode\n",
"Namespace(model='phasenet_plus', resume='', backbone='unet', phases=['P', 'S'], device='cpu', workers=1, batch_size=1, use_deterministic_algorithms=False, amp=False, world_size=1, dist_url='env://', data_path='./', data_list='mseed_list.txt', hdf5_file=None, prefix='', format='mseed', dataset='das', result_path='./results', plot_figure=True, min_prob=0.3, add_polarity=False, add_event=False, highpass_filter=0.0, response_xml=None, folder_depth=0, cut_patch=False, nt=20480, nx=5120, resample_time=False, resample_space=False, system=None, location=None, skip_existing=False, distributed=False)\n",
"Total samples: ./.mseed : 16 files\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Predicting: 0%| | 0/16 [00:00<?, ?it/s]"
]
}
],
"source": [
"if ngpu == 0:\n",
" cmd = f\"python {base_cmd} --device cpu\"\n",
"elif ngpu == 1:\n",
" cmd = f\"python {base_cmd}\"\n",
"else:\n",
" cmd = f\"torchrun --nproc_per_node {ngpu} {base_cmd}\"\n",
"\n",
"print(cmd)\n",
"os.system(cmd);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"fig, axes = plt.subplots(4, 4, figsize=(12, 12))\n",
"axes = axes.flatten()\n",
"\n",
"for i, f in enumerate(glob('results/figures_phasenet_plus/*.png')):\n",
" img = Image.open(f) \n",
" axes[i].imshow(img)\n",
" axes[i].axis('off')\n",
" if i >= 15:\n",
" break\n",
"plt.tight_layout()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
18 changes: 11 additions & 7 deletions eqnet/data/seismic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,10 @@ def drop_channel(meta):
class SeismicTraceIterableDataset(IterableDataset):
degree2km = 111.32
nt = 4096 ## 8992
feature_scale = 16
feature_nt = nt // feature_scale
event_feature_scale = 16
polarity_feature_scale = 4
event_feature_nt = nt // event_feature_scale
polarity_feature_nt = nt // polarity_feature_scale

def __init__(
self,
Expand Down Expand Up @@ -552,6 +554,8 @@ def read_training_h5(self, trace_id, hdf5_fp=None):
tmp_max = np.max(attrs["phase_index"][attrs["event_id"] == e]).item()
duration.append([tmp_min, max(tmp_min + 3, tmp_max + 2 * (tmp_max - tmp_min))])

if e not in hdf5_fp:
continue
if len(attrs["phase_index"][attrs["event_id"] == e]) <= 1: # need both P and S
continue
c0.append(np.mean(attrs["phase_index"][attrs["event_id"] == e]).item())
Expand Down Expand Up @@ -661,11 +665,11 @@ def sample_train(self, data_list):
# waveform = normalize(waveform)
phase_pick = meta["phase_pick"]
phase_mask = meta["phase_mask"][np.newaxis, ::]
event_center = meta["event_center"][np.newaxis, :: self.feature_scale]
polarity = meta["polarity"][np.newaxis, ::]
polarity_mask = meta["polarity_mask"][np.newaxis, ::]
event_time = meta["event_time"][np.newaxis, :: self.feature_scale]
event_mask = meta["event_mask"][np.newaxis, :: self.feature_scale]
event_center = meta["event_center"][np.newaxis, :: self.event_feature_scale]
polarity = meta["polarity"][np.newaxis, :: self.polarity_feature_scale]
polarity_mask = meta["polarity_mask"][np.newaxis, :: self.polarity_feature_scale]
event_time = meta["event_time"][np.newaxis, :: self.event_feature_scale]
event_mask = meta["event_mask"][np.newaxis, :: self.event_feature_scale]
station_location = meta["station_location"]

yield {
Expand Down
73 changes: 35 additions & 38 deletions eqnet/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,38 +125,6 @@ def __init__(
self.add_stft = add_stft
self.moving_norm = moving_norm
self.log_scale = log_scale
if self.add_polarity:
self.encoder1_polarity = nn.Sequential(
self.encoder_block(
1, features, kernel_size=kernel_size, stride=init_stride, padding=padding, name="enc1_polarity"
),
self.encoder_block(
features,
features * 2,
kernel_size=kernel_size,
stride=stride,
padding=padding,
name="enc2_polarity",
),
)
self.decoder1_polarity = nn.Sequential(
self.decoder_block(
features * 6,
features * 2,
kernel_size=kernel_size,
stride=stride,
padding=padding,
name="dec1_polarity",
),
self.encoder_block(
features * 2,
features,
kernel_size=kernel_size,
stride=init_stride,
padding=padding,
name="dec2_polarity",
),
)

self.input_conv = self.encoder_block(
in_channels, features, kernel_size=kernel_size, stride=init_stride, padding=padding, name="enc1"
Expand Down Expand Up @@ -236,6 +204,30 @@ def __init__(
]
)
)

if self.add_polarity:
self.encoder_polarity = self.encoder_block(
1, features, kernel_size=kernel_size, stride=stride, padding=padding, name="enc1_polarity"
)
self.output_polarity = nn.Sequential(
OrderedDict(
[
(
"output_polarity_conv",
nn.Conv2d(
in_channels=features * 5,
out_channels=features,
kernel_size=kernel_size,
padding=padding,
bias=False,
),
),
("output_polarity_norm", nn.BatchNorm2d(num_features=features)),
("output_polarity_relu", nn.ReLU(inplace=True)),
]
)
)

if self.add_event:
self.output_event = nn.Sequential(
OrderedDict(
Expand All @@ -255,6 +247,7 @@ def __init__(
]
)
)

if (init_stride[0] > 1) or (init_stride[1] > 1):
self.output_upsample = nn.Upsample(scale_factor=init_stride, mode="bilinear", align_corners=False)
else:
Expand All @@ -276,7 +269,8 @@ def forward(self, x):
# x = sgram

if self.add_polarity:
enc1_polarity = self.encoder1_polarity(x[:, -1:, :, :]) ## last channel is vertical component
enc_polarity = self.encoder_polarity(x[:, -1:, :, :]) ## last channel is vertical component

enc1 = self.input_conv(x)
enc2 = self.encoder12(enc1)
enc3 = self.encoder23(enc2)
Expand All @@ -298,20 +292,23 @@ def forward(self, x):
out_phase = out_phase[:, :, :nt, :nx]

if self.add_polarity:
# dec1_polarity = torch.cat((dec1, enc1_polarity), dim=1)
dec1_polarity = torch.cat((dec2, enc1_polarity), dim=1)
out_polarity = self.decoder1_polarity(dec1_polarity)
dec_polarity = torch.cat((dec2, enc_polarity), dim=1)
out_polarity = self.output_polarity(dec_polarity)
if self.output_upsample is not None:
out_polarity = self.output_upsample(out_polarity)
out_polarity = out_polarity[:, :, :nt, :nx]
out_polarity = out_polarity[:, :, :nt, :nx]
else:
out_polarity = out_polarity[:, :, : nt // 4, :nx]
else:
out_polarity = None

if self.add_event:
out_event = self.output_event(dec3)
if self.output_upsample is not None:
out_event = self.output_upsample(out_event)
out_event = out_event[:, :, :nt, :nx]
out_event = out_event[:, :, : nt // 4, :nx]
else:
out_event = out_event[:, :, : nt // 16, :nx]
else:
out_event = None

Expand Down
Loading

0 comments on commit 5bec875

Please sign in to comment.