Spaces:
Runtime error
Runtime error
# Gradio app that takes seismic waveform as input and marks 2 phases on the waveform as output. | |
import gradio as gr | |
import numpy as np | |
from phasehunter.model import Onset_picker, Updated_onset_picker | |
from phasehunter.data_preparation import prepare_waveform | |
import torch | |
from scipy.stats import gaussian_kde | |
import obspy | |
from obspy.clients.fdsn import Client | |
from obspy.clients.fdsn.header import FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException | |
from obspy.geodetics.base import locations2degrees | |
from obspy.taup import TauPyModel | |
from obspy.taup.helper_classes import SlownessModelError | |
from obspy.clients.fdsn.header import URL_MAPPINGS | |
import matplotlib.pyplot as plt | |
def make_prediction(waveform): | |
waveform = np.load(waveform) | |
processed_input = prepare_waveform(waveform) | |
# Make prediction | |
with torch.no_grad(): | |
output = model(processed_input) | |
p_phase = output[:, 0] | |
s_phase = output[:, 1] | |
return processed_input, p_phase, s_phase | |
def mark_phases(waveform): | |
processed_input, p_phase, s_phase = make_prediction(waveform) | |
# Create a plot of the waveform with the phases marked | |
if sum(processed_input[0][2] == 0): #if input is 1C | |
fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True) | |
ax[0].plot(processed_input[0][0]) | |
ax[0].set_ylabel('Norm. Ampl.') | |
else: #if input is 3C | |
fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True) | |
ax[0].plot(processed_input[0][0]) | |
ax[1].plot(processed_input[0][1]) | |
ax[2].plot(processed_input[0][2]) | |
ax[0].set_ylabel('Z') | |
ax[1].set_ylabel('N') | |
ax[2].set_ylabel('E') | |
p_phase_plot = p_phase*processed_input.shape[-1] | |
p_kde = gaussian_kde(p_phase_plot) | |
p_dist_space = np.linspace( min(p_phase_plot)-10, max(p_phase_plot)+10, 500 ) | |
ax[-1].plot( p_dist_space, p_kde(p_dist_space), color='r') | |
s_phase_plot = s_phase*processed_input.shape[-1] | |
s_kde = gaussian_kde(s_phase_plot) | |
s_dist_space = np.linspace( min(s_phase_plot)-10, max(s_phase_plot)+10, 500 ) | |
ax[-1].plot( s_dist_space, s_kde(s_dist_space), color='b') | |
for a in ax: | |
a.axvline(p_phase.mean()*processed_input.shape[-1], color='r', linestyle='--', label='P') | |
a.axvline(s_phase.mean()*processed_input.shape[-1], color='b', linestyle='--', label='S') | |
ax[-1].set_xlabel('Time, samples') | |
ax[-1].set_ylabel('Uncert.') | |
ax[-1].legend() | |
plt.subplots_adjust(hspace=0., wspace=0.) | |
# Convert the plot to an image and return it | |
fig.canvas.draw() | |
image = np.array(fig.canvas.renderer.buffer_rgba()) | |
plt.close(fig) | |
return image | |
def download_data(timestamp, eq_lat, eq_lon, client_name, radius_km): | |
client = Client(client_name) | |
window = radius_km / 111.2 | |
assert eq_lat - window > -90 and eq_lat + window < 90, "Latitude out of bounds" | |
assert eq_lon - window > -180 and eq_lon + window < 180, "Longitude out of bounds" | |
starttime = obspy.UTCDateTime(timestamp) | |
endtime = startime + 120 | |
inv = client.get_stations(network="*", station="*", location="*", channel="*H*", | |
starttime=obspy.UTCDateTime(starttime), endtime=endtime, | |
minlatitude=eq_lat-window, maxlatitude=eq_lat+window, | |
minlongitude=eq_lon-window, maxlongitude=eq_lon+window, | |
level='channel') | |
for network in inv: | |
for station in network: | |
print(station) | |
# waveform = client.get_waveforms(network=network.code, station=station.code, location="*", channel="*", | |
# starttime=obspy.UTCDateTime(start_date), endtime=obspy.UTCDateTime(end_date)) | |
return 0 | |
model = Onset_picker.load_from_checkpoint("./weights.ckpt", | |
picker=Updated_onset_picker(), | |
learning_rate=3e-4) | |
model.eval() | |
# # Create the Gradio interface | |
# gr.Interface(mark_phases, inputs, outputs, title='PhaseHunter').launch() | |
with gr.Blocks() as demo: | |
gr.Markdown("# PhaseHunter") | |
with gr.Tab("Default example"): | |
# Define the input and output types for Gradio | |
inputs = gr.Dropdown( | |
["data/sample/sample_0.npy", | |
"data/sample/sample_1.npy", | |
"data/sample/sample_2.npy"], | |
label="Sample waveform", | |
info="Select one of the samples", | |
value = "data/sample/sample_0.npy" | |
) | |
button = gr.Button("Predict phases") | |
outputs = gr.outputs.Image(label='Waveform with Phases Marked', type='numpy') | |
button.click(mark_phases, inputs=inputs, outputs=outputs) | |
with gr.Tab("Select earthquake from catalogue"): | |
gr.Markdown('TEST') | |
client_inputs = gr.Dropdown( | |
choices = list(URL_MAPPINGS.keys()), | |
label="FDSN Client", | |
info="Select one of the available FDSN clients", | |
value = "IRIS", | |
interactive=True | |
) | |
with gr.Row(): | |
timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49', | |
placeholder='YYYY-MM-DD HH:MM:SS', | |
label="Timestamp", | |
info="Timestamp of the earthquake", | |
max_lines=1, | |
interactive=True) | |
eq_lat_inputs = gr.Number(value=35.766, | |
label="Latitude", | |
info="Latitude of the earthquake", | |
interactive=True) | |
eq_lo_inputs = gr.Number(value=117.605, | |
label="Longitude", | |
info="Longitude of the earthquake", | |
interactive=True) | |
radius_inputs = gr.Slider(minimum=1, | |
maximum=150, | |
value=50, label="Radius (km)", | |
info="Select the radius around the earthquake to download data from", | |
interactive=True) | |
button = gr.Button("Predict phases") | |
button.click(mark_phases, inputs=inputs, outputs=outputs) | |
with gr.Tab("Predict on your own waveform"): | |
gr.Markdown(""" | |
Please upload your waveform in .npy (numpy) format. | |
Your waveform should be sampled at 100 sps and have 3 (Z, N, E) or 1 (Z) channels. | |
""") | |
button.click(download_data, inputs=[timestamp_inputs, eq_lat_inputs,eq_lo_inputs, radius_inputs], outputs=outputs) | |
demo.launch() |