crimeacs commited on
Commit
15dbd99
·
1 Parent(s): d4306c7

section plot now works

Browse files
Files changed (3) hide show
  1. Gradio_app.ipynb +83 -28
  2. app.py +31 -14
  3. requirements.txt +1 -0
Gradio_app.ipynb CHANGED
@@ -152,7 +152,28 @@
152
  },
153
  {
154
  "cell_type": "code",
155
- "execution_count": 64,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  "metadata": {},
157
  "outputs": [
158
  {
@@ -167,7 +188,7 @@
167
  "name": "stdout",
168
  "output_type": "stream",
169
  "text": [
170
- "Running on local URL: http://127.0.0.1:7914\n",
171
  "\n",
172
  "To create a public link, set `share=True` in `launch()`.\n"
173
  ]
@@ -175,7 +196,7 @@
175
  {
176
  "data": {
177
  "text/html": [
178
- "<div><iframe src=\"http://127.0.0.1:7914/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
179
  ],
180
  "text/plain": [
181
  "<IPython.core.display.HTML object>"
@@ -188,16 +209,9 @@
188
  "data": {
189
  "text/plain": []
190
  },
191
- "execution_count": 64,
192
  "metadata": {},
193
  "output_type": "execute_result"
194
- },
195
- {
196
- "name": "stdout",
197
- "output_type": "stream",
198
- "text": [
199
- "torch.Size([256])\n"
200
- ]
201
  }
202
  ],
203
  "source": [
@@ -205,6 +219,7 @@
205
  "\n",
206
  "import gradio as gr\n",
207
  "import numpy as np\n",
 
208
  "from phasehunter.model import Onset_picker, Updated_onset_picker\n",
209
  "from phasehunter.data_preparation import prepare_waveform\n",
210
  "import torch\n",
@@ -221,6 +236,7 @@
221
  "from obspy.clients.fdsn.header import URL_MAPPINGS\n",
222
  "\n",
223
  "import matplotlib.pyplot as plt\n",
 
224
  "\n",
225
  "def make_prediction(waveform):\n",
226
  " waveform = np.load(waveform)\n",
@@ -328,7 +344,10 @@
328
  " continue\n",
329
  "\n",
330
  " if len(waveform) == 3:\n",
331
- " waveform = prepare_waveform(np.stack([x.data for x in waveform]))\n",
 
 
 
332
  " \n",
333
  " distances.append(distance)\n",
334
  " t0s.append(starttime)\n",
@@ -346,19 +365,31 @@
346
  " p_phases = output[:, 0]\n",
347
  " s_phases = output[:, 1]\n",
348
  "\n",
349
- "\n",
350
- " print(p_phases.shape)\n",
351
- " # for i in range(len(waveforms)):\n",
352
- " # current_P = P_batch[i::len(waveforms)].cpu()\n",
353
- " # current_S_batch = S_batch[i::len(waveforms)].cpu()\n",
354
- " # current_Pg_batch = Pg_batch[i::len(waveforms)].cpu()\n",
355
- " # current_Sg_batch = Sg_batch[i::len(waveforms)].cpu()\n",
356
- " # current_Pn_batch = Pn_batch[i::len(waveforms)].cpu()\n",
357
- " # current_Sn_batch = Sn_batch[i::len(waveforms)].cpu()\n",
358
- " \n",
359
- " fig,ax = plt.subplots()\n",
360
- " ax.scatter(st_lats, st_lons)\n",
361
- " fig.canvas.draw()\n",
 
 
 
 
 
 
 
 
 
 
 
 
362
  " image = np.array(fig.canvas.renderer.buffer_rgba())\n",
363
  " plt.close(fig)\n",
364
  "\n",
@@ -470,10 +501,34 @@
470
  },
471
  {
472
  "cell_type": "code",
473
- "execution_count": null,
474
  "metadata": {},
475
- "outputs": [],
476
- "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  }
478
  ],
479
  "metadata": {
 
152
  },
153
  {
154
  "cell_type": "code",
155
+ "execution_count": 75,
156
+ "metadata": {},
157
+ "outputs": [
158
+ {
159
+ "ename": "NameError",
160
+ "evalue": "name 't0s' is not defined",
161
+ "output_type": "error",
162
+ "traceback": [
163
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
164
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
165
+ "Cell \u001b[0;32mIn[75], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mpandas\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mpd\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m pd\u001b[39m.\u001b[39mdate_range(start\u001b[39m=\u001b[39mt0s[i], periods\u001b[39m=\u001b[39mwaveforms[i][\u001b[39m0\u001b[39m]\u001b[39m.\u001b[39mshape[\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m], freq\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39m1s\u001b[39m\u001b[39m'\u001b[39m)\n",
166
+ "\u001b[0;31mNameError\u001b[0m: name 't0s' is not defined"
167
+ ]
168
+ }
169
+ ],
170
+ "source": [
171
+ "import pandas as pd\n"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 114,
177
  "metadata": {},
178
  "outputs": [
179
  {
 
188
  "name": "stdout",
189
  "output_type": "stream",
190
  "text": [
191
+ "Running on local URL: http://127.0.0.1:7935\n",
192
  "\n",
193
  "To create a public link, set `share=True` in `launch()`.\n"
194
  ]
 
196
  {
197
  "data": {
198
  "text/html": [
199
+ "<div><iframe src=\"http://127.0.0.1:7935/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
200
  ],
201
  "text/plain": [
202
  "<IPython.core.display.HTML object>"
 
209
  "data": {
210
  "text/plain": []
211
  },
212
+ "execution_count": 114,
213
  "metadata": {},
214
  "output_type": "execute_result"
 
 
 
 
 
 
 
215
  }
216
  ],
217
  "source": [
 
219
  "\n",
220
  "import gradio as gr\n",
221
  "import numpy as np\n",
222
+ "import pandas as pd\n",
223
  "from phasehunter.model import Onset_picker, Updated_onset_picker\n",
224
  "from phasehunter.data_preparation import prepare_waveform\n",
225
  "import torch\n",
 
236
  "from obspy.clients.fdsn.header import URL_MAPPINGS\n",
237
  "\n",
238
  "import matplotlib.pyplot as plt\n",
239
+ "import matplotlib.dates as mdates\n",
240
  "\n",
241
  "def make_prediction(waveform):\n",
242
  " waveform = np.load(waveform)\n",
 
344
  " continue\n",
345
  "\n",
346
  " if len(waveform) == 3:\n",
347
+ " try:\n",
348
+ " waveform = prepare_waveform(np.stack([x.data for x in waveform]))\n",
349
+ " except:\n",
350
+ " continue\n",
351
  " \n",
352
  " distances.append(distance)\n",
353
  " t0s.append(starttime)\n",
 
365
  " p_phases = output[:, 0]\n",
366
  " s_phases = output[:, 1]\n",
367
  "\n",
368
+ " fig, ax = plt.subplots(nrows=1, figsize=(10, 3), sharex=True)\n",
369
+ " for i in range(len(waveforms)):\n",
370
+ " current_P = p_phases[i::len(waveforms)]\n",
371
+ " current_S = s_phases[i::len(waveforms)]\n",
372
+ " x = [t0s[i] + pd.Timedelta(seconds=k/100) for k in np.linspace(0,6000,6000)]\n",
373
+ " x = mdates.date2num(x)\n",
374
+ " ax.plot(x, waveforms[i][0, 0]+distances[i]*111.2, color='black', alpha=0.5)\n",
375
+ " ax.scatter(x[int(current_P.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='r')\n",
376
+ " ax.scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b')\n",
377
+ " ax.set_ylabel('Z')\n",
378
+ "\n",
379
+ " ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))\n",
380
+ " ax.xaxis.set_major_locator(mdates.SecondLocator(interval=10))\n",
381
+ "\n",
382
+ " # for a in ax:\n",
383
+ " # a.axvline(current_P.mean()*waveforms[i][0].shape[-1], color='r', linestyle='--', label='P')\n",
384
+ " # a.axvline(current_S.mean()*waveforms[i][0].shape[-1], color='b', linestyle='--', label='S')\n",
385
+ "\n",
386
+ " # ax[-1].set_xlabel('Time, samples')\n",
387
+ " # ax[-1].set_ylabel('Uncert.')\n",
388
+ " # ax[-1].legend()\n",
389
+ "\n",
390
+ " plt.subplots_adjust(hspace=0., wspace=0.)\n",
391
+ " \n",
392
+ " fig.canvas.draw();\n",
393
  " image = np.array(fig.canvas.renderer.buffer_rgba())\n",
394
  " plt.close(fig)\n",
395
  "\n",
 
501
  },
502
  {
503
  "cell_type": "code",
504
+ "execution_count": 105,
505
  "metadata": {},
506
+ "outputs": [
507
+ {
508
+ "data": {
509
+ "text/plain": [
510
+ "DatetimeIndex(['2019-07-04 17:33:49', '2019-07-04 17:33:50',\n",
511
+ " '2019-07-04 17:33:51', '2019-07-04 17:33:52',\n",
512
+ " '2019-07-04 17:33:53', '2019-07-04 17:33:54',\n",
513
+ " '2019-07-04 17:33:55', '2019-07-04 17:33:56',\n",
514
+ " '2019-07-04 17:33:57', '2019-07-04 17:33:58',\n",
515
+ " ...\n",
516
+ " '2019-07-04 19:13:39', '2019-07-04 19:13:40',\n",
517
+ " '2019-07-04 19:13:41', '2019-07-04 19:13:42',\n",
518
+ " '2019-07-04 19:13:43', '2019-07-04 19:13:44',\n",
519
+ " '2019-07-04 19:13:45', '2019-07-04 19:13:46',\n",
520
+ " '2019-07-04 19:13:47', '2019-07-04 19:13:48'],\n",
521
+ " dtype='datetime64[ns]', length=6000, freq='S')"
522
+ ]
523
+ },
524
+ "execution_count": 105,
525
+ "metadata": {},
526
+ "output_type": "execute_result"
527
+ }
528
+ ],
529
+ "source": [
530
+ "pd.date_range(start=obspy.UTCDateTime(\"2019-07-04 17:33:49\").timestamp*1e9, periods=6000, freq='s')"
531
+ ]
532
  }
533
  ],
534
  "metadata": {
app.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import gradio as gr
4
  import numpy as np
 
5
  from phasehunter.model import Onset_picker, Updated_onset_picker
6
  from phasehunter.data_preparation import prepare_waveform
7
  import torch
@@ -18,6 +19,7 @@ from obspy.taup.helper_classes import SlownessModelError
18
  from obspy.clients.fdsn.header import URL_MAPPINGS
19
 
20
  import matplotlib.pyplot as plt
 
21
 
22
  def make_prediction(waveform):
23
  waveform = np.load(waveform)
@@ -125,7 +127,10 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
125
  continue
126
 
127
  if len(waveform) == 3:
128
- waveform = prepare_waveform(np.stack([x.data for x in waveform]))
 
 
 
129
 
130
  distances.append(distance)
131
  t0s.append(starttime)
@@ -143,19 +148,31 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
143
  p_phases = output[:, 0]
144
  s_phases = output[:, 1]
145
 
146
-
147
- print(p_phases.shape)
148
- # for i in range(len(waveforms)):
149
- # current_P = P_batch[i::len(waveforms)].cpu()
150
- # current_S_batch = S_batch[i::len(waveforms)].cpu()
151
- # current_Pg_batch = Pg_batch[i::len(waveforms)].cpu()
152
- # current_Sg_batch = Sg_batch[i::len(waveforms)].cpu()
153
- # current_Pn_batch = Pn_batch[i::len(waveforms)].cpu()
154
- # current_Sn_batch = Sn_batch[i::len(waveforms)].cpu()
155
-
156
- fig,ax = plt.subplots()
157
- ax.scatter(st_lats, st_lons)
158
- fig.canvas.draw()
 
 
 
 
 
 
 
 
 
 
 
 
159
  image = np.array(fig.canvas.renderer.buffer_rgba())
160
  plt.close(fig)
161
 
 
2
 
3
  import gradio as gr
4
  import numpy as np
5
+ import pandas as pd
6
  from phasehunter.model import Onset_picker, Updated_onset_picker
7
  from phasehunter.data_preparation import prepare_waveform
8
  import torch
 
19
  from obspy.clients.fdsn.header import URL_MAPPINGS
20
 
21
  import matplotlib.pyplot as plt
22
+ import matplotlib.dates as mdates
23
 
24
  def make_prediction(waveform):
25
  waveform = np.load(waveform)
 
127
  continue
128
 
129
  if len(waveform) == 3:
130
+ try:
131
+ waveform = prepare_waveform(np.stack([x.data for x in waveform]))
132
+ except:
133
+ continue
134
 
135
  distances.append(distance)
136
  t0s.append(starttime)
 
148
  p_phases = output[:, 0]
149
  s_phases = output[:, 1]
150
 
151
+ fig, ax = plt.subplots(nrows=1, figsize=(10, 3), sharex=True)
152
+ for i in range(len(waveforms)):
153
+ current_P = p_phases[i::len(waveforms)]
154
+ current_S = s_phases[i::len(waveforms)]
155
+ x = [t0s[i] + pd.Timedelta(seconds=k/100) for k in np.linspace(0,6000,6000)]
156
+ x = mdates.date2num(x)
157
+ ax.plot(x, waveforms[i][0, 0]+distances[i]*111.2, color='black', alpha=0.5)
158
+ ax.scatter(x[int(current_P.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='r')
159
+ ax.scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b')
160
+ ax.set_ylabel('Z')
161
+
162
+ ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
163
+ ax.xaxis.set_major_locator(mdates.SecondLocator(interval=10))
164
+
165
+ # for a in ax:
166
+ # a.axvline(current_P.mean()*waveforms[i][0].shape[-1], color='r', linestyle='--', label='P')
167
+ # a.axvline(current_S.mean()*waveforms[i][0].shape[-1], color='b', linestyle='--', label='S')
168
+
169
+ # ax[-1].set_xlabel('Time, samples')
170
+ # ax[-1].set_ylabel('Uncert.')
171
+ # ax[-1].legend()
172
+
173
+ plt.subplots_adjust(hspace=0., wspace=0.)
174
+
175
+ fig.canvas.draw();
176
  image = np.array(fig.canvas.renderer.buffer_rgba())
177
  plt.close(fig)
178
 
requirements.txt CHANGED
@@ -12,4 +12,5 @@ torchvision==0.15.1
12
  tqdm==4.65.0
13
  webdataset==0.2.48
14
  obspy
 
15
  git+http://github.com/nikitadurasov/masksembles
 
12
  tqdm==4.65.0
13
  webdataset==0.2.48
14
  obspy
15
+ pandas
16
  git+http://github.com/nikitadurasov/masksembles