Skip to content

Commit

Permalink
for quakeflow batch processing
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Oct 23, 2024
1 parent e20e1dd commit 4f36f9a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 112 deletions.
64 changes: 15 additions & 49 deletions phasenet/data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from tqdm import tqdm

# token_json = f"{os.environ['HOME']}/.config/gcloud/application_default_credentials.json"
token_json = "application_default_credentials.json"
with open(token_json, "r") as fp:
token = json.load(fp)
fs_gs = fsspec.filesystem("gs", token=token)
# token_json = "application_default_credentials.json"
# with open(token_json, "r") as fp:
# token = json.load(fp)
# fs_gs = fsspec.filesystem("gs", token=token)
# client = Client("SCEDC")
client = Client("NCEDC")
client_iris = Client("IRIS") ## HardCode: IRIS for response file
Expand Down Expand Up @@ -196,7 +196,8 @@ def __init__(
self.highpass_filter = highpass_filter
# self.response_xml = response_xml
if response_xml is not None:
self.response = obspy.read_inventory(response_xml)
# self.response = obspy.read_inventory(response_xml)
self.response = response_xml
else:
self.response = None
self.sampling_rate = sampling_rate
Expand Down Expand Up @@ -367,50 +368,14 @@ def read_mseed(self, fname, response=None, highpass_filter=0.0, sampling_rate=10
stream = stream.merge(fill_value="latest")

## FIX: hard code for response file
## NCEDC
station, network, channel = files[0].split("/")[-1].split(".")[:3]
response_xml = f"gs://quakeflow_catalog/NC/FDSNstationXML/{network}/{network}.{station}.xml"
# response_xml = (
# f"gs://quakeflow_dataset/NC/FDSNstationXML/{network}.info/{network}.FDSN.xml/{network}.{station}.xml"
# )

## SCEDC
# fname = files[0].split("/")[-1]
# network = fname[:2]
# station = fname[2:7].rstrip("_")
# instrument = fname[7:9]
# channel = fname[9]
# location = fname[10:12].rstrip("_")
# year = fname[13:17]
# jday = fname[17:20]
# response_xml = f"gs://quakeflow_catalog/SC/FDSNstationXML/{network}/{network}_{station}.xml"

redownload = True
if fs_gs.exists(response_xml):
try:
with fs_gs.open(response_xml, "rb") as fp:
response = obspy.read_inventory(fp)
stream = stream.remove_sensitivity(response)
redownload = False
except Exception as e:
print(f"Error removing sensitivity: {e}")
else:
redownload = True
if redownload:
try:
response = client.get_stations(network=network, station=station, level="response")
except Exception as e:
print(f"Error downloading response: {e}")
print(f"Retry downloading response from IRIS...")
try:
response = client_iris.get_stations(network=network, station=station, level="response")
except Exception as e:
print(f"Error downloading response from IRIS: {e}")
raise
response.write(f"/tmp/{network}_{station}.xml", format="stationxml")
fs_gs.put(f"/tmp/{network}_{station}.xml", response_xml)
print(f"Update response file: {response_xml}")
station_id = files[0].split("/")[-1].replace(".mseed", "")[:-1]
response_xml = f"{response.rstrip('/')}/{station_id}.xml"
try:
with fsspec.open(response_xml, "rb") as fp:
response = obspy.read_inventory(fp)
stream = stream.remove_sensitivity(response)
except Exception as e:
print(f"Error removing sensitivity: {e}")

except Exception as e:
print(f"Error reading {fname}:\n{e}")
Expand Down Expand Up @@ -539,7 +504,7 @@ def read_mseed_3c(self, fname, response=None, highpass_filter=0.0, sampling_rate
if len(station_ids) > 1:
print(f"{station_ids = }")
raise
assert (len(station_ids) == 1, f"Error: {fname} has multiple stations {station_ids}")
assert len(station_ids) == 1, f"Error: {fname} has multiple stations {station_ids}"

begin_time = min([st.stats.starttime for st in traces])
end_time = max([st.stats.endtime for st in traces])
Expand Down Expand Up @@ -953,6 +918,7 @@ def __getitem__(self, i):
# )
meta = self.read_mseed(
base_name,
response=self.response,
sampling_rate=self.sampling_rate,
highpass_filter=self.highpass_filter,
return_single_station=True,
Expand Down
119 changes: 56 additions & 63 deletions phasenet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

# token_json = f"{os.environ['HOME']}/.config/gcloud/application_default_credentials.json"
token_json = "application_default_credentials.json"
with open(token_json, "r") as fp:
token = json.load(fp)
fs_gs = fsspec.filesystem("gs", token=token)
# token_json = "application_default_credentials.json"
# with open(token_json, "r") as fp:
# token = json.load(fp)
# fs_gs = fsspec.filesystem("gs", token=token)


def read_args():
Expand Down Expand Up @@ -151,28 +151,29 @@ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None):
dt=1.0 / args.sampling_rate,
)

picks.extend(picks_)
# picks.extend(picks_)

# ## save pick per file

if len(fname_batch) == 1:
# ### FIX: Hard code for NCEDC and SCEDC
tmp = fname_batch[0].decode().split(",")[0].lstrip("s3://").split("/")
parant_dir = "/".join(tmp[2:-1]) # remove s3://ncedc-pds/continuous and mseed file name
tmp = fname_batch[0].decode().split(",")[0].split("/")
subdir = "/".join(tmp[-1-3:-1])
fname = tmp[-1].rstrip("\n").rstrip(".mseed").rstrip(".ms") + ".csv"
csv_name = f"quakeflow_catalog/NC/phasenet/{parant_dir}/{fname}"
# csv_name = f"quakeflow_catalog/SC/phasenet/{parant_dir}/{fname}"
if not os.path.exists(os.path.join(args.result_dir, "picks", parant_dir)):
os.makedirs(os.path.join(args.result_dir, "picks", parant_dir), exist_ok=True)
# csv_name = f"quakeflow_catalog/NC/phasenet/{subdir}/{fname}"
# csv_name = f"quakeflow_catalog/SC/phasenet/{subdir}/{fname}"
if not os.path.exists(os.path.join(args.result_dir, "picks", subdir)):
os.makedirs(os.path.join(args.result_dir, "picks", subdir), exist_ok=True)
csv_file = os.path.join(args.result_dir, "picks", subdir, fname)

if len(picks_) == 0:
with fs_gs.open(csv_name, "w") as fp:
with open(csv_file, "w") as fp:
fp.write("")
else:
df = pd.DataFrame(picks_)
df = df[df["phase_index"] > 10]
if len(df) == 0:
with fs_gs.open(csv_name, "w") as fp:
with open(csv_file, "w") as fp:
fp.write("")
else:
df["phase_amplitude"] = df["phase_amplitude"].apply(lambda x: f"{x:.3e}")
Expand All @@ -189,24 +190,16 @@ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None):
]
]
df.sort_values(by=["phase_time"], inplace=True)
df.to_csv(
os.path.join(
args.result_dir,
"picks",
parant_dir,
fname,
),
index=False,
)
fs_gs.put(
os.path.join(
args.result_dir,
"picks",
parant_dir,
fname,
),
csv_name,
)
df.to_csv(csv_file, index=False)
# fs_gs.put(
# os.path.join(
# args.result_dir,
# "picks",
# subdir,
# fname,
# ),
# csv_name,
# )

if args.plot_figure:
if not (isinstance(fname_batch, np.ndarray) or isinstance(fname_batch, list)):
Expand All @@ -230,38 +223,38 @@ def pred_fn(args, data_reader, figure_dir=None, prob_dir=None, log_dir=None):
fname_batch = [x.decode() for x in fname_batch]
save_prob_h5(pred_batch, fname_batch, prob_h5)

if len(picks) > 0:
# save_picks(picks, args.result_dir, amps=amps, fname=args.result_fname+".csv")
# save_picks_json(picks, args.result_dir, dt=data_reader.dt, amps=amps, fname=args.result_fname+".json")
df = pd.DataFrame(picks)
# df["fname"] = df["file_name"]
# df["id"] = df["station_id"]
# df["timestamp"] = df["phase_time"]
# df["prob"] = df["phase_prob"]
# df["type"] = df["phase_type"]

base_columns = [
"station_id",
"begin_time",
"phase_index",
"phase_time",
"phase_score",
"phase_type",
"file_name",
]
if args.amplitude:
base_columns.append("phase_amplitude")
base_columns.append("phase_amp")
df["phase_amp"] = df["phase_amplitude"]

df = df[base_columns]
df.to_csv(os.path.join(args.result_dir, args.result_fname + ".csv"), index=False)

print(
f"Done with {len(df[df['phase_type'] == 'P'])} P-picks and {len(df[df['phase_type'] == 'S'])} S-picks"
)
else:
print(f"Done with 0 P-picks and 0 S-picks")
# if len(picks) > 0:
# # save_picks(picks, args.result_dir, amps=amps, fname=args.result_fname+".csv")
# # save_picks_json(picks, args.result_dir, dt=data_reader.dt, amps=amps, fname=args.result_fname+".json")
# df = pd.DataFrame(picks)
# # df["fname"] = df["file_name"]
# # df["id"] = df["station_id"]
# # df["timestamp"] = df["phase_time"]
# # df["prob"] = df["phase_prob"]
# # df["type"] = df["phase_type"]

# base_columns = [
# "station_id",
# "begin_time",
# "phase_index",
# "phase_time",
# "phase_score",
# "phase_type",
# "file_name",
# ]
# if args.amplitude:
# base_columns.append("phase_amplitude")
# base_columns.append("phase_amp")
# df["phase_amp"] = df["phase_amplitude"]

# df = df[base_columns]
# df.to_csv(os.path.join(args.result_dir, args.result_fname + ".csv"), index=False)

# print(
# f"Done with {len(df[df['phase_type'] == 'P'])} P-picks and {len(df[df['phase_type'] == 'S'])} S-picks"
# )
# else:
# print(f"Done with 0 P-picks and 0 S-picks")
return 0


Expand Down

0 comments on commit 4f36f9a

Please sign in to comment.