Spaces:
Runtime error
Runtime error
# Gradio app that takes seismic waveform as input and marks 2 phases on the waveform as output. | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from phasehunter.model import Onset_picker, Updated_onset_picker | |
from phasehunter.data_preparation import prepare_waveform | |
import torch | |
from scipy.stats import gaussian_kde | |
from bmi_topography import Topography | |
import earthpy.spatial as es | |
import obspy | |
from obspy.clients.fdsn import Client | |
from obspy.clients.fdsn.header import FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException | |
from obspy.geodetics.base import locations2degrees | |
from obspy.taup import TauPyModel | |
from obspy.taup.helper_classes import SlownessModelError | |
from obspy.clients.fdsn.header import URL_MAPPINGS | |
import matplotlib.pyplot as plt | |
import matplotlib.dates as mdates | |
from matplotlib.colors import LightSource | |
from glob import glob | |
def make_prediction(waveform): | |
waveform = np.load(waveform) | |
processed_input = prepare_waveform(waveform) | |
# Make prediction | |
with torch.no_grad(): | |
output = model(processed_input) | |
p_phase = output[:, 0] | |
s_phase = output[:, 1] | |
return processed_input, p_phase, s_phase | |
def mark_phases(waveform, uploaded_file): | |
if uploaded_file is not None: | |
waveform = uploaded_file.name | |
processed_input, p_phase, s_phase = make_prediction(waveform) | |
# Create a plot of the waveform with the phases marked | |
if sum(processed_input[0][2] == 0): #if input is 1C | |
fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True) | |
ax[0].plot(processed_input[0][0], color='black', lw=1) | |
ax[0].set_ylabel('Norm. Ampl.') | |
else: #if input is 3C | |
fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True) | |
ax[0].plot(processed_input[0][0], color='black', lw=1) | |
ax[1].plot(processed_input[0][1], color='black', lw=1) | |
ax[2].plot(processed_input[0][2], color='black', lw=1) | |
ax[0].set_ylabel('Z') | |
ax[1].set_ylabel('N') | |
ax[2].set_ylabel('E') | |
p_phase_plot = p_phase*processed_input.shape[-1] | |
p_kde = gaussian_kde(p_phase_plot) | |
p_dist_space = np.linspace( min(p_phase_plot)-10, max(p_phase_plot)+10, 500 ) | |
ax[-1].plot( p_dist_space, p_kde(p_dist_space), color='r') | |
s_phase_plot = s_phase*processed_input.shape[-1] | |
s_kde = gaussian_kde(s_phase_plot) | |
s_dist_space = np.linspace( min(s_phase_plot)-10, max(s_phase_plot)+10, 500 ) | |
ax[-1].plot( s_dist_space, s_kde(s_dist_space), color='b') | |
for a in ax: | |
a.axvline(p_phase.mean()*processed_input.shape[-1], color='r', linestyle='--', label='P') | |
a.axvline(s_phase.mean()*processed_input.shape[-1], color='b', linestyle='--', label='S') | |
ax[-1].set_xlabel('Time, samples') | |
ax[-1].set_ylabel('Uncert.') | |
ax[-1].legend() | |
plt.subplots_adjust(hspace=0., wspace=0.) | |
# Convert the plot to an image and return it | |
fig.canvas.draw() | |
image = np.array(fig.canvas.renderer.buffer_rgba()) | |
plt.close(fig) | |
return image | |
def bin_distances(distances, bin_size=10): | |
# Bin the distances into groups of `bin_size` kilometers | |
binned_distances = {} | |
for i, distance in enumerate(distances): | |
bin_index = distance // bin_size | |
if bin_index not in binned_distances: | |
binned_distances[bin_index] = (distance, i) | |
elif i < binned_distances[bin_index][1]: | |
binned_distances[bin_index] = (distance, i) | |
# Select the first distance in each bin and its index | |
first_distances = [] | |
for bin_index in binned_distances: | |
first_distance, first_distance_index = binned_distances[bin_index] | |
first_distances.append(first_distance_index) | |
return first_distances | |
def variance_coefficient(residuals): | |
# calculate the variance of the residuals | |
var = residuals.var() | |
# scale the variance to a coefficient between 0 and 1 | |
coeff = 1 - (var / (residuals.max() - residuals.min())) | |
return coeff | |
def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source_depth_km, velocity_model, max_waveforms): | |
distances, t0s, st_lats, st_lons, waveforms, names = [], [], [], [], [], [] | |
taup_model = TauPyModel(model=velocity_model) | |
client = Client(client_name) | |
window = radius_km / 111.2 | |
max_waveforms = int(max_waveforms) | |
assert eq_lat - window > -90 and eq_lat + window < 90, "Latitude out of bounds" | |
assert eq_lon - window > -180 and eq_lon + window < 180, "Longitude out of bounds" | |
starttime = obspy.UTCDateTime(timestamp) | |
endtime = starttime + 120 | |
try: | |
print('Starting to download inventory') | |
inv = client.get_stations(network="*", station="*", location="*", channel="*H*", | |
starttime=starttime, endtime=endtime, | |
minlatitude=(eq_lat-window), maxlatitude=(eq_lat+window), | |
minlongitude=(eq_lon-window), maxlongitude=(eq_lon+window), | |
level='station') | |
print('Finished downloading inventory') | |
except (IndexError, FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException): | |
fig, ax = plt.subplots() | |
ax.text(0.5,0.5,'Something is wrong with the data provider, try another') | |
fig.canvas.draw(); | |
image = np.array(fig.canvas.renderer.buffer_rgba()) | |
plt.close(fig) | |
return image | |
waveforms = [] | |
cached_waveforms = glob("data/cached/*.mseed") | |
for network in inv: | |
# Skip the SYntetic networks | |
if network.code == 'SY': | |
continue | |
for station in network: | |
print(f"Processing {network.code}.{station.code}...") | |
distance = locations2degrees(eq_lat, eq_lon, station.latitude, station.longitude) | |
arrivals = taup_model.get_travel_times(source_depth_in_km=source_depth_km, | |
distance_in_degree=distance, | |
phase_list=["P", "S"]) | |
if len(arrivals) > 0: | |
starttime = obspy.UTCDateTime(timestamp) + arrivals[0].time - 15 | |
endtime = starttime + 60 | |
try: | |
if f"data/cached/{network.code}_{station.code}_{starttime}.mseed" not in cached_waveforms: | |
print('Downloading waveform') | |
waveform = client.get_waveforms(network=network.code, station=station.code, location="*", channel="*", | |
starttime=starttime, endtime=endtime) | |
waveform.write(f"data/cached/{network.code}_{station.code}_{starttime}.mseed", format="MSEED") | |
print('Finished downloading and caching waveform') | |
else: | |
print('Reading cached waveform') | |
waveform = obspy.read(f"data/cached/{network.code}_{station.code}_{starttime}.mseed") | |
except (IndexError, FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException): | |
print(f'Skipping {network.code}_{station.code}_{starttime}') | |
continue | |
waveform = waveform.select(channel="H[BH][ZNE]") | |
waveform = waveform.merge(fill_value=0) | |
waveform = waveform[:3] | |
len_check = [len(x.data) for x in waveform] | |
if len(set(len_check)) > 1: | |
continue | |
if len(waveform) == 3: | |
try: | |
waveform = prepare_waveform(np.stack([x.data for x in waveform])) | |
distances.append(distance) | |
t0s.append(starttime) | |
st_lats.append(station.latitude) | |
st_lons.append(station.longitude) | |
waveforms.append(waveform) | |
names.append(f"{network.code}.{station.code}") | |
print(f"Added {network.code}.{station.code} to the list of waveforms") | |
except: | |
continue | |
# If there are no waveforms, return an empty plot | |
if len(waveforms) == 0: | |
fig, ax = plt.subplots() | |
ax.text(0.5,0.5,'No waveforms found') | |
fig.canvas.draw(); | |
image = np.array(fig.canvas.renderer.buffer_rgba()) | |
plt.close(fig) | |
return image | |
first_distances = bin_distances(distances, bin_size=10/111.2) | |
# Edge case when there are way too many waveforms to process | |
selection_indexes = np.random.choice(first_distances, | |
np.min([len(first_distances), max_waveforms]), | |
replace=False) | |
waveforms = np.array(waveforms)[selection_indexes] | |
distances = np.array(distances)[selection_indexes] | |
t0s = np.array(t0s)[selection_indexes] | |
st_lats = np.array(st_lats)[selection_indexes] | |
st_lons = np.array(st_lons)[selection_indexes] | |
names = np.array(names)[selection_indexes] | |
waveforms = [torch.tensor(waveform) for waveform in waveforms] | |
print('Starting to run predictions') | |
with torch.no_grad(): | |
waveforms_torch = torch.vstack(waveforms) | |
output = model(waveforms_torch) | |
p_phases = output[:, 0] | |
s_phases = output[:, 1] | |
# Max confidence - min variance | |
p_max_confidence = np.min([p_phases[i::len(waveforms)].std() for i in range(len(waveforms))]) | |
s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))]) | |
print(f"Starting plotting {len(waveforms)} waveforms") | |
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 3)) | |
# Plot topography | |
print('Fetching topography') | |
params = Topography.DEFAULT.copy() | |
extra_window = 0.5 | |
params["south"] = np.min([st_lats.min(), eq_lat])-extra_window | |
params["north"] = np.max([st_lats.max(), eq_lat])+extra_window | |
params["west"] = np.min([st_lons.min(), eq_lon])-extra_window | |
params["east"] = np.max([st_lons.max(), eq_lon])+extra_window | |
topo_map = Topography(**params) | |
topo_map.fetch() | |
topo_map.load() | |
print('Plotting topo') | |
hillshade = es.hillshade(topo_map.da[0], altitude=10) | |
topo_map.da.plot(ax = ax[1], cmap='Greys', add_colorbar=False, add_labels=False) | |
topo_map.da.plot(ax = ax[2], cmap='Greys', add_colorbar=False, add_labels=False) | |
ax[1].imshow(hillshade, cmap="Greys", alpha=0.5) | |
output_picks = pd.DataFrame({'station_name' : [], 'starttime' : [], | |
'p_phase' : [], 'p_uncertainty' : [], 's_phase' : [], 's_uncertainty' : [], | |
'velocity_p' : [], 'velocity_s' : []}) | |
for i in range(len(waveforms)): | |
print(f"Plotting waveform {i+1}/{len(waveforms)}") | |
current_P = p_phases[i::len(waveforms)] | |
current_S = s_phases[i::len(waveforms)] | |
x = [t0s[i] + pd.Timedelta(seconds=k/100) for k in np.linspace(0,6000,6000)] | |
x = mdates.date2num(x) | |
# Normalize confidence for the plot | |
p_conf = 1/(current_P.std()/p_max_confidence).item() | |
s_conf = 1/(current_S.std()/s_max_confidence).item() | |
ax[0].plot(x, waveforms[i][0, 0]*10+distances[i]*111.2, color='black', alpha=0.5, lw=1) | |
ax[0].scatter(x[int(current_P.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='r', alpha=p_conf, marker='|') | |
ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|') | |
ax[0].set_ylabel('Z') | |
ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) | |
ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=20)) | |
delta_t = t0s[i].timestamp - obspy.UTCDateTime(timestamp).timestamp | |
velocity_p = (distances[i]*111.2)/(delta_t+current_P.mean()*60).item() | |
velocity_s = (distances[i]*111.2)/(delta_t+current_S.mean()*60).item() | |
print(f"Station {st_lats[i]}, {st_lons[i]} has P velocity {velocity_p} and S velocity {velocity_s}") | |
output_picks = output_picks.append(pd.DataFrame({'station_name': [names[i]], 'starttime' : [str(t0s[i])], | |
'p_phase' : [(delta_t+current_P.mean()*60).item()], 'p_uncertainty' : [current_P.std().item()*60], | |
's_phase' : [(delta_t+current_S.mean()*60).item()], 's_uncertainty' : [current_S.std().item()*60], | |
'velocity_p' : [velocity_p], 'velocity_s' : [velocity_s]})) | |
# Generate an array from st_lat to eq_lat and from st_lon to eq_lon | |
x = np.linspace(st_lons[i], eq_lon, 50) | |
y = np.linspace(st_lats[i], eq_lat, 50) | |
# Plot the array | |
ax[1].scatter(x, y, c=np.zeros_like(x)+velocity_p, alpha=0.5, vmin=0, vmax=8) | |
ax[2].scatter(x, y, c=np.zeros_like(x)+velocity_s, alpha=0.5, vmin=0, vmax=8) | |
# Add legend | |
ax[0].scatter(None, None, color='r', marker='|', label='P') | |
ax[0].scatter(None, None, color='b', marker='|', label='S') | |
ax[0].legend() | |
print('Plotting stations') | |
for i in range(1,3): | |
ax[i].scatter(st_lons, st_lats, color='b', label='Stations') | |
ax[i].scatter(eq_lon, eq_lat, color='r', marker='*', label='Earthquake') | |
# Generate colorbar for the velocity plot | |
cbar = plt.colorbar(ax[1].scatter(None, None, c=velocity_p, alpha=0.5, vmin=0, vmax=8), ax=ax[1]) | |
cbar.set_label('P Velocity (km/s)') | |
ax[1].set_title('P Velocity') | |
cbar = plt.colorbar(ax[2].scatter(None, None, c=velocity_s, alpha=0.5, vmin=0, vmax=8), ax=ax[2]) | |
cbar.set_label('S Velocity (km/s)') | |
ax[2].set_title('S Velocity') | |
plt.subplots_adjust(hspace=0., wspace=0.5) | |
fig.canvas.draw(); | |
image = np.array(fig.canvas.renderer.buffer_rgba()) | |
plt.close(fig) | |
return image, output_picks | |
model = Onset_picker.load_from_checkpoint("./weights.ckpt", | |
picker=Updated_onset_picker(), | |
learning_rate=3e-4) | |
model.eval() | |
with gr.Blocks() as demo: | |
gr.HTML("""<h1>PhaseHunter</h1> | |
<p>This app allows one to detect <mark style="background-color: red; color: white;">P</mark> and <mark style="background-color: blue; color: white;">S</mark> seismic phases along with <span style="background-image: linear-gradient(to right, #f12711, #f5af19); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
background-clip: text; | |
font-size: 24px;"> | |
uncertainty | |
</span> of the detection.</p> | |
<ol> | |
<li>By selecting one of the sample waveforms.</li> | |
<li>By uploading your own waveform.</li> | |
<li>By selecting an earthquake from the global earthquake catalogue.</li> | |
</ol> | |
<p>Please upload your waveform in <code>.npy</code> (numpy) format.</p> | |
<p>Your waveform should be sampled at 100 samples per second and have 3 (Z, N, E) or 1 (Z) channels. If your file is longer than 60 seconds, the app will only use the first 60 seconds of the waveform.</p> | |
""") | |
with gr.Tab("Try on a single station"): | |
with gr.Row(): | |
# Define the input and output types for Gradio | |
inputs = gr.Dropdown( | |
["data/sample/sample_0.npy", | |
"data/sample/sample_1.npy", | |
"data/sample/sample_2.npy"], | |
label="Sample waveform", | |
info="Select one of the samples", | |
value = "data/sample/sample_0.npy" | |
) | |
upload = gr.File(label="Or upload your own waveform") | |
button = gr.Button("Predict phases") | |
outputs = gr.Image(label='Waveform with Phases Marked', type='numpy', interactive=False) | |
button.click(mark_phases, inputs=[inputs, upload], outputs=outputs) | |
with gr.Tab("Select earthquake from catalogue"): | |
gr.Markdown("""Select an earthquake from the global earthquake catalogue and the app will download the waveform from the FDSN client of your choice. | |
""") | |
with gr.Row(): | |
client_inputs = gr.Dropdown( | |
choices = list(URL_MAPPINGS.keys()), | |
label="FDSN Client", | |
info="Select one of the available FDSN clients", | |
value = "IRIS", | |
interactive=True | |
) | |
velocity_inputs = gr.Dropdown( | |
choices = ['1066a', '1066b', 'ak135', | |
'ak135f', 'herrin', 'iasp91', | |
'jb', 'prem', 'pwdk'], | |
label="1D velocity model", | |
info="Velocity model for station selection", | |
value = "1066a", | |
interactive=True | |
) | |
with gr.Column(scale=4): | |
with gr.Row(): | |
timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49', | |
placeholder='YYYY-MM-DD HH:MM:SS', | |
label="Timestamp", | |
info="Timestamp of the earthquake", | |
max_lines=1, | |
interactive=True) | |
eq_lat_inputs = gr.Number(value=35.766, | |
label="Latitude", | |
info="Latitude of the earthquake", | |
interactive=True) | |
eq_lon_inputs = gr.Number(value=-117.605, | |
label="Longitude", | |
info="Longitude of the earthquake", | |
interactive=True) | |
source_depth_inputs = gr.Number(value=10, | |
label="Source depth (km)", | |
info="Depth of the earthquake", | |
interactive=True) | |
with gr.Column(scale=2): | |
with gr.Row(): | |
radius_inputs = gr.Slider(minimum=1, | |
maximum=150, | |
value=50, label="Radius (km)", | |
step=10, | |
info="""Select the radius around the earthquake to download data from.\n | |
Note that the larger the radius, the longer the app will take to run.""", | |
interactive=True) | |
max_waveforms_inputs = gr.Slider(minimum=1, | |
maximum=100, | |
value=10, | |
label="Max waveforms per section", | |
step=1, | |
info="Maximum number of waveforms to show per section\n (to avoid long prediction times)", | |
interactive=True, | |
) | |
button = gr.Button("Predict phases") | |
output_image = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False) | |
output_picks = gr.Dataframe(label='# Pick data', type='pandas', interactive=False) | |
button.click(predict_on_section, | |
inputs=[client_inputs, timestamp_inputs, | |
eq_lat_inputs, eq_lon_inputs, | |
radius_inputs, source_depth_inputs, | |
velocity_inputs, max_waveforms_inputs], | |
outputs=[output_image, output_picks]) | |
demo.launch() |