diff --git a/Dockerfile b/Dockerfile index ed41902..4e25344 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,6 @@ FROM tensorflow/tensorflow # SHELL ["conda", "run", "-n", "cs329s", "/bin/bash", "-c"] RUN pip install tqdm obspy pandas -RUN pip install minio pymongo kafka-python RUN pip install uvicorn fastapi WORKDIR /opt diff --git a/env.yml b/env.yml index 6d77889..e29273a 100644 --- a/env.yml +++ b/env.yml @@ -13,9 +13,5 @@ dependencies: - obspy - uvicorn - fastapi - - kafka-python - tensorflow - keras - - pymongo - - diff --git a/phasenet/data_reader.py b/phasenet/data_reader.py index 8081faf..3de2074 100755 --- a/phasenet/data_reader.py +++ b/phasenet/data_reader.py @@ -10,14 +10,14 @@ pd.options.mode.chained_assignment = None import json +import random +from collections import defaultdict # import s3fs import h5py import obspy from scipy.interpolate import interp1d from tqdm import tqdm -from collections import defaultdict -import random def py_func_decorator(output_types=None, output_shapes=None, name=None): @@ -148,7 +148,6 @@ def normalize_batch(data, window=3000): class DataConfig: - seed = 123 use_seed = True n_channel = 3 @@ -168,7 +167,9 @@ def __init__(self, **kwargs): class DataReader: - def __init__(self, format="numpy", config=DataConfig(), response_xml=None, sampling_rate=100, highpass_filter=0, **kwargs): + def __init__( + self, format="numpy", config=DataConfig(), response_xml=None, sampling_rate=100, highpass_filter=0, **kwargs + ): self.buffer = {} self.n_channel = config.n_channel self.n_class = config.n_class @@ -302,9 +303,7 @@ def read_s3(self, format, fname, bucket, key, secret, s3_url, use_ssl): raise (f"Format {format} not supported") return meta - def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=100, return_single_station=True): - try: stream = obspy.read(fname) stream = stream.merge(fill_value="latest") @@ -316,7 +315,6 @@ def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=10 return {} tmp_stream = obspy.Stream() for trace in stream: - if len(trace.data) < 10: continue @@ -344,9 +342,18 @@ def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=10 end_time = max([st.stats.endtime for st in stream]) stream = stream.trim(begin_time, end_time, pad=True, fill_value=0) - comp = ["3", "2", "1", "E", "N", "Z"] + comp = ["3", "2", "1", "E", "N", "U", "V", "Z"] order = {key: i for i, key in enumerate(comp)} - comp2idx = {"3": 0, "2": 1, "1": 2, "E": 0, "N": 1, "Z": 2} ## only for cases less than 3 components + comp2idx = { + "3": 0, + "2": 1, + "1": 2, + "E": 0, + "N": 1, + "Z": 2, + "U": 0, + "V": 1, + } ## only for cases less than 3 components station_ids = defaultdict(list) for tr in stream: @@ -360,9 +367,7 @@ def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=10 nt = len(stream[0].data) data = np.zeros([3, nt, nx], dtype=np.float32) for i, sta in enumerate(station_keys): - for j, c in enumerate(sorted(station_ids[sta], key=lambda x: order[x])): - if len(station_ids[sta]) != 3: ## less than 3 component j = comp2idx[c] @@ -378,17 +383,20 @@ def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=10 tmp = trace.data.astype("float32") data[j, : len(tmp), i] = tmp[:nt] - + # if return_single_station and (len(station_keys) > 1): # print(f"Warning: {fname} has multiple stations, returning only the first one {station_keys[0]}") # data = data[:, :, 0:1] # station_keys = station_keys[0:1] - meta = {"data": data.transpose([1, 2, 0]), "t0": begin_time.datetime.isoformat(timespec="milliseconds"), "station_id": station_keys} + meta = { + "data": data.transpose([1, 2, 0]), + "t0": begin_time.datetime.isoformat(timespec="milliseconds"), + "station_id": station_keys, + } return meta def read_sac(self, fname): - mseed = obspy.read(fname) mseed = mseed.detrend("spline", order=2, dspline=5 * mseed[0].stats.sampling_rate) mseed = mseed.merge(fill_value=0) @@ -422,7 +430,6 @@ def read_sac(self, fname): return meta def read_mseed_array(self, fname, stations, amplitude=False, remove_resp=True): - data = [] station_id = [] t0 = [] @@ -485,7 +492,6 @@ def read_mseed_array(self, fname, stations, amplitude=False, remove_resp=True): resp = stations[sta]["response"] for j, c in enumerate(sorted(comp, key=lambda x: order[x[-1]])): - resp_j = resp[j] if len(comp) != 3: ## less than 3 component j = comp2idx[c] @@ -625,7 +631,6 @@ def random_shift(self, sample, itp, its, itp_old=None, its_old=None, shift_range return shifted_sample, shift_pick(itp, shift), shift_pick(its, shift), shift def stack_events(self, sample_old, itp_old, its_old, shift_range=None, mask_old=None): - i = np.random.randint(self.num_data) base_name = self.data_list[i] if self.format == "numpy": @@ -663,7 +668,6 @@ def cut_window(self, sample, target, itp, its, select_range): class DataReader_train(DataReader): def __init__(self, format="numpy", config=DataConfig(), **kwargs): - super().__init__(format=format, config=config, **kwargs) self.min_event_gap = config.min_event_gap @@ -672,7 +676,6 @@ def __init__(self, format="numpy", config=DataConfig(), **kwargs): self.select_range = [5000, 8000] def __getitem__(self, i): - base_name = self.data_list[i] if self.format == "numpy": meta = self.read_numpy(os.path.join(self.data_dir, base_name)) @@ -716,13 +719,11 @@ def dataset(self, batch_size, num_parallel_calls=2, shuffle=True, drop_remainder class DataReader_test(DataReader): def __init__(self, format="numpy", config=DataConfig(), **kwargs): - super().__init__(format=format, config=config, **kwargs) self.select_range = [5000, 8000] def __getitem__(self, i): - base_name = self.data_list[i] if self.format == "numpy": meta = self.read_numpy(os.path.join(self.data_dir, base_name)) @@ -756,7 +757,6 @@ def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainde class DataReader_pred(DataReader): def __init__(self, format="numpy", amplitude=True, config=DataConfig(), **kwargs): - super().__init__(format=format, config=config, **kwargs) self.amplitude = amplitude @@ -769,25 +769,30 @@ def adjust_missingchannels(self, data): return data def __getitem__(self, i): - base_name = self.data_list[i] if self.format == "numpy": meta = self.read_numpy(os.path.join(self.data_dir, base_name)) elif (self.format == "mseed") or (self.format == "sac"): - meta = self.read_mseed(os.path.join(self.data_dir, base_name), response=self.response, sampling_rate=self.sampling_rate, highpass_filter=self.highpass_filter, return_single_station=True) + meta = self.read_mseed( + os.path.join(self.data_dir, base_name), + response=self.response, + sampling_rate=self.sampling_rate, + highpass_filter=self.highpass_filter, + return_single_station=True, + ) elif self.format == "hdf5": meta = self.read_hdf5(base_name) else: raise (f"{self.format} does not support!") - + if "data" in meta: raw_amp = meta["data"].copy() sample = normalize_long(meta["data"]) else: raw_amp = np.zeros([3000, 1, 3], dtype=np.float32) sample = np.zeros([3000, 1, 3], dtype=np.float32) - + if "t0" in meta: t0 = meta["t0"] else: @@ -816,7 +821,7 @@ def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainde dataset = dataset_map( self, output_types=(self.dtype, self.dtype, "string", "string", "string"), - output_shapes=([None,None,3], [None,None,3], None, None, None), + output_shapes=([None, None, 3], [None, None, 3], None, None, None), num_parallel_calls=num_parallel_calls, shuffle=shuffle, ) @@ -824,7 +829,7 @@ def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainde dataset = dataset_map( self, output_types=(self.dtype, "string", "string", "string"), - output_shapes=([None,None,3], None, None, None), + output_shapes=([None, None, 3], None, None, None), num_parallel_calls=num_parallel_calls, shuffle=shuffle, ) @@ -834,7 +839,6 @@ def dataset(self, batch_size, num_parallel_calls=2, shuffle=False, drop_remainde class DataReader_mseed_array(DataReader): def __init__(self, stations, amplitude=True, remove_resp=True, config=DataConfig(), **kwargs): - super().__init__(format="mseed", config=config, **kwargs) # self.stations = pd.read_json(stations) @@ -852,7 +856,6 @@ def get_data_shape(self): return meta["data"].shape def __getitem__(self, i): - fp = os.path.join(self.data_dir, self.data_list[i]) # try: meta = self.read_mseed_array(fp, self.stations, self.amplitude, self.remove_resp) @@ -1004,5 +1007,4 @@ def read(data_reader, batch=1): if __name__ == "__main__": - test_DataReader() diff --git a/phasenet/postprocess.py b/phasenet/postprocess.py index 281f06c..8ae9cb0 100644 --- a/phasenet/postprocess.py +++ b/phasenet/postprocess.py @@ -75,7 +75,6 @@ def extract_picks( config=None, waveforms=None, use_amplitude=False, - upload_waveform=False, ): """Extract picks from prediction results. Args: @@ -94,7 +93,6 @@ def extract_picks( for x in phases: mph[x] = 0.3 mpd = 50 - ## upload waveform pre_idx = int(1 / dt) post_idx = int(4 / dt) else: @@ -123,7 +121,6 @@ def extract_picks( picks = [] for i in range(Nb): - file_name = file_names[i] begin_time = datetime.fromisoformat(begin_times[i]) @@ -162,9 +159,6 @@ def extract_picks( if hi > Nt: hi = Nt tmp[insert_idx : insert_idx + hi - lo, :] = waveforms[i, lo:hi, j, :] - if upload_waveform: - pick["waveform"] = tmp.tolist() - pick["_id"] = f"{pick['station_id']}_{pick['timestamp']}_{pick['type']}" if use_amplitude: next_pick = idxs[l + 1] if l < len(idxs) - 1 else (phase_index + post_idx * 3) pick["phase_amplitude"] = np.max( diff --git a/phasenet/predict.py b/phasenet/predict.py index 9c91cee..c322ebb 100755 --- a/phasenet/predict.py +++ b/phasenet/predict.py @@ -11,7 +11,6 @@ import pandas as pd import tensorflow as tf from data_reader import DataReader_mseed_array, DataReader_pred -from model import ModelConfig, UNet from postprocess import ( extract_amplitude, extract_picks, @@ -19,35 +18,16 @@ save_picks_json, save_prob_h5, ) -from pymongo import MongoClient from tqdm import tqdm from visulization import plot_waveform +from model import ModelConfig, UNet + tf.compat.v1.disable_eager_execution() tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) -username = "root" -password = "quakeflow123" -# client = MongoClient(f"mongodb://{username}:{password}@127.0.0.1:27017") -client = MongoClient(f"mongodb://{username}:{password}@quakeflow-mongodb-headless.default.svc.cluster.local:27017") - -# db = client["quakeflow"] -# collection = db["waveform"] - - -def upload_mongodb(picks): - db = client["quakeflow"] - collection = db["waveform"] - try: - collection.insert_many(picks) - except Exception as e: - print("Warning:", e) - collection.delete_many({"_id": {"$in": [p["_id"] for p in picks]}}) - collection.insert_many(picks) - def read_args(): - parser = argparse.ArgumentParser() parser.add_argument("--batch_size", default=20, type=int, help="batch size") parser.add_argument("--model_dir", help="Checkpoint directory (default: None)") @@ -66,7 +46,6 @@ def read_args(): parser.add_argument("--stations", default="", help="seismic station info") parser.add_argument("--plot_figure", action="store_true", help="If plot figure for test") parser.add_argument("--save_prob", action="store_true", help="If save result for test") - parser.add_argument("--upload_waveform", action="store_true", help="If upload waveform to mongodb") parser.add_argument("--pre_sec", default=1, type=float, help="Window length before pick") parser.add_argument("--post_sec", default=4, type=float, help="Window length after pick") @@ -117,7 +96,6 @@ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None): # sess_config.log_device_placement = False with tf.compat.v1.Session(config=sess_config) as sess: - saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), max_to_keep=5) init = tf.compat.v1.global_variables_initializer() sess.run(init) @@ -151,8 +129,6 @@ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None): # pred_batch = np.vstack(pred_batch) waveforms = None - if args.upload_waveform: - waveforms = X_batch if args.amplitude: waveforms = amp_batch @@ -164,14 +140,36 @@ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None): config=args, waveforms=waveforms, use_amplitude=args.amplitude, - upload_waveform=args.upload_waveform, - dt=1.0/args.sampling_rate + dt=1.0 / args.sampling_rate, ) - if args.upload_waveform: - upload_mongodb(picks_) picks.extend(picks_) + ## save pick per file + if len(fname_batch) == 1: + df = pd.DataFrame(picks_) + df = df[df["phase_index"] > 10] + if not os.path.exists(os.path.join(args.result_dir, "picks")): + os.makedirs(os.path.join(args.result_dir, "picks")) + df = df[ + [ + "station_id", + "begin_time", + "phase_index", + "phase_time", + "phase_score", + "phase_type", + "phase_amplitude", + "dt", + ] + ] + df.to_csv( + os.path.join( + args.result_dir, "picks", fname_batch[0].decode().split("/")[-1].rstrip(".mseed") + ".csv" + ), + index=False, + ) + if args.plot_figure: if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)): fname_batch = [fname_batch.decode().rstrip(".mseed") + "_" + x.decode() for x in station_batch] @@ -204,12 +202,20 @@ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None): # df["prob"] = df["phase_prob"] # df["type"] = df["phase_type"] - base_columns = ["station_id", "begin_time", "phase_index", "phase_time", "phase_score", "phase_type", "file_name"] + base_columns = [ + "station_id", + "begin_time", + "phase_index", + "phase_time", + "phase_score", + "phase_type", + "file_name", + ] if args.amplitude: base_columns.append("phase_amplitude") base_columns.append("phase_amp") df["phase_amp"] = df["phase_amplitude"] - + df = df[base_columns] df.to_csv(os.path.join(args.result_dir, args.result_fname + ".csv"), index=False) @@ -222,11 +228,9 @@ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None): def main(args): - logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO) with tf.compat.v1.name_scope("create_inputs"): - if args.format == "mseed_array": data_reader = DataReader_mseed_array( data_dir=args.data_dir, diff --git a/requirements.txt b/requirements.txt index 1b4ddc3..55afc96 100755 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,5 @@ matplotlib pandas tqdm scipy -pymongo obspy