crimeacs commited on
Commit
eeca930
·
1 Parent(s): 82bf5ac

Updated layout

Browse files
Files changed (5) hide show
  1. .DS_Store +0 -0
  2. Gradio_app.ipynb +140 -98
  3. app.py +85 -83
  4. phasehunter/model.py +0 -313
  5. phasehunter/training.py +0 -104
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
Gradio_app.ipynb CHANGED
@@ -2,29 +2,14 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 5,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "model = Onset_picker.load_from_checkpoint(\"./weights.ckpt\",\n",
10
- " picker=Updated_onset_picker(),\n",
11
- " learning_rate=3e-4)\n",
12
- "model.eval()\n",
13
- "model.freeze()\n",
14
- "script = model.to_torchscript()\n",
15
- "torch.jit.save(script, \"model.pt\")"
16
- ]
17
- },
18
- {
19
- "cell_type": "code",
20
- "execution_count": 32,
21
  "metadata": {},
22
  "outputs": [
23
  {
24
  "name": "stdout",
25
  "output_type": "stream",
26
  "text": [
27
- "Running on local URL: http://127.0.0.1:7878\n",
28
  "\n",
29
  "To create a public link, set `share=True` in `launch()`.\n"
30
  ]
@@ -32,7 +17,7 @@
32
  {
33
  "data": {
34
  "text/html": [
35
- "<div><iframe src=\"http://127.0.0.1:7878/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
36
  ],
37
  "text/plain": [
38
  "<IPython.core.display.HTML object>"
@@ -45,7 +30,7 @@
45
  "data": {
46
  "text/plain": []
47
  },
48
- "execution_count": 32,
49
  "metadata": {},
50
  "output_type": "execute_result"
51
  },
@@ -116,13 +101,69 @@
116
  "name": "stderr",
117
  "output_type": "stream",
118
  "text": [
119
- "/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/2440224661.py:224: FutureWarning: The input object of type 'Tensor' is an array-like implementing one of the corresponding protocols (`__array__`, `__array_interface__` or `__array_struct__`); but not a sequence (or 0-D). In the future, this object will be coerced as if it was first converted using `np.array(obj)`. To retain the old behaviour, you have to either modify the type 'Tensor', or assign to an empty array created with `np.empty(correct_shape, dtype=object)`.\n",
120
  " waveforms = np.array(waveforms)[selection_indexes]\n",
121
- "/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/2440224661.py:224: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
122
  " waveforms = np.array(waveforms)[selection_indexes]\n",
123
- "/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/2440224661.py:231: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
124
  " waveforms = [torch.tensor(waveform) for waveform in waveforms]\n"
125
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  }
127
  ],
128
  "source": [
@@ -149,7 +190,7 @@
149
  "\n",
150
  "import matplotlib.pyplot as plt\n",
151
  "import matplotlib.dates as mdates\n",
152
- "from matplotlib.colors import LightSource\n",
153
  "\n",
154
  "from glob import glob\n",
155
  "\n",
@@ -309,8 +350,8 @@
309
  " \n",
310
  " waveform = waveform.select(channel=\"H[BH][ZNE]\")\n",
311
  " waveform = waveform.merge(fill_value=0)\n",
312
- " waveform = waveform[:3]\n",
313
- " \n",
314
  " len_check = [len(x.data) for x in waveform]\n",
315
  " if len(set(len_check)) > 1:\n",
316
  " continue\n",
@@ -371,8 +412,8 @@
371
  " s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))])\n",
372
  "\n",
373
  " print(f\"Starting plotting {len(waveforms)} waveforms\")\n",
374
- " fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 3))\n",
375
- "\n",
376
  " # Plot topography\n",
377
  " print('Fetching topography')\n",
378
  " params = Topography.DEFAULT.copy()\n",
@@ -417,9 +458,6 @@
417
  " ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|')\n",
418
  " ax[0].set_ylabel('Z')\n",
419
  "\n",
420
- " ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))\n",
421
- " ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=20))\n",
422
- "\n",
423
  " delta_t = t0s[i].timestamp - obspy.UTCDateTime(timestamp).timestamp\n",
424
  "\n",
425
  " velocity_p = (distances[i]*111.2)/(delta_t+current_P.mean()*60).item()\n",
@@ -437,30 +475,37 @@
437
  " y = np.linspace(st_lats[i], eq_lat, 50)\n",
438
  " \n",
439
  " # Plot the array\n",
440
- " ax[1].scatter(x, y, c=np.zeros_like(x)+velocity_p, alpha=0.5, vmin=0, vmax=8)\n",
441
- " ax[2].scatter(x, y, c=np.zeros_like(x)+velocity_s, alpha=0.5, vmin=0, vmax=8)\n",
442
  "\n",
443
  " # Add legend\n",
444
  " ax[0].scatter(None, None, color='r', marker='|', label='P')\n",
445
  " ax[0].scatter(None, None, color='b', marker='|', label='S')\n",
 
 
446
  " ax[0].legend()\n",
447
  "\n",
448
  " print('Plotting stations')\n",
449
  " for i in range(1,3):\n",
450
  " ax[i].scatter(st_lons, st_lats, color='b', label='Stations')\n",
451
  " ax[i].scatter(eq_lon, eq_lat, color='r', marker='*', label='Earthquake')\n",
 
 
452
  "\n",
453
- " # Generate colorbar for the velocity plot\n",
454
- " cbar = plt.colorbar(ax[1].scatter(None, None, c=velocity_p, alpha=0.5, vmin=0, vmax=8), ax=ax[1])\n",
455
- " cbar.set_label('P Velocity (km/s)')\n",
456
- " ax[1].set_title('P Velocity')\n",
 
457
  "\n",
458
- " cbar = plt.colorbar(ax[2].scatter(None, None, c=velocity_s, alpha=0.5, vmin=0, vmax=8), ax=ax[2])\n",
459
- " cbar.set_label('S Velocity (km/s)')\n",
460
  " ax[2].set_title('S Velocity')\n",
461
  "\n",
 
 
 
462
  " plt.subplots_adjust(hspace=0., wspace=0.5)\n",
463
- "\n",
464
  " fig.canvas.draw();\n",
465
  " image = np.array(fig.canvas.renderer.buffer_rgba())\n",
466
  " plt.close(fig)\n",
@@ -482,7 +527,6 @@
482
  " }\n",
483
  "</style></h1> \n",
484
  " \n",
485
- "\n",
486
  " <p style=\"font-size: 16px; margin-bottom: 20px;\">Detect <span style=\"background-image: linear-gradient(to right, #ED213A, #93291E); \n",
487
  " -webkit-background-clip: text;\n",
488
  " -webkit-text-fill-color: transparent;\n",
@@ -531,68 +575,66 @@
531
  " </div>\n",
532
  " \"\"\")\n",
533
  " with gr.Row(): \n",
534
- " client_inputs = gr.Dropdown(\n",
535
- " choices = list(URL_MAPPINGS.keys()), \n",
536
- " label=\"FDSN Client\", \n",
537
- " info=\"Select one of the available FDSN clients\",\n",
538
- " value = \"IRIS\",\n",
539
- " interactive=True\n",
540
- " )\n",
541
- "\n",
542
- " velocity_inputs = gr.Dropdown(\n",
543
- " choices = ['1066a', '1066b', 'ak135', \n",
544
- " 'ak135f', 'herrin', 'iasp91', \n",
545
- " 'jb', 'prem', 'pwdk'], \n",
546
- " label=\"1D velocity model\", \n",
547
- " info=\"Velocity model for station selection\",\n",
548
- " value = \"1066a\",\n",
549
- " interactive=True\n",
550
- " )\n",
 
551
  "\n",
552
- " with gr.Column(scale=4):\n",
553
- " with gr.Row(): \n",
554
- " timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49',\n",
555
- " placeholder='YYYY-MM-DD HH:MM:SS',\n",
556
- " label=\"Timestamp\",\n",
557
- " info=\"Timestamp of the earthquake\",\n",
558
- " max_lines=1,\n",
559
- " interactive=True)\n",
560
- " \n",
561
- " eq_lat_inputs = gr.Number(value=35.766, \n",
562
- " label=\"Latitude\", \n",
563
- " info=\"Latitude of the earthquake\",\n",
 
 
 
 
 
 
 
 
 
 
564
  " interactive=True)\n",
565
- " \n",
566
- " eq_lon_inputs = gr.Number(value=-117.605,\n",
567
- " label=\"Longitude\",\n",
568
- " info=\"Longitude of the earthquake\",\n",
569
- " interactive=True)\n",
570
- " \n",
571
- " source_depth_inputs = gr.Number(value=10,\n",
572
- " label=\"Source depth (km)\",\n",
573
- " info=\"Depth of the earthquake\",\n",
574
- " interactive=True)\n",
575
  " \n",
576
- "\n",
577
- " \n",
578
  " with gr.Column(scale=2):\n",
579
- " with gr.Row(): \n",
580
- " radius_inputs = gr.Slider(minimum=1, \n",
581
- " maximum=150, \n",
582
- " value=50, label=\"Radius (km)\", \n",
583
- " step=10,\n",
584
- " info=\"\"\"Select the radius around the earthquake to download data from.\\n \n",
585
- " Note that the larger the radius, the longer the app will take to run.\"\"\",\n",
586
- " interactive=True)\n",
587
- " \n",
588
- " max_waveforms_inputs = gr.Slider(minimum=1,\n",
589
- " maximum=100,\n",
590
- " value=10,\n",
591
- " label=\"Max waveforms per section\",\n",
592
- " step=1,\n",
593
- " info=\"Maximum number of waveforms to show per section\\n (to avoid long prediction times)\",\n",
594
- " interactive=True,\n",
595
- " )\n",
596
  " \n",
597
  " button = gr.Button(\"Predict phases\")\n",
598
  " output_image = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False)\n",
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 51,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  "metadata": {},
7
  "outputs": [
8
  {
9
  "name": "stdout",
10
  "output_type": "stream",
11
  "text": [
12
+ "Running on local URL: http://127.0.0.1:7897\n",
13
  "\n",
14
  "To create a public link, set `share=True` in `launch()`.\n"
15
  ]
 
17
  {
18
  "data": {
19
  "text/html": [
20
+ "<div><iframe src=\"http://127.0.0.1:7897/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
21
  ],
22
  "text/plain": [
23
  "<IPython.core.display.HTML object>"
 
30
  "data": {
31
  "text/plain": []
32
  },
33
+ "execution_count": 51,
34
  "metadata": {},
35
  "output_type": "execute_result"
36
  },
 
101
  "name": "stderr",
102
  "output_type": "stream",
103
  "text": [
104
+ "/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:224: FutureWarning: The input object of type 'Tensor' is an array-like implementing one of the corresponding protocols (`__array__`, `__array_interface__` or `__array_struct__`); but not a sequence (or 0-D). In the future, this object will be coerced as if it was first converted using `np.array(obj)`. To retain the old behaviour, you have to either modify the type 'Tensor', or assign to an empty array created with `np.empty(correct_shape, dtype=object)`.\n",
105
  " waveforms = np.array(waveforms)[selection_indexes]\n",
106
+ "/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:224: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
107
  " waveforms = np.array(waveforms)[selection_indexes]\n",
108
+ "/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:231: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
109
  " waveforms = [torch.tensor(waveform) for waveform in waveforms]\n"
110
  ]
111
+ },
112
+ {
113
+ "name": "stdout",
114
+ "output_type": "stream",
115
+ "text": [
116
+ "Starting plotting 3 waveforms\n",
117
+ "Fetching topography\n",
118
+ "Plotting topo\n"
119
+ ]
120
+ },
121
+ {
122
+ "name": "stderr",
123
+ "output_type": "stream",
124
+ "text": [
125
+ "/Users/anovosel/miniconda3/envs/phasehunter/lib/python3.11/site-packages/bmi_topography/api_key.py:49: UserWarning: You are using a demo key to fetch data from OpenTopography, functionality will be limited. See https://bmi-topography.readthedocs.io/en/latest/#api-key for more information.\n",
126
+ " warnings.warn(\n"
127
+ ]
128
+ },
129
+ {
130
+ "name": "stdout",
131
+ "output_type": "stream",
132
+ "text": [
133
+ "Plotting waveform 1/3\n",
134
+ "Station 36.11758, -117.85486 has P velocity 4.987805380766392 and S velocity 2.9782985042350987\n",
135
+ "Plotting waveform 2/3\n"
136
+ ]
137
+ },
138
+ {
139
+ "name": "stderr",
140
+ "output_type": "stream",
141
+ "text": [
142
+ "/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:299: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
143
+ " output_picks = output_picks.append(pd.DataFrame({'station_name': [names[i]], 'starttime' : [str(t0s[i])],\n",
144
+ "/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:299: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
145
+ " output_picks = output_picks.append(pd.DataFrame({'station_name': [names[i]], 'starttime' : [str(t0s[i])],\n"
146
+ ]
147
+ },
148
+ {
149
+ "name": "stdout",
150
+ "output_type": "stream",
151
+ "text": [
152
+ "Station 35.98249, -117.80885 has P velocity 4.255522557803516 and S velocity 2.2929437916670583\n",
153
+ "Plotting waveform 3/3\n",
154
+ "Station 35.69235, -117.75051 has P velocity 2.979034174961547 and S velocity 1.3728192788753049\n",
155
+ "Plotting stations\n"
156
+ ]
157
+ },
158
+ {
159
+ "name": "stderr",
160
+ "output_type": "stream",
161
+ "text": [
162
+ "/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:299: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n",
163
+ " output_picks = output_picks.append(pd.DataFrame({'station_name': [names[i]], 'starttime' : [str(t0s[i])],\n",
164
+ "/var/folders/_g/3q5q8_dj0ydcpktxlwxb5vrh0000gq/T/ipykernel_27324/1938231065.py:324: UserWarning: FixedFormatter should only be used together with FixedLocator\n",
165
+ " ax[i].set_xticklabels(ax[i].get_xticks(), rotation = 50)\n"
166
+ ]
167
  }
168
  ],
169
  "source": [
 
190
  "\n",
191
  "import matplotlib.pyplot as plt\n",
192
  "import matplotlib.dates as mdates\n",
193
+ "from mpl_toolkits.axes_grid1 import ImageGrid\n",
194
  "\n",
195
  "from glob import glob\n",
196
  "\n",
 
350
  " \n",
351
  " waveform = waveform.select(channel=\"H[BH][ZNE]\")\n",
352
  " waveform = waveform.merge(fill_value=0)\n",
353
+ " waveform = waveform[:3].sort(keys=['channel'], reverse=True)\n",
354
+ "\n",
355
  " len_check = [len(x.data) for x in waveform]\n",
356
  " if len(set(len_check)) > 1:\n",
357
  " continue\n",
 
412
  " s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))])\n",
413
  "\n",
414
  " print(f\"Starting plotting {len(waveforms)} waveforms\")\n",
415
+ " fig, ax = plt.subplots(ncols=3, figsize=(10, 3))\n",
416
+ " \n",
417
  " # Plot topography\n",
418
  " print('Fetching topography')\n",
419
  " params = Topography.DEFAULT.copy()\n",
 
458
  " ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|')\n",
459
  " ax[0].set_ylabel('Z')\n",
460
  "\n",
 
 
 
461
  " delta_t = t0s[i].timestamp - obspy.UTCDateTime(timestamp).timestamp\n",
462
  "\n",
463
  " velocity_p = (distances[i]*111.2)/(delta_t+current_P.mean()*60).item()\n",
 
475
  " y = np.linspace(st_lats[i], eq_lat, 50)\n",
476
  " \n",
477
  " # Plot the array\n",
478
+ " ax[1].scatter(x, y, c=np.zeros_like(x)+velocity_p, alpha=0.1, vmin=0, vmax=8)\n",
479
+ " ax[2].scatter(x, y, c=np.zeros_like(x)+velocity_s, alpha=0.1, vmin=0, vmax=8)\n",
480
  "\n",
481
  " # Add legend\n",
482
  " ax[0].scatter(None, None, color='r', marker='|', label='P')\n",
483
  " ax[0].scatter(None, None, color='b', marker='|', label='S')\n",
484
+ " ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))\n",
485
+ " ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=20))\n",
486
  " ax[0].legend()\n",
487
  "\n",
488
  " print('Plotting stations')\n",
489
  " for i in range(1,3):\n",
490
  " ax[i].scatter(st_lons, st_lats, color='b', label='Stations')\n",
491
  " ax[i].scatter(eq_lon, eq_lat, color='r', marker='*', label='Earthquake')\n",
492
+ " ax[i].set_aspect('equal')\n",
493
+ " ax[i].set_xticklabels(ax[i].get_xticks(), rotation = 50)\n",
494
  "\n",
495
+ " fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.8,\n",
496
+ " wspace=0.02, hspace=0.02)\n",
497
+ " \n",
498
+ " cb_ax = fig.add_axes([0.83, 0.1, 0.02, 0.8])\n",
499
+ " cbar = fig.colorbar(ax[2].scatter(None, None, c=velocity_p, alpha=0.5, vmin=0, vmax=8), cax=cb_ax)\n",
500
  "\n",
501
+ " cbar.set_label('Velocity (km/s)')\n",
502
+ " ax[1].set_title('P Velocity')\n",
503
  " ax[2].set_title('S Velocity')\n",
504
  "\n",
505
+ " for a in ax:\n",
506
+ " a.tick_params(axis='both', which='major', labelsize=8)\n",
507
+ " \n",
508
  " plt.subplots_adjust(hspace=0., wspace=0.5)\n",
 
509
  " fig.canvas.draw();\n",
510
  " image = np.array(fig.canvas.renderer.buffer_rgba())\n",
511
  " plt.close(fig)\n",
 
527
  " }\n",
528
  "</style></h1> \n",
529
  " \n",
 
530
  " <p style=\"font-size: 16px; margin-bottom: 20px;\">Detect <span style=\"background-image: linear-gradient(to right, #ED213A, #93291E); \n",
531
  " -webkit-background-clip: text;\n",
532
  " -webkit-text-fill-color: transparent;\n",
 
575
  " </div>\n",
576
  " \"\"\")\n",
577
  " with gr.Row(): \n",
578
+ " with gr.Column(scale=2):\n",
579
+ " client_inputs = gr.Dropdown(\n",
580
+ " choices = list(URL_MAPPINGS.keys()), \n",
581
+ " label=\"FDSN Client\", \n",
582
+ " info=\"Select one of the available FDSN clients\",\n",
583
+ " value = \"IRIS\",\n",
584
+ " interactive=True\n",
585
+ " )\n",
586
+ "\n",
587
+ " velocity_inputs = gr.Dropdown(\n",
588
+ " choices = ['1066a', '1066b', 'ak135', \n",
589
+ " 'ak135f', 'herrin', 'iasp91', \n",
590
+ " 'jb', 'prem', 'pwdk'], \n",
591
+ " label=\"1D velocity model\", \n",
592
+ " info=\"Velocity model for station selection\",\n",
593
+ " value = \"1066a\",\n",
594
+ " interactive=True\n",
595
+ " )\n",
596
  "\n",
597
+ " with gr.Column(scale=2):\n",
598
+ " timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49',\n",
599
+ " placeholder='YYYY-MM-DD HH:MM:SS',\n",
600
+ " label=\"Timestamp\",\n",
601
+ " info=\"Timestamp of the earthquake\",\n",
602
+ " max_lines=1,\n",
603
+ " interactive=True)\n",
604
+ " \n",
605
+ " source_depth_inputs = gr.Number(value=10,\n",
606
+ " label=\"Source depth (km)\",\n",
607
+ " info=\"Depth of the earthquake\",\n",
608
+ " interactive=True)\n",
609
+ " \n",
610
+ " with gr.Column(scale=2):\n",
611
+ " eq_lat_inputs = gr.Number(value=35.766, \n",
612
+ " label=\"Latitude\", \n",
613
+ " info=\"Latitude of the earthquake\",\n",
614
+ " interactive=True)\n",
615
+ " \n",
616
+ " eq_lon_inputs = gr.Number(value=-117.605,\n",
617
+ " label=\"Longitude\",\n",
618
+ " info=\"Longitude of the earthquake\",\n",
619
  " interactive=True)\n",
 
 
 
 
 
 
 
 
 
 
620
  " \n",
 
 
621
  " with gr.Column(scale=2):\n",
622
+ " radius_inputs = gr.Slider(minimum=1, \n",
623
+ " maximum=200, \n",
624
+ " value=50, label=\"Radius (km)\", \n",
625
+ " step=10,\n",
626
+ " info=\"\"\"Select the radius around the earthquake to download data from.\\n \n",
627
+ " Note that the larger the radius, the longer the app will take to run.\"\"\",\n",
628
+ " interactive=True)\n",
629
+ " \n",
630
+ " max_waveforms_inputs = gr.Slider(minimum=1,\n",
631
+ " maximum=100,\n",
632
+ " value=10,\n",
633
+ " label=\"Max waveforms per section\",\n",
634
+ " step=1,\n",
635
+ " info=\"Maximum number of waveforms to show per section\\n (to avoid long prediction times)\",\n",
636
+ " interactive=True,\n",
637
+ " )\n",
 
638
  " \n",
639
  " button = gr.Button(\"Predict phases\")\n",
640
  " output_image = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False)\n",
app.py CHANGED
@@ -21,7 +21,7 @@ from obspy.clients.fdsn.header import URL_MAPPINGS
21
 
22
  import matplotlib.pyplot as plt
23
  import matplotlib.dates as mdates
24
- from matplotlib.colors import LightSource
25
 
26
  from glob import glob
27
 
@@ -181,8 +181,8 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
181
 
182
  waveform = waveform.select(channel="H[BH][ZNE]")
183
  waveform = waveform.merge(fill_value=0)
184
- waveform = waveform[:3]
185
-
186
  len_check = [len(x.data) for x in waveform]
187
  if len(set(len_check)) > 1:
188
  continue
@@ -243,8 +243,8 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
243
  s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))])
244
 
245
  print(f"Starting plotting {len(waveforms)} waveforms")
246
- fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 3))
247
-
248
  # Plot topography
249
  print('Fetching topography')
250
  params = Topography.DEFAULT.copy()
@@ -289,9 +289,6 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
289
  ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|')
290
  ax[0].set_ylabel('Z')
291
 
292
- ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
293
- ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=20))
294
-
295
  delta_t = t0s[i].timestamp - obspy.UTCDateTime(timestamp).timestamp
296
 
297
  velocity_p = (distances[i]*111.2)/(delta_t+current_P.mean()*60).item()
@@ -309,30 +306,37 @@ def predict_on_section(client_name, timestamp, eq_lat, eq_lon, radius_km, source
309
  y = np.linspace(st_lats[i], eq_lat, 50)
310
 
311
  # Plot the array
312
- ax[1].scatter(x, y, c=np.zeros_like(x)+velocity_p, alpha=0.5, vmin=0, vmax=8)
313
- ax[2].scatter(x, y, c=np.zeros_like(x)+velocity_s, alpha=0.5, vmin=0, vmax=8)
314
 
315
  # Add legend
316
  ax[0].scatter(None, None, color='r', marker='|', label='P')
317
  ax[0].scatter(None, None, color='b', marker='|', label='S')
 
 
318
  ax[0].legend()
319
 
320
  print('Plotting stations')
321
  for i in range(1,3):
322
  ax[i].scatter(st_lons, st_lats, color='b', label='Stations')
323
  ax[i].scatter(eq_lon, eq_lat, color='r', marker='*', label='Earthquake')
 
 
324
 
325
- # Generate colorbar for the velocity plot
326
- cbar = plt.colorbar(ax[1].scatter(None, None, c=velocity_p, alpha=0.5, vmin=0, vmax=8), ax=ax[1])
327
- cbar.set_label('P Velocity (km/s)')
328
- ax[1].set_title('P Velocity')
 
329
 
330
- cbar = plt.colorbar(ax[2].scatter(None, None, c=velocity_s, alpha=0.5, vmin=0, vmax=8), ax=ax[2])
331
- cbar.set_label('S Velocity (km/s)')
332
  ax[2].set_title('S Velocity')
333
 
 
 
 
334
  plt.subplots_adjust(hspace=0., wspace=0.5)
335
-
336
  fig.canvas.draw();
337
  image = np.array(fig.canvas.renderer.buffer_rgba())
338
  plt.close(fig)
@@ -354,7 +358,6 @@ with gr.Blocks() as demo:
354
  }
355
  </style></h1>
356
 
357
-
358
  <p style="font-size: 16px; margin-bottom: 20px;">Detect <span style="background-image: linear-gradient(to right, #ED213A, #93291E);
359
  -webkit-background-clip: text;
360
  -webkit-text-fill-color: transparent;
@@ -393,77 +396,76 @@ with gr.Blocks() as demo:
393
  button.click(mark_phases, inputs=[inputs, upload], outputs=outputs)
394
 
395
  with gr.Tab("Select earthquake from catalogue"):
396
- gr.Markdown("""
397
- Select an earthquake from the global earthquake catalogue and the app will download the waveform from the FDSN client of your choice.
398
- The app will use a velocity model of your choice to select appropriate time windows for each station within specify radius of the earthquake.
399
- The app will then analyze the waveforms and mark the detected phases on the waveform.
400
- Pick data for each waveform is reported in seconds from the start of the waveform.
401
- Velocities are derived from distance and travel time determined by PhaseHunter picks ($v = \mathrm{distance}/\mathrm{predicted_pick_time}$).
402
- Backround of velocity plot is colored by DEM.
 
403
  """)
404
  with gr.Row():
405
- client_inputs = gr.Dropdown(
406
- choices = list(URL_MAPPINGS.keys()),
407
- label="FDSN Client",
408
- info="Select one of the available FDSN clients",
409
- value = "IRIS",
410
- interactive=True
411
- )
412
-
413
- velocity_inputs = gr.Dropdown(
414
- choices = ['1066a', '1066b', 'ak135',
415
- 'ak135f', 'herrin', 'iasp91',
416
- 'jb', 'prem', 'pwdk'],
417
- label="1D velocity model",
418
- info="Velocity model for station selection",
419
- value = "1066a",
420
- interactive=True
421
- )
 
422
 
423
- with gr.Column(scale=4):
424
- with gr.Row():
425
- timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49',
426
- placeholder='YYYY-MM-DD HH:MM:SS',
427
- label="Timestamp",
428
- info="Timestamp of the earthquake",
429
- max_lines=1,
430
- interactive=True)
431
-
432
- eq_lat_inputs = gr.Number(value=35.766,
433
- label="Latitude",
434
- info="Latitude of the earthquake",
435
  interactive=True)
436
-
437
- eq_lon_inputs = gr.Number(value=-117.605,
438
- label="Longitude",
439
- info="Longitude of the earthquake",
440
- interactive=True)
441
-
442
- source_depth_inputs = gr.Number(value=10,
443
- label="Source depth (km)",
444
- info="Depth of the earthquake",
445
- interactive=True)
446
 
447
-
448
-
 
 
 
449
  with gr.Column(scale=2):
450
- with gr.Row():
451
- radius_inputs = gr.Slider(minimum=1,
452
- maximum=150,
453
- value=50, label="Radius (km)",
454
- step=10,
455
- info="""Select the radius around the earthquake to download data from.\n
456
- Note that the larger the radius, the longer the app will take to run.""",
457
- interactive=True)
458
-
459
- max_waveforms_inputs = gr.Slider(minimum=1,
460
- maximum=100,
461
- value=10,
462
- label="Max waveforms per section",
463
- step=1,
464
- info="Maximum number of waveforms to show per section\n (to avoid long prediction times)",
465
- interactive=True,
466
- )
 
 
 
 
 
 
 
 
 
 
467
 
468
  button = gr.Button("Predict phases")
469
  output_image = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False)
 
21
 
22
  import matplotlib.pyplot as plt
23
  import matplotlib.dates as mdates
24
+ from mpl_toolkits.axes_grid1 import ImageGrid
25
 
26
  from glob import glob
27
 
 
181
 
182
  waveform = waveform.select(channel="H[BH][ZNE]")
183
  waveform = waveform.merge(fill_value=0)
184
+ waveform = waveform[:3].sort(keys=['channel'], reverse=True)
185
+
186
  len_check = [len(x.data) for x in waveform]
187
  if len(set(len_check)) > 1:
188
  continue
 
243
  s_max_confidence = np.min([s_phases[i::len(waveforms)].std() for i in range(len(waveforms))])
244
 
245
  print(f"Starting plotting {len(waveforms)} waveforms")
246
+ fig, ax = plt.subplots(ncols=3, figsize=(10, 3))
247
+
248
  # Plot topography
249
  print('Fetching topography')
250
  params = Topography.DEFAULT.copy()
 
289
  ax[0].scatter(x[int(current_S.mean()*waveforms[i][0].shape[-1])], waveforms[i][0, 0].mean()+distances[i]*111.2, color='b', alpha=s_conf, marker='|')
290
  ax[0].set_ylabel('Z')
291
 
 
 
 
292
  delta_t = t0s[i].timestamp - obspy.UTCDateTime(timestamp).timestamp
293
 
294
  velocity_p = (distances[i]*111.2)/(delta_t+current_P.mean()*60).item()
 
306
  y = np.linspace(st_lats[i], eq_lat, 50)
307
 
308
  # Plot the array
309
+ ax[1].scatter(x, y, c=np.zeros_like(x)+velocity_p, alpha=0.1, vmin=0, vmax=8)
310
+ ax[2].scatter(x, y, c=np.zeros_like(x)+velocity_s, alpha=0.1, vmin=0, vmax=8)
311
 
312
  # Add legend
313
  ax[0].scatter(None, None, color='r', marker='|', label='P')
314
  ax[0].scatter(None, None, color='b', marker='|', label='S')
315
+ ax[0].xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
316
+ ax[0].xaxis.set_major_locator(mdates.SecondLocator(interval=20))
317
  ax[0].legend()
318
 
319
  print('Plotting stations')
320
  for i in range(1,3):
321
  ax[i].scatter(st_lons, st_lats, color='b', label='Stations')
322
  ax[i].scatter(eq_lon, eq_lat, color='r', marker='*', label='Earthquake')
323
+ ax[i].set_aspect('equal')
324
+ ax[i].set_xticklabels(ax[i].get_xticks(), rotation = 50)
325
 
326
+ fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.8,
327
+ wspace=0.02, hspace=0.02)
328
+
329
+ cb_ax = fig.add_axes([0.83, 0.1, 0.02, 0.8])
330
+ cbar = fig.colorbar(ax[2].scatter(None, None, c=velocity_p, alpha=0.5, vmin=0, vmax=8), cax=cb_ax)
331
 
332
+ cbar.set_label('Velocity (km/s)')
333
+ ax[1].set_title('P Velocity')
334
  ax[2].set_title('S Velocity')
335
 
336
+ for a in ax:
337
+ a.tick_params(axis='both', which='major', labelsize=8)
338
+
339
  plt.subplots_adjust(hspace=0., wspace=0.5)
 
340
  fig.canvas.draw();
341
  image = np.array(fig.canvas.renderer.buffer_rgba())
342
  plt.close(fig)
 
358
  }
359
  </style></h1>
360
 
 
361
  <p style="font-size: 16px; margin-bottom: 20px;">Detect <span style="background-image: linear-gradient(to right, #ED213A, #93291E);
362
  -webkit-background-clip: text;
363
  -webkit-text-fill-color: transparent;
 
396
  button.click(mark_phases, inputs=[inputs, upload], outputs=outputs)
397
 
398
  with gr.Tab("Select earthquake from catalogue"):
399
+
400
+ gr.HTML("""
401
+ <div style="padding: 20px; border-radius: 10px; font-size: 16px;">
402
+ <p style="font-weight: bold; font-size: 24px; margin-bottom: 20px;">Using PhaseHunter to Analyze Seismic Waveforms</p>
403
+ <p>Select an earthquake from the global earthquake catalogue and the app will download the waveform from the FDSN client of your choice. The app will use a velocity model of your choice to select appropriate time windows for each station within a specified radius of the earthquake.</p>
404
+ <p>The app will then analyze the waveforms and mark the detected phases on the waveform. Pick data for each waveform is reported in seconds from the start of the waveform.</p>
405
+ <p>Velocities are derived from distance and travel time determined by PhaseHunter picks (<span style="font-style: italic;">v = distance/predicted_pick_time</span>). The background of the velocity plot is colored by DEM.</p>
406
+ </div>
407
  """)
408
  with gr.Row():
409
+ with gr.Column(scale=2):
410
+ client_inputs = gr.Dropdown(
411
+ choices = list(URL_MAPPINGS.keys()),
412
+ label="FDSN Client",
413
+ info="Select one of the available FDSN clients",
414
+ value = "IRIS",
415
+ interactive=True
416
+ )
417
+
418
+ velocity_inputs = gr.Dropdown(
419
+ choices = ['1066a', '1066b', 'ak135',
420
+ 'ak135f', 'herrin', 'iasp91',
421
+ 'jb', 'prem', 'pwdk'],
422
+ label="1D velocity model",
423
+ info="Velocity model for station selection",
424
+ value = "1066a",
425
+ interactive=True
426
+ )
427
 
428
+ with gr.Column(scale=2):
429
+ timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49',
430
+ placeholder='YYYY-MM-DD HH:MM:SS',
431
+ label="Timestamp",
432
+ info="Timestamp of the earthquake",
433
+ max_lines=1,
 
 
 
 
 
 
434
  interactive=True)
 
 
 
 
 
 
 
 
 
 
435
 
436
+ source_depth_inputs = gr.Number(value=10,
437
+ label="Source depth (km)",
438
+ info="Depth of the earthquake",
439
+ interactive=True)
440
+
441
  with gr.Column(scale=2):
442
+ eq_lat_inputs = gr.Number(value=35.766,
443
+ label="Latitude",
444
+ info="Latitude of the earthquake",
445
+ interactive=True)
446
+
447
+ eq_lon_inputs = gr.Number(value=-117.605,
448
+ label="Longitude",
449
+ info="Longitude of the earthquake",
450
+ interactive=True)
451
+
452
+ with gr.Column(scale=2):
453
+ radius_inputs = gr.Slider(minimum=1,
454
+ maximum=200,
455
+ value=50, label="Radius (km)",
456
+ step=10,
457
+ info="""Select the radius around the earthquake to download data from.\n
458
+ Note that the larger the radius, the longer the app will take to run.""",
459
+ interactive=True)
460
+
461
+ max_waveforms_inputs = gr.Slider(minimum=1,
462
+ maximum=100,
463
+ value=10,
464
+ label="Max waveforms per section",
465
+ step=1,
466
+ info="Maximum number of waveforms to show per section\n (to avoid long prediction times)",
467
+ interactive=True,
468
+ )
469
 
470
  button = gr.Button("Predict phases")
471
  output_image = gr.Image(label='Waveforms with Phases Marked', type='numpy', interactive=False)
phasehunter/model.py DELETED
@@ -1,313 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
- from torch import nn
5
- from torchmetrics import MeanAbsoluteError
6
- from torch.optim.lr_scheduler import ReduceLROnPlateau
7
-
8
- import lightning as pl
9
-
10
- class BlurPool1D(nn.Module):
11
- def __init__(self, channels, pad_type="reflect", filt_size=3, stride=2, pad_off=0):
12
- super(BlurPool1D, self).__init__()
13
- self.filt_size = filt_size
14
- self.pad_off = pad_off
15
- self.pad_sizes = [
16
- int(1.0 * (filt_size - 1) / 2),
17
- int(np.ceil(1.0 * (filt_size - 1) / 2)),
18
- ]
19
- self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
20
- self.stride = stride
21
- self.off = int((self.stride - 1) / 2.0)
22
- self.channels = channels
23
-
24
- # print('Filter size [%i]' % filt_size)
25
- if self.filt_size == 1:
26
- a = np.array(
27
- [
28
- 1.0,
29
- ]
30
- )
31
- elif self.filt_size == 2:
32
- a = np.array([1.0, 1.0])
33
- elif self.filt_size == 3:
34
- a = np.array([1.0, 2.0, 1.0])
35
- elif self.filt_size == 4:
36
- a = np.array([1.0, 3.0, 3.0, 1.0])
37
- elif self.filt_size == 5:
38
- a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
39
- elif self.filt_size == 6:
40
- a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
41
- elif self.filt_size == 7:
42
- a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
43
-
44
- filt = torch.Tensor(a)
45
- filt = filt / torch.sum(filt)
46
- self.register_buffer("filt", filt[None, None, :].repeat((self.channels, 1, 1)))
47
-
48
- self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
49
-
50
- def forward(self, inp):
51
- if self.filt_size == 1:
52
- if self.pad_off == 0:
53
- return inp[:, :, :: self.stride]
54
- else:
55
- return self.pad(inp)[:, :, :: self.stride]
56
- else:
57
- return F.conv1d(
58
- self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]
59
- )
60
-
61
-
62
- def get_pad_layer_1d(pad_type):
63
- if pad_type in ["refl", "reflect"]:
64
- PadLayer = nn.ReflectionPad1d
65
- elif pad_type in ["repl", "replicate"]:
66
- PadLayer = nn.ReplicationPad1d
67
- elif pad_type == "zero":
68
- PadLayer = nn.ZeroPad1d
69
- else:
70
- print("Pad type [%s] not recognized" % pad_type)
71
- return PadLayer
72
-
73
-
74
- from masksembles import common
75
-
76
-
77
- class Masksembles1D(nn.Module):
78
- def __init__(self, channels: int, n: int, scale: float):
79
- super().__init__()
80
-
81
- self.channels = channels
82
- self.n = n
83
- self.scale = scale
84
-
85
- masks = common.generation_wrapper(channels, n, scale)
86
- masks = torch.from_numpy(masks)
87
-
88
- self.masks = torch.nn.Parameter(masks, requires_grad=False)
89
-
90
- def forward(self, inputs):
91
- batch = inputs.shape[0]
92
- x = torch.split(inputs.unsqueeze(1), batch // self.n, dim=0)
93
- x = torch.cat(x, dim=1).permute([1, 0, 2, 3])
94
- x = x * self.masks.unsqueeze(1).unsqueeze(-1)
95
- x = torch.cat(torch.split(x, 1, dim=0), dim=1)
96
-
97
- return x.squeeze(0).type(inputs.dtype)
98
-
99
-
100
- class BasicBlock(nn.Module):
101
- expansion = 1
102
-
103
- def __init__(self, in_planes, planes, stride=1, kernel_size=7, groups=1):
104
- super(BasicBlock, self).__init__()
105
- self.conv1 = nn.Conv1d(
106
- in_planes,
107
- planes,
108
- kernel_size=kernel_size,
109
- stride=stride,
110
- padding="same",
111
- bias=False,
112
- )
113
- self.bn1 = nn.BatchNorm1d(planes)
114
- self.conv2 = nn.Conv1d(
115
- planes,
116
- planes,
117
- kernel_size=kernel_size,
118
- stride=1,
119
- padding="same",
120
- bias=False,
121
- )
122
- self.bn2 = nn.BatchNorm1d(planes)
123
-
124
- self.shortcut = nn.Sequential(
125
- nn.Conv1d(
126
- in_planes,
127
- self.expansion * planes,
128
- kernel_size=1,
129
- stride=stride,
130
- padding="same",
131
- bias=False,
132
- ),
133
- nn.BatchNorm1d(self.expansion * planes),
134
- )
135
-
136
- def forward(self, x):
137
- out = F.relu(self.bn1(self.conv1(x)))
138
- out = self.bn2(self.conv2(out))
139
- out += self.shortcut(x)
140
- out = F.relu(out)
141
- return out
142
-
143
-
144
- class Updated_onset_picker(nn.Module):
145
- def __init__(
146
- self,
147
- ):
148
- super().__init__()
149
-
150
- # self.activation = nn.ReLU()
151
- # self.maxpool = nn.MaxPool1d(2)
152
-
153
- self.n_masks = 128
154
-
155
- self.block1 = nn.Sequential(
156
- BasicBlock(3, 8, kernel_size=7, groups=1),
157
- nn.GELU(),
158
- BlurPool1D(8, filt_size=3, stride=2),
159
- nn.GroupNorm(2, 8),
160
- )
161
-
162
- self.block2 = nn.Sequential(
163
- BasicBlock(8, 16, kernel_size=7, groups=8),
164
- nn.GELU(),
165
- BlurPool1D(16, filt_size=3, stride=2),
166
- nn.GroupNorm(2, 16),
167
- )
168
-
169
- self.block3 = nn.Sequential(
170
- BasicBlock(16, 32, kernel_size=7, groups=16),
171
- nn.GELU(),
172
- BlurPool1D(32, filt_size=3, stride=2),
173
- nn.GroupNorm(2, 32),
174
- )
175
-
176
- self.block4 = nn.Sequential(
177
- BasicBlock(32, 64, kernel_size=7, groups=32),
178
- nn.GELU(),
179
- BlurPool1D(64, filt_size=3, stride=2),
180
- nn.GroupNorm(2, 64),
181
- )
182
-
183
- self.block5 = nn.Sequential(
184
- BasicBlock(64, 128, kernel_size=7, groups=64),
185
- nn.GELU(),
186
- BlurPool1D(128, filt_size=3, stride=2),
187
- nn.GroupNorm(2, 128),
188
- )
189
-
190
- self.block6 = nn.Sequential(
191
- Masksembles1D(128, self.n_masks, 2.0),
192
- BasicBlock(128, 256, kernel_size=7, groups=128),
193
- nn.GELU(),
194
- BlurPool1D(256, filt_size=3, stride=2),
195
- nn.GroupNorm(2, 256),
196
- )
197
-
198
- self.block7 = nn.Sequential(
199
- Masksembles1D(256, self.n_masks, 2.0),
200
- BasicBlock(256, 512, kernel_size=7, groups=256),
201
- BlurPool1D(512, filt_size=3, stride=2),
202
- nn.GELU(),
203
- nn.GroupNorm(2, 512),
204
- )
205
-
206
- self.block8 = nn.Sequential(
207
- Masksembles1D(512, self.n_masks, 2.0),
208
- BasicBlock(512, 1024, kernel_size=7, groups=512),
209
- BlurPool1D(1024, filt_size=3, stride=2),
210
- nn.GELU(),
211
- nn.GroupNorm(2, 1024),
212
- )
213
-
214
- self.block9 = nn.Sequential(
215
- Masksembles1D(1024, self.n_masks, 2.0),
216
- BasicBlock(1024, 128, kernel_size=7, groups=128),
217
- # BlurPool1D(512, filt_size=3, stride=2),
218
- # nn.GELU(),
219
- # nn.GroupNorm(2,512),
220
- )
221
-
222
- self.out = nn.Sequential(nn.Linear(3072, 2), nn.Sigmoid())
223
-
224
- def forward(self, x):
225
- # Feature extraction
226
-
227
- x = self.block1(x)
228
- x = self.block2(x)
229
-
230
- x = self.block3(x)
231
- x = self.block4(x)
232
-
233
- x = self.block5(x)
234
- x = self.block6(x)
235
-
236
- x = self.block7(x)
237
- x = self.block8(x)
238
-
239
- x = self.block9(x)
240
-
241
- # Regressor
242
- x = x.flatten(start_dim=1)
243
- x = self.out(x)
244
-
245
- return x
246
-
247
- class Onset_picker(pl.LightningModule):
248
- def __init__(self, picker, learning_rate):
249
- super().__init__()
250
- self.picker = picker
251
- self.learning_rate = learning_rate
252
- self.save_hyperparameters(ignore=['picker'])
253
- self.mae = MeanAbsoluteError()
254
-
255
- def compute_loss(self, y, pick, mae_name=False):
256
- y_filt = y[y != 0]
257
- pick_filt = pick[y != 0]
258
- if len(y_filt) > 0:
259
- loss = F.l1_loss(y_filt, pick_filt.flatten())
260
- if mae_name != False:
261
- mae_phase = self.mae(y_filt, pick_filt.flatten())*60
262
- self.log(f'MAE/{mae_name}_val', mae_phase, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
263
- else:
264
- loss = 0
265
- return loss
266
-
267
- def training_step(self, batch, batch_idx):
268
- # training_step defines the train loop.
269
- x, y_p, y_s = batch
270
- # x, y_p, y_s, y_pg, y_sg, y_pn, y_sn = batch
271
-
272
- picks = self.picker(x)
273
-
274
- p_pick = picks[:,0]
275
- s_pick = picks[:,1]
276
-
277
- p_loss = self.compute_loss(y_p, p_pick)
278
- s_loss = self.compute_loss(y_s, s_pick)
279
-
280
- loss = (p_loss+s_loss)/2
281
-
282
- self.log('Loss/train', loss, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True)
283
-
284
- return loss
285
-
286
- def validation_step(self, batch, batch_idx):
287
-
288
- x, y_p, y_s = batch
289
-
290
- picks = self.picker(x)
291
-
292
- p_pick = picks[:,0]
293
- s_pick = picks[:,1]
294
-
295
- p_loss = self.compute_loss(y_p, p_pick, mae_name='P')
296
- s_loss = self.compute_loss(y_s, s_pick, mae_name='S')
297
-
298
- loss = (p_loss+s_loss)/2
299
-
300
- self.log('Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
301
-
302
- return loss
303
-
304
- def configure_optimizers(self):
305
- optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
306
- scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, cooldown=10, threshold=1e-3)
307
- # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 3e-4, epochs=300, steps_per_epoch=len(train_loader))
308
- monitor = 'Loss/train'
309
- return {"optimizer": optimizer, "lr_scheduler": scheduler, 'monitor': monitor}
310
-
311
- def forward(self, x):
312
- picks = self.picker(x)
313
- return picks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
phasehunter/training.py DELETED
@@ -1,104 +0,0 @@
1
-
2
- import torch
3
-
4
- from data_preparation import augment, collation_fn, my_split_by_node
5
- from model import Onset_picker, Updated_onset_picker
6
-
7
- import webdataset as wds
8
-
9
- from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
10
- from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
11
- from lightning.pytorch.strategies import DDPStrategy
12
- from lightning import seed_everything
13
- import lightning as pl
14
-
15
- seed_everything(42, workers=False)
16
- torch.set_float32_matmul_precision('medium')
17
-
18
- batch_size = 256
19
- num_workers = 16 #int(os.cpu_count())
20
- n_iters_in_epoch = 5000
21
-
22
- train_dataset = (
23
- wds.WebDataset("data/sample/shard-00{0000..0001}.tar",
24
- # splitter=my_split_by_worker,
25
- nodesplitter=my_split_by_node)
26
- .decode()
27
- .map(augment)
28
- .shuffle(5000)
29
- .batched(batchsize=batch_size,
30
- collation_fn=collation_fn,
31
- partial=False
32
- )
33
- ).with_epoch(n_iters_in_epoch//num_workers)
34
-
35
-
36
- val_dataset = (
37
- wds.WebDataset("data/sample/shard-00{0000..0000}.tar",
38
- # splitter=my_split_by_worker,
39
- nodesplitter=my_split_by_node)
40
- .decode()
41
- .map(augment)
42
- .repeat()
43
- .batched(batchsize=batch_size,
44
- collation_fn=collation_fn,
45
- partial=False
46
- )
47
- ).with_epoch(100)
48
-
49
-
50
- train_loader = wds.WebLoader(train_dataset,
51
- num_workers=num_workers,
52
- shuffle=False,
53
- pin_memory=True,
54
- batch_size=None)
55
-
56
- val_loader = wds.WebLoader(val_dataset,
57
- num_workers=0,
58
- shuffle=False,
59
- pin_memory=True,
60
- batch_size=None)
61
-
62
-
63
-
64
- # model
65
- model = Onset_picker(picker=Updated_onset_picker(),
66
- learning_rate=3e-4)
67
- # model = torch.compile(model, mode="reduce-overhead")
68
-
69
- logger = TensorBoardLogger("tensorboard_logdir", name="FAST")
70
-
71
- checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="Loss/val", filename="chkp-{epoch:02d}")
72
- lr_callback = LearningRateMonitor(logging_interval='epoch')
73
- # swa_callback = StochasticWeightAveraging(swa_lrs=0.05)
74
-
75
- # # train model
76
- trainer = pl.Trainer(
77
- precision='16-mixed',
78
-
79
- callbacks=[checkpoint_callback, lr_callback],
80
-
81
- devices='auto',
82
- accelerator='auto',
83
-
84
- strategy=DDPStrategy(find_unused_parameters=False,
85
- static_graph=True,
86
- gradient_as_bucket_view=True),
87
- benchmark=True,
88
-
89
- gradient_clip_val=0.5,
90
- # ckpt_path='path/to/saved/checkpoints/chkp.ckpt',
91
-
92
- # fast_dev_run=True,
93
-
94
- logger=logger,
95
- log_every_n_steps=50,
96
- enable_progress_bar=True,
97
-
98
- max_epochs=300,
99
- )
100
-
101
- trainer.fit(model=model,
102
- train_dataloaders=train_loader,
103
- val_dataloaders=val_loader,
104
- )