Skip to content

Commit

Permalink
remove pymongo and save picks per file
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Oct 7, 2023
1 parent 460bfe1 commit 25b9617
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 77 deletions.
1 change: 0 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ FROM tensorflow/tensorflow
# SHELL ["conda", "run", "-n", "cs329s", "/bin/bash", "-c"]

RUN pip install tqdm obspy pandas
RUN pip install minio pymongo kafka-python
RUN pip install uvicorn fastapi

WORKDIR /opt
Expand Down
4 changes: 0 additions & 4 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,5 @@ dependencies:
- obspy
- uvicorn
- fastapi
- kafka-python
- tensorflow
- keras
- pymongo


64 changes: 33 additions & 31 deletions phasenet/data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

pd.options.mode.chained_assignment = None
import json
import random
from collections import defaultdict

# import s3fs
import h5py
import obspy
from scipy.interpolate import interp1d
from tqdm import tqdm
from collections import defaultdict
import random


def py_func_decorator(output_types=None, output_shapes=None, name=None):
Expand Down Expand Up @@ -148,7 +148,6 @@ def normalize_batch(data, window=3000):


class DataConfig:

seed = 123
use_seed = True
n_channel = 3
Expand All @@ -168,7 +167,9 @@ def __init__(self, **kwargs):


class DataReader:
def __init__(self, format="numpy", config=DataConfig(), response_xml=None, sampling_rate=100, highpass_filter=0, **kwargs):
def __init__(
self, format="numpy", config=DataConfig(), response_xml=None, sampling_rate=100, highpass_filter=0, **kwargs
):
self.buffer = {}
self.n_channel = config.n_channel
self.n_class = config.n_class
Expand Down Expand Up @@ -302,9 +303,7 @@ def read_s3(self, format, fname, bucket, key, secret, s3_url, use_ssl):
raise (f"Format {format} not supported")
return meta


def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=100, return_single_station=True):

try:
stream = obspy.read(fname)
stream = stream.merge(fill_value="latest")
Expand All @@ -316,7 +315,6 @@ def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=10
return {}
tmp_stream = obspy.Stream()
for trace in stream:

if len(trace.data) < 10:
continue

Expand Down Expand Up @@ -344,9 +342,18 @@ def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=10
end_time = max([st.stats.endtime for st in stream])
stream = stream.trim(begin_time, end_time, pad=True, fill_value=0)

comp = ["3", "2", "1", "E", "N", "Z"]
comp = ["3", "2", "1", "E", "N", "U", "V", "Z"]
order = {key: i for i, key in enumerate(comp)}
comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2} ## only for cases less than 3 components
comp2idx = {
"3": 0,
"2": 1,
"1": 2,
"E": 0,
"N": 1,
"Z": 2,
"U": 0,
"V": 1,
} ## only for cases less than 3 components

station_ids = defaultdict(list)
for tr in stream:
Expand All @@ -360,9 +367,7 @@ def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=10
nt = len(stream[0].data)
data = np.zeros([3, nt, nx], dtype=np.float32)
for i, sta in enumerate(station_keys):

for j, c in enumerate(sorted(station_ids[sta], key=lambda x: order[x])):

if len(station_ids[sta]) != 3: ## less than 3 component
j = comp2idx[c]

Expand All @@ -378,17 +383,20 @@ def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=10

tmp = trace.data.astype("float32")
data[j, : len(tmp), i] = tmp[:nt]

# if return_single_station and (len(station_keys) > 1):
# print(f"Warning: {fname} has multiple stations, returning only the first one {station_keys[0]}")
# data = data[:, :, 0:1]
# station_keys = station_keys[0:1]

meta = {"data": data.transpose([1, 2, 0]), "t0": begin_time.datetime.isoformat(timespec="milliseconds"), "station_id": station_keys}
meta = {
"data": data.transpose([1, 2, 0]),
"t0": begin_time.datetime.isoformat(timespec="milliseconds"),
"station_id": station_keys,
}
return meta

def read_sac(self, fname):

mseed = obspy.read(fname)
mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate)
mseed = mseed.merge(fill_value=0)
Expand Down Expand Up @@ -422,7 +430,6 @@ def read_sac(self, fname):
return meta

def read_mseed_array(self, fname, stations, amplitude=False, remove_resp=True):

data = []
station_id = []
t0 = []
Expand Down Expand Up @@ -485,7 +492,6 @@ def read_mseed_array(self, fname, stations, amplitude=False, remove_resp=True):
resp = stations[sta]["response"]

for j, c in enumerate(sorted(comp, key=lambda x: order[x[-1]])):

resp_j = resp[j]
if len(comp) != 3: ## less than 3 component
j = comp2idx[c]
Expand Down Expand Up @@ -625,7 +631,6 @@ def random_shift(self, sample, itp, its, itp_old=None, its_old=None, shift_range
return shifted_sample, shift_pick(itp, shift), shift_pick(its, shift), shift

def stack_events(self, sample_old, itp_old, its_old, shift_range=None, mask_old=None):

i = np.random.randint(self.num_data)
base_name = self.data_list[i]
if self.format == "numpy":
Expand Down Expand Up @@ -663,7 +668,6 @@ def cut_window(self, sample, target, itp, its, select_range):

class DataReader_train(DataReader):
def __init__(self, format="numpy", config=DataConfig(), **kwargs):

super().__init__(format=format, config=config, **kwargs)

self.min_event_gap = config.min_event_gap
Expand All @@ -672,7 +676,6 @@ def __init__(self, format="numpy", config=DataConfig(), **kwargs):
self.select_range = [5000, 8000]

def __getitem__(self, i):

base_name = self.data_list[i]
if self.format == "numpy":
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
Expand Down Expand Up @@ -716,13 +719,11 @@ def dataset(self, batch_size, num_parallel_calls=2, shuffle=True, drop_remainder

class DataReader_test(DataReader):
def __init__(self, format="numpy", config=DataConfig(), **kwargs):

super().__init__(format=format, config=config, **kwargs)

self.select_range = [5000, 8000]

def __getitem__(self, i):

base_name = self.data_list[i]
if self.format == "numpy":
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
Expand Down Expand Up @@ -756,7 +757,6 @@ def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainde

class DataReader_pred(DataReader):
def __init__(self, format="numpy", amplitude=True, config=DataConfig(), **kwargs):

super().__init__(format=format, config=config, **kwargs)

self.amplitude = amplitude
Expand All @@ -769,25 +769,30 @@ def adjust_missingchannels(self, data):
return data

def __getitem__(self, i):

base_name = self.data_list[i]

if self.format == "numpy":
meta = self.read_numpy(os.path.join(self.data_dir, base_name))
elif (self.format == "mseed") or (self.format == "sac"):
meta = self.read_mseed(os.path.join(self.data_dir, base_name), response=self.response, sampling_rate=self.sampling_rate, highpass_filter=self.highpass_filter, return_single_station=True)
meta = self.read_mseed(
os.path.join(self.data_dir, base_name),
response=self.response,
sampling_rate=self.sampling_rate,
highpass_filter=self.highpass_filter,
return_single_station=True,
)
elif self.format == "hdf5":
meta = self.read_hdf5(base_name)
else:
raise (f"{self.format} does not support!")

if "data" in meta:
raw_amp = meta["data"].copy()
sample = normalize_long(meta["data"])
else:
raw_amp = np.zeros([3000, 1, 3], dtype=np.float32)
sample = np.zeros([3000, 1, 3], dtype=np.float32)

if "t0" in meta:
t0 = meta["t0"]
else:
Expand Down Expand Up @@ -816,15 +821,15 @@ def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainde
dataset = dataset_map(
self,
output_types=(self.dtype, self.dtype, "string", "string", "string"),
output_shapes=([None,None,3], [None,None,3], None, None, None),
output_shapes=([None, None, 3], [None, None, 3], None, None, None),
num_parallel_calls=num_parallel_calls,
shuffle=shuffle,
)
else:
dataset = dataset_map(
self,
output_types=(self.dtype, "string", "string", "string"),
output_shapes=([None,None,3], None, None, None),
output_shapes=([None, None, 3], None, None, None),
num_parallel_calls=num_parallel_calls,
shuffle=shuffle,
)
Expand All @@ -834,7 +839,6 @@ def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainde

class DataReader_mseed_array(DataReader):
def __init__(self, stations, amplitude=True, remove_resp=True, config=DataConfig(), **kwargs):

super().__init__(format="mseed", config=config, **kwargs)

# self.stations = pd.read_json(stations)
Expand All @@ -852,7 +856,6 @@ def get_data_shape(self):
return meta["data"].shape

def __getitem__(self, i):

fp = os.path.join(self.data_dir, self.data_list[i])
# try:
meta = self.read_mseed_array(fp, self.stations, self.amplitude, self.remove_resp)
Expand Down Expand Up @@ -1004,5 +1007,4 @@ def read(data_reader, batch=1):


if __name__ == "__main__":

test_DataReader()
6 changes: 0 additions & 6 deletions phasenet/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def extract_picks(
config=None,
waveforms=None,
use_amplitude=False,
upload_waveform=False,
):
"""Extract picks from prediction results.
Args:
Expand All @@ -94,7 +93,6 @@ def extract_picks(
for x in phases:
mph[x] = 0.3
mpd = 50
## upload waveform
pre_idx = int(1 / dt)
post_idx = int(4 / dt)
else:
Expand Down Expand Up @@ -123,7 +121,6 @@ def extract_picks(

picks = []
for i in range(Nb):

file_name = file_names[i]
begin_time = datetime.fromisoformat(begin_times[i])

Expand Down Expand Up @@ -162,9 +159,6 @@ def extract_picks(
if hi > Nt:
hi = Nt
tmp[insert_idx : insert_idx + hi - lo, :] = waveforms[i, lo:hi, j, :]
if upload_waveform:
pick["waveform"] = tmp.tolist()
pick["_id"] = f"{pick['station_id']}_{pick['timestamp']}_{pick['type']}"
if use_amplitude:
next_pick = idxs[l + 1] if l < len(idxs) - 1 else (phase_index + post_idx * 3)
pick["phase_amplitude"] = np.max(
Expand Down
Loading

0 comments on commit 25b9617

Please sign in to comment.