{ "cells": [ { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7904\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Gradio app that takes seismic waveform as input and marks 2 phases on the waveform as output.\n", "\n", "import gradio as gr\n", "import numpy as np\n", "import pandas as pd\n", "from phasehunter.model import Onset_picker, Updated_onset_picker\n", "from phasehunter.data_preparation import prepare_waveform\n", "import torch\n", "\n", "from scipy.stats import gaussian_kde\n", "from bmi_topography import Topography\n", "import earthpy.spatial as es\n", "\n", "import obspy\n", "from obspy.clients.fdsn import Client\n", "from obspy.clients.fdsn.header import FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException\n", "from obspy.geodetics.base import locations2degrees\n", "from obspy.taup import TauPyModel\n", "from obspy.taup.helper_classes import SlownessModelError\n", "\n", "from obspy.clients.fdsn.header import URL_MAPPINGS\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.dates as mdates\n", "from matplotlib.colors import LightSource\n", "\n", "from glob import glob\n", "\n", "def make_prediction(waveform):\n", " waveform = np.load(waveform)\n", " processed_input = prepare_waveform(waveform)\n", " \n", " # Make prediction\n", " with torch.no_grad():\n", " output = model(processed_input)\n", "\n", " p_phase = output[:, 0]\n", " s_phase = output[:, 1]\n", "\n", " return processed_input, p_phase, s_phase\n", "\n", "def mark_phases(waveform, uploaded_file):\n", "\n", " if uploaded_file is not None:\n", " waveform = uploaded_file.name\n", "\n", " processed_input, p_phase, s_phase = make_prediction(waveform)\n", "\n", " # Create a plot of the waveform with the phases marked\n", " if sum(processed_input[0][2] == 0): #if input is 1C\n", " fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)\n", "\n", " ax[0].plot(processed_input[0][0], color='black', lw=1)\n", " ax[0].set_ylabel('Norm. Ampl.')\n", "\n", " else: #if input is 3C\n", " fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)\n", " ax[0].plot(processed_input[0][0], color='black', lw=1)\n", " ax[1].plot(processed_input[0][1], color='black', lw=1)\n", " ax[2].plot(processed_input[0][2], color='black', lw=1)\n", "\n", " ax[0].set_ylabel('Z')\n", " ax[1].set_ylabel('N')\n", " ax[2].set_ylabel('E')\n", "\n", " p_phase_plot = p_phase*processed_input.shape[-1]\n", " p_kde = gaussian_kde(p_phase_plot)\n", " p_dist_space = np.linspace( min(p_phase_plot)-10, max(p_phase_plot)+10, 500 )\n", " ax[-1].plot( p_dist_space, p_kde(p_dist_space), color='r')\n", "\n", " s_phase_plot = s_phase*processed_input.shape[-1]\n", " s_kde = gaussian_kde(s_phase_plot)\n", " s_dist_space = np.linspace( min(s_phase_plot)-10, max(s_phase_plot)+10, 500 )\n", " ax[-1].plot( s_dist_space, s_kde(s_dist_space), color='b')\n", "\n", " for a in ax:\n", " a.axvline(p_phase.mean()*processed_input.shape[-1], color='r', linestyle='--', label='P')\n", " a.axvline(s_phase.mean()*processed_input.shape[-1], color='b', linestyle='--', label='S')\n", "\n", " ax[-1].set_xlabel('Time, samples')\n", " ax[-1].set_ylabel('Uncert.')\n", " ax[-1].legend()\n", "\n", " plt.subplots_adjust(hspace=0., wspace=0.)\n", "\n", " # Convert the plot to an image and return it\n", " fig.canvas.draw()\n", " image = np.array(fig.canvas.renderer.buffer_rgba())\n", " plt.close(fig)\n", " return image\n", "\n", "def bin_distances(distances, bin_size=10):\n", " # Bin the distances into groups of `bin_size` kilometers\n", " binned_distances = {}\n", " for i, distance in enumerate(distances):\n", " bin_index = distance // bin_size\n", " if bin_index not in binned_distances:\n", " binned_distances[bin_index] = (distance, i)\n", " elif i < binned_distances[bin_index][1]:\n", " binned_distances[bin_index] = (distance, i)\n", "\n", " # Select the first distance in each bin and its index\n", " first_distances = []\n", " for bin_index in binned_distances:\n", " first_distance, first_distance_index = binned_distances[bin_index]\n", " first_distances.append(first_distance_index)\n", " \n", " return first_distances\n", "\n", "def variance_coefficient(residuals):\n", " # calculate the variance of the residuals\n", " var = residuals.var()\n", " # scale the variance to a coefficient between 0 and 1\n", " coeff = 1 - (var / (residuals.max() - residuals.min()))\n", " return coeff\n", "\n", "def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source_depth_km, velocity_model, max_waveforms):\n", " distances, t0s, st_lats, st_lons, waveforms = [], [], [], [], []\n", " \n", " taup_model = TauPyModel(model=velocity_model)\n", " client = Client(client_name)\n", "\n", " window = radius_km / 111.2\n", " max_waveforms = int(max_waveforms)\n", "\n", " assert eq_lat - window > -90 and eq_lat + window < 90, \"Latitude out of bounds\"\n", " assert eq_lon - window > -180 and eq_lon + window < 180, \"Longitude out of bounds\"\n", "\n", " starttime = obspy.UTCDateTime(timestamp)\n", " endtime = starttime + 120\n", "\n", " try:\n", " print('Starting to download inventory')\n", " inv = client.get_stations(network=\"*\", station=\"*\", location=\"*\", channel=\"*H*\", \n", " starttime=starttime, endtime=endtime, \n", " minlatitude=(eq_lat-window), maxlatitude=(eq_lat+window),\n", " minlongitude=(eq_lon-window), maxlongitude=(eq_lon+window), \n", " level='station')\n", " print('Finished downloading inventory')\n", " except (IndexError, FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException):\n", " fig, ax = plt.subplots()\n", " ax.text(0.5,0.5,'Something is wrong with the data provider, try another')\n", " fig.canvas.draw();\n", " image = np.array(fig.canvas.renderer.buffer_rgba())\n", " plt.close(fig)\n", " return image\n", " \n", " waveforms = []\n", " cached_waveforms = glob(\"data/cached/*.mseed\")\n", "\n", " for network in inv:\n", " # Skip the SYntetic networks\n", " if network.code == 'SY':\n", " continue\n", " for station in network:\n", " print(f\"Processing {network.code}.{station.code}...\")\n", " distance = locations2degrees(eq_lat, eq_lon, station.latitude, station.longitude)\n", "\n", " arrivals = taup_model.get_travel_times(source_depth_in_km=source_depth_km, \n", " distance_in_degree=distance, \n", " phase_list=[\"P\", \"S\"])\n", "\n", " if len(arrivals) > 0:\n", "\n", " starttime = obspy.UTCDateTime(timestamp) + arrivals[0].time - 15\n", " endtime = starttime + 60\n", " try:\n", " if f\"data/cached/{network.code}_{station.code}_{starttime}.mseed\" not in cached_waveforms:\n", " print('Downloading waveform')\n", " waveform = client.get_waveforms(network=network.code, station=station.code, location=\"*\", channel=\"*\", \n", " starttime=starttime, endtime=endtime)\n", " waveform.write(f\"data/cached/{network.code}_{station.code}_{starttime}.mseed\", format=\"MSEED\")\n", " print('Finished downloading and caching waveform')\n", " else:\n", " print('Reading cached waveform')\n", " waveform = obspy.read(f\"data/cached/{network.code}_{station.code}_{starttime}.mseed\")\n", " \n", "\n", " except (IndexError, FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException):\n", " print(f'Skipping {network.code}_{station.code}_{starttime}')\n", " continue\n", " \n", " waveform = waveform.select(channel=\"H[BH][ZNE]\")\n", " waveform = waveform.merge(fill_value=0)\n", " waveform = waveform[:3]\n", " \n", " len_check = [len(x.data) for x in waveform]\n", " if len(set(len_check)) > 1:\n", " continue\n", "\n", " if len(waveform) == 3:\n", " try:\n", " waveform = prepare_waveform(np.stack([x.data for x in waveform]))\n", "\n", " distances.append(distance)\n", " t0s.append(starttime)\n", " st_lats.append(station.latitude)\n", " st_lons.append(station.longitude)\n", " waveforms.append(waveform)\n", "\n", " print(f\"Added {network.code}.{station.code} to the list of waveforms\")\n", "\n", " except:\n", " continue\n", " \n", " \n", " # If there are no waveforms, return an empty plot\n", " if len(waveforms) == 0:\n", " fig, ax = plt.subplots()\n", " ax.text(0.5,0.5,'No waveforms found')\n", " fig.canvas.draw();\n", " image = np.array(fig.canvas.renderer.buffer_rgba())\n", " plt.close(fig)\n", " return image\n", " \n", "\n", " first_distances = bin_distances(distances, bin_size=10/111.2)\n", "\n", " # Edge case when there are way too many waveforms to process\n", " selection_indexes = np.random.choice(first_distances, \n", " np.min([len(first_distances), max_waveforms]),\n", " replace=False)\n", "\n", " waveforms = np.array(waveforms)[selection_indexes]\n", " distances = np.array(distances)[selection_indexes]\n", " t0s = np.array(t0s)[selection_indexes]\n", " st_lats = np.array(st_lats)[selection_indexes]\n", " st_lons = np.array(st_lons)[selection_indexes]\n", "\n", " waveforms = [torch.tensor(waveform) for waveform in waveforms]\n", "\n", " print('Starting to run predictions')\n", " with torch.no_grad():\n", " waveforms_torch = torch.vstack(waveforms)\n", " output = model(waveforms_torch)\n", "\n", " p_phases = output[:, 0]\n", " s_phases = output[:, 1]\n", "\n", " # Max confidence - min variance \n", " p_max_confidence = np.min([p_phases[i::len(waveforms)].std() for i in range(len(waveforms))]) \n", " s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))])\n", "\n", " print(f\"Starting plotting {len(waveforms)} waveforms\")\n", " fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 3))\n", "\n", " # Plot topography\n", " print('Fetching topography')\n", " params = Topography.DEFAULT.copy()\n", " extra_window = 0.5\n", " params[\"south\"] = np.min([st_lats.min(), eq_lat])-extra_window\n", " params[\"north\"] = np.max([st_lats.max(), eq_lat])+extra_window\n", " params[\"west\"] = np.min([st_lons.min(), eq_lon])-extra_window\n", " params[\"east\"] = np.max([st_lons.max(), eq_lon])+extra_window\n", "\n", " topo_map = Topography(**params)\n", " topo_map.fetch()\n", " topo_map.load()\n", "\n", " print('Plotting topo')\n", " hillshade = es.hillshade(topo_map.da[0], altitude=10)\n", " \n", " topo_map.da.plot(ax = ax[1], cmap='Greys', add_colorbar=False, add_labels=False)\n", " topo_map.da.plot(ax = ax[2], cmap='Greys', add_colorbar=False, add_labels=False)\n", " ax[1].imshow(hillshade, cmap=\"Greys\", alpha=0.5)\n", "\n", " for i in range(len(waveforms)):\n", " print(f\"Plotting waveform {i+1}/{len(waveforms)}\")\n", " current_P = p_phases[i::len(waveforms)]\n", " current_S = s_phases[i::len(waveforms)]\n", "\n", " x = [t0s[i] + pd.Timedelta(seconds=k/100) for k in np.linspace(0,6000,6000)]\n", " x = mdates.date2num(x)\n", "\n", " # Normalize confidence for the plot\n", " p_conf = 1/(current_P.std()/p_max_confidence).item()\n", " s_conf = 1/(current_S.std()/s_max_confidence).item()\n", "\n", " ax[0].plot(x, waveforms[i][0, 0]*10+distances[i]*111.2, color='black', alpha=0.5, lw=1)\n", "\n", " ax[0].scatter(x[int(current_P.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='r', alpha=p_conf, marker='|')\n", " ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|')\n", " ax[0].set_ylabel('Z')\n", "\n", " ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))\n", " ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=20))\n", "\n", " delta_t = t0s[i].timestamp - obspy.UTCDateTime(timestamp).timestamp\n", "\n", " velocity_p = (distances[i]*111.2)/(delta_t+current_P.mean()*60).item()\n", " velocity_s = (distances[i]*111.2)/(delta_t+current_S.mean()*60).item()\n", "\n", " print(f\"Station {st_lats[i]}, {st_lons[i]} has P velocity {velocity_p} and S velocity {velocity_s}\")\n", " \n", " # Generate an array from st_lat to eq_lat and from st_lon to eq_lon\n", " x = np.linspace(st_lons[i], eq_lon, 50)\n", " y = np.linspace(st_lats[i], eq_lat, 50)\n", " \n", " # Plot the array\n", " ax[1].scatter(x, y, c=np.zeros_like(x)+velocity_p, alpha=0.5, vmin=0, vmax=8)\n", " ax[2].scatter(x, y, c=np.zeros_like(x)+velocity_s, alpha=0.5, vmin=0, vmax=8)\n", "\n", " # Add legend\n", " ax[0].scatter(None, None, color='r', marker='|', label='P')\n", " ax[0].scatter(None, None, color='b', marker='|', label='S')\n", " ax[0].legend()\n", "\n", " print('Plotting stations')\n", " for i in range(1,3):\n", " ax[i].scatter(st_lons, st_lats, color='b', label='Stations')\n", " ax[i].scatter(eq_lon, eq_lat, color='r', marker='*', label='Earthquake')\n", "\n", " # Generate colorbar for the velocity plot\n", " cbar = plt.colorbar(ax[1].scatter(None, None, c=velocity_p, alpha=0.5, vmin=0, vmax=8), ax=ax[1])\n", " cbar.set_label('P Velocity (km/s)')\n", " ax[1].set_title('P Velocity')\n", "\n", " cbar = plt.colorbar(ax[2].scatter(None, None, c=velocity_s, alpha=0.5, vmin=0, vmax=8), ax=ax[2])\n", " cbar.set_label('S Velocity (km/s)')\n", " ax[2].set_title('S Velocity')\n", "\n", "\n", "\n", " plt.subplots_adjust(hspace=0., wspace=0.5)\n", "\n", " fig.canvas.draw();\n", " image = np.array(fig.canvas.renderer.buffer_rgba())\n", " plt.close(fig)\n", "\n", " return image\n", "\n", "\n", "model = Onset_picker.load_from_checkpoint(\"./weights.ckpt\",\n", " picker=Updated_onset_picker(),\n", " learning_rate=3e-4)\n", "model.eval()\n", "\n", "with gr.Blocks() as demo:\n", " gr.HTML(\"\"\"

PhaseHunter

\n", "

This app allows one to detect P and S seismic phases along with \n", "\n", "\n", " u\n", " n\n", " c\n", " e\n", " r\n", " t\n", " a\n", " i\n", " n\n", " t\n", " y\n", " \n", "\n", " of the detection.

\n", "
    \n", "
  1. By selecting one of the sample waveforms.
  2. \n", "
  3. By uploading your own waveform.
  4. \n", "
  5. By selecting an earthquake from the global earthquake catalogue.
  6. \n", "
\n", "

Please upload your waveform in .npy (numpy) format.

\n", "

Your waveform should be sampled at 100 samples per second and have 3 (Z, N, E) or 1 (Z) channels. If your file is longer than 60 seconds, the app will only use the first 60 seconds of the waveform.

\n", " \"\"\")\n", " with gr.Tab(\"Try on a single station\"):\n", " with gr.Row(): \n", " # Define the input and output types for Gradio\n", " inputs = gr.Dropdown(\n", " [\"data/sample/sample_0.npy\", \n", " \"data/sample/sample_1.npy\", \n", " \"data/sample/sample_2.npy\"], \n", " label=\"Sample waveform\", \n", " info=\"Select one of the samples\",\n", " value = \"data/sample/sample_0.npy\"\n", " )\n", "\n", " upload = gr.File(label=\"Or upload your own waveform\")\n", "\n", " button = gr.Button(\"Predict phases\")\n", " outputs = gr.Image(label='Waveform with Phases Marked', type='numpy', interactive=False)\n", " \n", " button.click(mark_phases, inputs=[inputs, upload], outputs=outputs)\n", " \n", " with gr.Tab(\"Select earthquake from catalogue\"):\n", " gr.Markdown(\"\"\"Select an earthquake from the global earthquake catalogue and the app will download the waveform from the FDSN client of your choice.\n", " \"\"\")\n", " \n", " client_inputs = gr.Dropdown(\n", " choices = list(URL_MAPPINGS.keys()), \n", " label=\"FDSN Client\", \n", " info=\"Select one of the available FDSN clients\",\n", " value = \"IRIS\",\n", " interactive=True\n", " )\n", "\n", " with gr.Row(): \n", "\n", " timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49',\n", " placeholder='YYYY-MM-DD HH:MM:SS',\n", " label=\"Timestamp\",\n", " info=\"Timestamp of the earthquake\",\n", " max_lines=1,\n", " interactive=True)\n", " \n", " eq_lat_inputs = gr.Number(value=35.766, \n", " label=\"Latitude\", \n", " info=\"Latitude of the earthquake\",\n", " interactive=True)\n", " \n", " eq_lon_inputs = gr.Number(value=-117.605,\n", " label=\"Longitude\",\n", " info=\"Longitude of the earthquake\",\n", " interactive=True)\n", " \n", " source_depth_inputs = gr.Number(value=10,\n", " label=\"Source depth (km)\",\n", " info=\"Depth of the earthquake\",\n", " interactive=True)\n", " \n", " radius_inputs = gr.Slider(minimum=1, \n", " maximum=150, \n", " value=50, label=\"Radius (km)\", \n", " step=10,\n", " info=\"\"\"Select the radius around the earthquake to download data from.\\n \n", " Note that the larger the radius, the longer the app will take to run.\"\"\",\n", " interactive=True)\n", " \n", " velocity_inputs = gr.Dropdown(\n", " choices = ['1066a', '1066b', 'ak135', \n", " 'ak135f', 'herrin', 'iasp91', \n", " 'jb', 'prem', 'pwdk'], \n", " label=\"1D velocity model\", \n", " info=\"Velocity model for station selection\",\n", " value = \"1066a\",\n", " interactive=True\n", " )\n", "\n", " max_waveforms_inputs = gr.Slider(minimum=1,\n", " maximum=100,\n", " value=10,\n", " label=\"Max waveforms per section\",\n", " step=1,\n", " info=\"Maximum number of waveforms to show per section\\n (to avoid long prediction times)\",\n", " interactive=True,\n", " )\n", " \n", " button = gr.Button(\"Predict phases\")\n", " outputs_section = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False)\n", " \n", " button.click(predict_on_section, \n", " inputs=[client_inputs, timestamp_inputs, \n", " eq_lat_inputs, eq_lon_inputs, \n", " radius_inputs, source_depth_inputs, \n", " velocity_inputs, max_waveforms_inputs],\n", " outputs=outputs_section)\n", "\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "phasehunter", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.2" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "6bf57068982d7b420bddaaf1d0614a7795947176033057024cf47d8ca2c1c4cd" } } }, "nbformat": 4, "nbformat_minor": 2 }