Skip to content

Commit

Permalink
update gradio app
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuwq0 committed Aug 21, 2023
1 parent 8d37e68 commit 460bfe1
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 56 deletions.
97 changes: 58 additions & 39 deletions docs/example_gradio.ipynb

Large diffs are not rendered by default.

44 changes: 27 additions & 17 deletions phasenet/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import requests
import tensorflow as tf
from fastapi import FastAPI
from fastapi import FastAPI, WebSocket
from kafka import KafkaProducer
from pydantic import BaseModel
from scipy.interpolate import interp1d
Expand Down Expand Up @@ -150,7 +150,6 @@ def format_picks(picks, dt, amplitudes):


def format_data(data):

# chn2idx = {"ENZ": {"E":0, "N":1, "Z":2},
# "123": {"3":0, "2":1, "1":2},
# "12Z": {"1":0, "2":1, "Z":2}}
Expand Down Expand Up @@ -189,7 +188,6 @@ def format_data(data):


def get_prediction(data, return_preds=False):

vec = np.array(data.vec)
vec, vec_raw = preprocess(vec)

Expand All @@ -198,7 +196,10 @@ def get_prediction(data, return_preds=False):

picks = extract_picks(preds, station_ids=data.id, begin_times=data.timestamp, waveforms=vec_raw)

picks = [{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]} for pick in picks]
picks = [
{k: v for k, v in pick.items() if k in ["station_id", "phase_time", "phase_score", "phase_type", "dt"]}
for pick in picks
]

if return_preds:
return picks, preds
Expand All @@ -211,8 +212,9 @@ class Data(BaseModel):
# timestamp: Union[List[str], str]
# vec: Union[List[List[List[float]]], List[List[float]]]
id: List[str]
timestamp: List[str]
timestamp: List[Union[str, float, datetime]]
vec: Union[List[List[List[float]]], List[List[float]]]

dt: Optional[float] = 0.01
## gamma
stations: Optional[List[Dict[str, Union[float, str]]]] = None
Expand All @@ -223,7 +225,7 @@ class Data(BaseModel):
# def set_default_executor():
# from concurrent.futures import ThreadPoolExecutor
# import asyncio
#
#
# loop = asyncio.get_running_loop()
# loop.set_default_executor(
# ThreadPoolExecutor(max_workers=2)
Expand All @@ -232,49 +234,58 @@ class Data(BaseModel):

@app.post("/predict")
def predict(data: Data):

picks = get_prediction(data)

return picks


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
while True:
data = await websocket.receive_json()
# data = json.loads(data)
data = Data(**data)
picks = get_prediction(data)
await websocket.send_json(picks)
print("PhaseNet Updating...")


@app.post("/predict_prob")
def predict(data: Data):

picks, preds = get_prediction(data, True)

return picks, preds.tolist()


@app.post("/predict_phasenet2gamma")
def predict(data: Data):

picks = get_prediction(data)

# if use_kafka:
# print("Push picks to kafka...")
# for pick in picks:
# producer.send("phasenet_picks", key=pick["id"], value=pick)
try:
catalog = requests.post(f"{GAMMA_API_URL}/predict", json={"picks": picks,
"stations": data.stations,
"config": data.config})
catalog = requests.post(
f"{GAMMA_API_URL}/predict", json={"picks": picks, "stations": data.stations, "config": data.config}
)
print(catalog.json()["catalog"])
return catalog.json()
except Exception as error:
print(error)

return {}


@app.post("/predict_phasenet2gamma2ui")
def predict(data: Data):

picks = get_prediction(data)

try:
catalog = requests.post(f"{GAMMA_API_URL}/predict", json={"picks": picks,
"stations": data.stations,
"config": data.config})
catalog = requests.post(
f"{GAMMA_API_URL}/predict", json={"picks": picks, "stations": data.stations, "config": data.config}
)
print(catalog.json()["catalog"])
return catalog.json()
except Exception as error:
Expand All @@ -293,7 +304,6 @@ def predict(data: Data):

@app.post("/predict_stream_phasenet2gamma")
def predict(data: Data):

data = format_data(data)
# for i in range(len(data.id)):
# plt.clf()
Expand Down

0 comments on commit 460bfe1

Please sign in to comment.