Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/ncedc' into quakeflow
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Oct 23, 2024
2 parents d5a17ec + b6b9840 commit e20e1dd
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 40 deletions.
6 changes: 4 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM tensorflow/tensorflow
FROM tensorflow/tensorflow:2.14.0

# Create the environment:
# COPY env.yml /app
Expand All @@ -8,12 +8,14 @@ FROM tensorflow/tensorflow

RUN pip install tqdm obspy pandas
RUN pip install uvicorn fastapi
RUN pip install fsspec gcsfs s3fs

WORKDIR /opt

# Copy files
COPY phasenet /opt/phasenet
COPY model /opt/model
COPY application_default_credentials.json /opt/application_default_credentials.json

# Expose API port
EXPOSE 8000
Expand All @@ -22,4 +24,4 @@ ENV PYTHONUNBUFFERED=1

# Start API server
#ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "cs329s", "uvicorn", "--app-dir", "phasenet", "app:app", "--reload", "--port", "8000", "--host", "0.0.0.0"]
ENTRYPOINT ["uvicorn", "--app-dir", "phasenet", "app:app", "--reload", "--port", "8000", "--host", "0.0.0.0"]
# ENTRYPOINT ["uvicorn", "--app-dir", "phasenet", "app:app", "--reload", "--port", "8000", "--host", "0.0.0.0"]
140 changes: 126 additions & 14 deletions phasenet/data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,22 @@
import random
from collections import defaultdict

# import s3fs
import fsspec
import h5py
import obspy
from obspy.clients.fdsn import Client
from scipy.interpolate import interp1d
from tqdm import tqdm

# token_json = f"{os.environ['HOME']}/.config/gcloud/application_default_credentials.json"
token_json = "application_default_credentials.json"
with open(token_json, "r") as fp:
token = json.load(fp)
fs_gs = fsspec.filesystem("gs", token=token)
# client = Client("SCEDC")
client = Client("NCEDC")
client_iris = Client("IRIS") ## HardCode: IRIS for response file


def py_func_decorator(output_types=None, output_shapes=None, name=None):
def decorator(func):
Expand Down Expand Up @@ -192,12 +202,51 @@ def __init__(
self.sampling_rate = sampling_rate
if format in ["numpy", "mseed", "sac"]:
self.data_dir = kwargs["data_dir"]
try:
csv = pd.read_csv(kwargs["data_list"], header=0, sep="[,|\s+]", engine="python")
except:
csv = pd.read_csv(kwargs["data_list"], header=0, sep="\t")
self.data_list = csv["fname"]
# try:
# csv = pd.read_csv(kwargs["data_list"], header=0, sep="[,|\s+]", engine="python")
# except:
# csv = pd.read_csv(kwargs["data_list"], header=0, sep="\t")

with open(kwargs["data_list"], "r") as fp:
lines = fp.read().splitlines()
print(f"Total sampel: {len(lines)}")

# filter = []
# processed = []

# # FIX: HardCode: check if picks exists
# bucket = "quakeflow_catalog"
# folder = "NC/phasenet"
# processed = fs_gs.glob(f"{bucket}/{folder}/**/*.csv")
# # networks = fs_gs.ls(f"{bucket}/{folder}/")
# # for i, network in enumerate(networks):
# # years = fs_gs.ls(f"{network}/")
# # for year in tqdm(years, desc=f"Check processed {network}"):
# # jdays = fs_gs.ls(f"{year}/")
# # for jday in jdays:
# # mseeds = fs_gs.glob(f"{jday}/*.{jday.split('/')[-1]}.csv")
# # # mseeds = fs_gs.ls(f"{jday}/")
# # processed.extend(mseeds)

# processed = set(processed)
# key_set = set()
# mapping_dit = {}
# for line in tqdm(lines, desc="Filter processed"):
# tmp = line.split(",")[0].lstrip("s3://").split("/")
# parant_dir = "/".join(tmp[2:-1])
# fname = tmp[-1].rstrip(".mseed") + ".csv"
# tmp_name = f"{bucket}/{folder}/{parant_dir}/{fname}"
# key_set.add(tmp_name)
# mapping_dit[tmp_name] = line
# key_set = list(key_set - processed)
# lines = sorted([mapping_dit[x] for x in key_set], reverse=True)
# print(f"Unprocessed sample {len(lines)}")

self.data_list = lines
self.num_data = len(self.data_list)

# del lines, filter, processed

elif format == "hdf5":
self.h5 = h5py.File(kwargs["hdf5_file"], "r", libver="latest", swmr=True)
self.h5_data = self.h5[kwargs["hdf5_group"]]
Expand Down Expand Up @@ -309,14 +358,64 @@ def read_s3(self, format, fname, bucket, key, secret, s3_url, use_ssl):

def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=100, return_single_station=True):
try:
stream = obspy.read(fname)
# stream = obspy.read(fname)
files = fname.rstrip("\n").split(",")
stream = obspy.Stream()
for file in files:
with fsspec.open(file, "rb", anon=True) as fp:
stream += obspy.read(fp)
stream = stream.merge(fill_value="latest")
if response is not None:
# response = obspy.read_inventory(response_xml)

## FIX: hard code for response file
## NCEDC
station, network, channel = files[0].split("/")[-1].split(".")[:3]
response_xml = f"gs://quakeflow_catalog/NC/FDSNstationXML/{network}/{network}.{station}.xml"
# response_xml = (
# f"gs://quakeflow_dataset/NC/FDSNstationXML/{network}.info/{network}.FDSN.xml/{network}.{station}.xml"
# )

## SCEDC
# fname = files[0].split("/")[-1]
# network = fname[:2]
# station = fname[2:7].rstrip("_")
# instrument = fname[7:9]
# channel = fname[9]
# location = fname[10:12].rstrip("_")
# year = fname[13:17]
# jday = fname[17:20]
# response_xml = f"gs://quakeflow_catalog/SC/FDSNstationXML/{network}/{network}_{station}.xml"

redownload = True
if fs_gs.exists(response_xml):
try:
with fs_gs.open(response_xml, "rb") as fp:
response = obspy.read_inventory(fp)
stream = stream.remove_sensitivity(response)
redownload = False
except Exception as e:
print(f"Error removing sensitivity: {e}")
else:
redownload = True
if redownload:
try:
response = client.get_stations(network=network, station=station, level="response")
except Exception as e:
print(f"Error downloading response: {e}")
print(f"Retry downloading response from IRIS...")
try:
response = client_iris.get_stations(network=network, station=station, level="response")
except Exception as e:
print(f"Error downloading response from IRIS: {e}")
raise
response.write(f"/tmp/{network}_{station}.xml", format="stationxml")
fs_gs.put(f"/tmp/{network}_{station}.xml", response_xml)
print(f"Update response file: {response_xml}")
stream = stream.remove_sensitivity(response)

except Exception as e:
print(f"Error reading {fname}:\n{e}")
return {}

tmp_stream = obspy.Stream()
for trace in stream:
if len(trace.data) < 10:
Expand All @@ -330,7 +429,8 @@ def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=10
except Exception as e:
print(f"Error resampling {trace.id}:\n{e}")

trace = trace.detrend("demean")
# trace = trace.detrend("demean")
trace = trace.detrend("spline", order=2, dspline=5 * trace.stats.sampling_rate)

## highpass filtering > 1Hz
if highpass_filter > 0.0:
Expand Down Expand Up @@ -361,18 +461,24 @@ def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=10

station_ids = defaultdict(list)
for tr in stream:
station_ids[tr.id[:-1]].append(tr.id[-1])
if tr.id[-1] not in comp:
print(f"Unknown component {tr.id[-1]}")
print(f"Unknown component {tr.id}")
continue
station_ids[tr.id[:-1]].append(tr.id[-1])

station_keys = sorted(list(station_ids.keys()))
if len(station_keys) == 0:
return {}

nx = len(station_ids)
nt = len(stream[0].data)
data = np.zeros([3, nt, nx], dtype=np.float32)
for i, sta in enumerate(station_keys):
for j, c in enumerate(sorted(station_ids[sta], key=lambda x: order[x])):
if len(station_ids[sta]) != 3: ## less than 3 component
if c not in comp2idx:
print(f"Unknown component {c}")
continue
j = comp2idx[c]

if len(stream.select(id=sta + c)) == 0:
Expand Down Expand Up @@ -838,9 +944,15 @@ def __getitem__(self, i):
if self.format == "numpy":
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
elif (self.format == "mseed") or (self.format == "sac"):
# meta = self.read_mseed(
# os.path.join(self.data_dir, base_name),
# response=self.response,
# sampling_rate=self.sampling_rate,
# highpass_filter=self.highpass_filter,
# return_single_station=True,
# )
meta = self.read_mseed(
os.path.join(self.data_dir, base_name),
response=self.response,
base_name,
sampling_rate=self.sampling_rate,
highpass_filter=self.highpass_filter,
return_single_station=True,
Expand Down
87 changes: 63 additions & 24 deletions phasenet/predict.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import argparse
import json
import logging
import multiprocessing
import os
import pickle
import time
from functools import partial

import fsspec
import h5py
import numpy as np
import pandas as pd
import tensorflow as tf
from data_reader import DataReader_mseed_array, DataReader_pred
from model import ModelConfig, UNet
from postprocess import (
extract_amplitude,
extract_picks,
Expand All @@ -22,9 +23,17 @@
from tqdm import tqdm
from visulization import plot_waveform

from model import ModelConfig, UNet

tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

# token_json = f"{os.environ['HOME']}/.config/gcloud/application_default_credentials.json"
token_json = "application_default_credentials.json"
with open(token_json, "r") as fp:
token = json.load(fp)
fs_gs = fsspec.filesystem("gs", token=token)


def read_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -145,29 +154,59 @@ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None):
picks.extend(picks_)

# ## save pick per file
# if (len(fname_batch) == 1) & (len(picks_) > 0):
# df = pd.DataFrame(picks_)
# df = df[df["phase_index"] > 10]
# if not os.path.exists(os.path.join(args.result_dir, "picks")):
# os.makedirs(os.path.join(args.result_dir, "picks"))
# df = df[
# [
# "station_id",
# "begin_time",
# "phase_index",
# "phase_time",
# "phase_score",
# "phase_type",
# "phase_amplitude",
# "dt",
# ]
# ]
# df.to_csv(
# os.path.join(
# args.result_dir, "picks", fname_batch[0].decode().split("/")[-1].rstrip(".mseed") + ".csv"
# ),
# index=False,
# )

if len(fname_batch) == 1:
# ### FIX: Hard code for NCEDC and SCEDC
tmp = fname_batch[0].decode().split(",")[0].lstrip("s3://").split("/")
parant_dir = "/".join(tmp[2:-1]) # remove s3://ncedc-pds/continuous and mseed file name
fname = tmp[-1].rstrip("\n").rstrip(".mseed").rstrip(".ms") + ".csv"
csv_name = f"quakeflow_catalog/NC/phasenet/{parant_dir}/{fname}"
# csv_name = f"quakeflow_catalog/SC/phasenet/{parant_dir}/{fname}"
if not os.path.exists(os.path.join(args.result_dir, "picks", parant_dir)):
os.makedirs(os.path.join(args.result_dir, "picks", parant_dir), exist_ok=True)

if len(picks_) == 0:
with fs_gs.open(csv_name, "w") as fp:
fp.write("")
else:
df = pd.DataFrame(picks_)
df = df[df["phase_index"] > 10]
if len(df) == 0:
with fs_gs.open(csv_name, "w") as fp:
fp.write("")
else:
df["phase_amplitude"] = df["phase_amplitude"].apply(lambda x: f"{x:.3e}")
df = df[
[
"station_id",
"phase_time",
"phase_score",
"phase_type",
"phase_amplitude",
"begin_time",
"phase_index",
"dt",
]
]
df.sort_values(by=["phase_time"], inplace=True)
df.to_csv(
os.path.join(
args.result_dir,
"picks",
parant_dir,
fname,
),
index=False,
)
fs_gs.put(
os.path.join(
args.result_dir,
"picks",
parant_dir,
fname,
),
csv_name,
)

if args.plot_figure:
if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)):
Expand Down

0 comments on commit e20e1dd

Please sign in to comment.