crimeacs commited on
Commit
bfda450
·
1 Parent(s): a237ee5

Now original waveform is displayed

Browse files
Files changed (2) hide show
  1. Gradio_app.ipynb +16 -20
  2. app.py +7 -6
Gradio_app.ipynb CHANGED
@@ -2,14 +2,14 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 13,
6
  "metadata": {},
7
  "outputs": [
8
  {
9
  "name": "stdout",
10
  "output_type": "stream",
11
  "text": [
12
- "Running on local URL: http://127.0.0.1:7866\n",
13
  "\n",
14
  "To create a public link, set `share=True` in `launch()`.\n"
15
  ]
@@ -17,7 +17,7 @@
17
  {
18
  "data": {
19
  "text/html": [
20
- "<div><iframe src=\"http://127.0.0.1:7866/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
21
  ],
22
  "text/plain": [
23
  "<IPython.core.display.HTML object>"
@@ -30,22 +30,16 @@
30
  "data": {
31
  "text/plain": []
32
  },
33
- "execution_count": 13,
34
  "metadata": {},
35
  "output_type": "execute_result"
36
  },
37
- {
38
- "name": "stderr",
39
- "output_type": "stream",
40
- "text": [
41
- "No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.\n"
42
- ]
43
- },
44
  {
45
  "name": "stdout",
46
  "output_type": "stream",
47
  "text": [
48
- "0.13119522482156754\n"
 
49
  ]
50
  }
51
  ],
@@ -83,8 +77,9 @@
83
  " if len(waveform.shape) == 1:\n",
84
  " waveform = waveform.reshape(1, waveform.shape[0])\n",
85
  "\n",
 
86
  " processed_input = prepare_waveform(waveform)\n",
87
- " \n",
88
  " # Make prediction\n",
89
  " with torch.inference_mode():\n",
90
  " output = model(processed_input)\n",
@@ -92,33 +87,34 @@
92
  " p_phase = output[:, 0]\n",
93
  " s_phase = output[:, 1]\n",
94
  "\n",
95
- " return processed_input, p_phase, s_phase\n",
 
96
  "\n",
97
  "def mark_phases(waveform, uploaded_file, p_thres, s_thres):\n",
98
  "\n",
99
  " if uploaded_file is not None:\n",
100
  " waveform = uploaded_file.name\n",
101
  "\n",
102
- " processed_input, p_phase, s_phase = make_prediction(waveform)\n",
103
  "\n",
104
  " # Create a plot of the waveform with the phases marked\n",
105
  " if sum(processed_input[0][2] == 0): #if input is 1C\n",
106
  " fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)\n",
107
  "\n",
108
- " ax[0].plot(processed_input[0][0], color='black', lw=1)\n",
109
  " ax[0].set_ylabel('Norm. Ampl.')\n",
110
  "\n",
111
  " else: #if input is 3C\n",
112
  " fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)\n",
113
- " ax[0].plot(processed_input[0][0], color='black', lw=1)\n",
114
- " ax[1].plot(processed_input[0][1], color='black', lw=1)\n",
115
- " ax[2].plot(processed_input[0][2], color='black', lw=1)\n",
116
  "\n",
117
  " ax[0].set_ylabel('Z')\n",
118
  " ax[1].set_ylabel('N')\n",
119
  " ax[2].set_ylabel('E')\n",
120
  "\n",
121
- " print(p_phase.std().item()*60)\n",
122
  " do_we_have_p = (p_phase.std().item()*60 < p_thres)\n",
123
  " if do_we_have_p:\n",
124
  " p_phase_plot = p_phase*processed_input.shape[-1]\n",
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 16,
6
  "metadata": {},
7
  "outputs": [
8
  {
9
  "name": "stdout",
10
  "output_type": "stream",
11
  "text": [
12
+ "Running on local URL: http://127.0.0.1:7869\n",
13
  "\n",
14
  "To create a public link, set `share=True` in `launch()`.\n"
15
  ]
 
17
  {
18
  "data": {
19
  "text/html": [
20
+ "<div><iframe src=\"http://127.0.0.1:7869/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
21
  ],
22
  "text/plain": [
23
  "<IPython.core.display.HTML object>"
 
30
  "data": {
31
  "text/plain": []
32
  },
33
+ "execution_count": 16,
34
  "metadata": {},
35
  "output_type": "execute_result"
36
  },
 
 
 
 
 
 
 
37
  {
38
  "name": "stdout",
39
  "output_type": "stream",
40
  "text": [
41
+ "4\n",
42
+ "0.02744414610788226\n"
43
  ]
44
  }
45
  ],
 
77
  " if len(waveform.shape) == 1:\n",
78
  " waveform = waveform.reshape(1, waveform.shape[0])\n",
79
  "\n",
80
+ " orig_waveform = waveform[:, :6000].copy()\n",
81
  " processed_input = prepare_waveform(waveform)\n",
82
+ "\n",
83
  " # Make prediction\n",
84
  " with torch.inference_mode():\n",
85
  " output = model(processed_input)\n",
 
87
  " p_phase = output[:, 0]\n",
88
  " s_phase = output[:, 1]\n",
89
  "\n",
90
+ " return processed_input, p_phase, s_phase, orig_waveform\n",
91
+ "\n",
92
  "\n",
93
  "def mark_phases(waveform, uploaded_file, p_thres, s_thres):\n",
94
  "\n",
95
  " if uploaded_file is not None:\n",
96
  " waveform = uploaded_file.name\n",
97
  "\n",
98
+ " processed_input, p_phase, s_phase, orig_waveform = make_prediction(waveform)\n",
99
  "\n",
100
  " # Create a plot of the waveform with the phases marked\n",
101
  " if sum(processed_input[0][2] == 0): #if input is 1C\n",
102
  " fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)\n",
103
  "\n",
104
+ " ax[0].plot(orig_waveform[0], color='black', lw=1)\n",
105
  " ax[0].set_ylabel('Norm. Ampl.')\n",
106
  "\n",
107
  " else: #if input is 3C\n",
108
  " fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)\n",
109
+ " ax[0].plot(orig_waveform[0], color='black', lw=1)\n",
110
+ " ax[1].plot(orig_waveform[1], color='black', lw=1)\n",
111
+ " ax[2].plot(orig_waveform[2], color='black', lw=1)\n",
112
  "\n",
113
  " ax[0].set_ylabel('Z')\n",
114
  " ax[1].set_ylabel('N')\n",
115
  " ax[2].set_ylabel('E')\n",
116
  "\n",
117
+ "\n",
118
  " do_we_have_p = (p_phase.std().item()*60 < p_thres)\n",
119
  " if do_we_have_p:\n",
120
  " p_phase_plot = p_phase*processed_input.shape[-1]\n",
app.py CHANGED
@@ -36,6 +36,7 @@ def make_prediction(waveform):
36
  if len(waveform.shape) == 1:
37
  waveform = waveform.reshape(1, waveform.shape[0])
38
 
 
39
  processed_input = prepare_waveform(waveform)
40
 
41
  # Make prediction
@@ -45,7 +46,7 @@ def make_prediction(waveform):
45
  p_phase = output[:, 0]
46
  s_phase = output[:, 1]
47
 
48
- return processed_input, p_phase, s_phase
49
 
50
 
51
  def mark_phases(waveform, uploaded_file, p_thres, s_thres):
@@ -53,20 +54,20 @@ def mark_phases(waveform, uploaded_file, p_thres, s_thres):
53
  if uploaded_file is not None:
54
  waveform = uploaded_file.name
55
 
56
- processed_input, p_phase, s_phase = make_prediction(waveform)
57
 
58
  # Create a plot of the waveform with the phases marked
59
  if sum(processed_input[0][2] == 0): # if input is 1C
60
  fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)
61
 
62
- ax[0].plot(processed_input[0][0], color="black", lw=1)
63
  ax[0].set_ylabel("Norm. Ampl.")
64
 
65
  else: # if input is 3C
66
  fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)
67
- ax[0].plot(processed_input[0][0], color="black", lw=1)
68
- ax[1].plot(processed_input[0][1], color="black", lw=1)
69
- ax[2].plot(processed_input[0][2], color="black", lw=1)
70
 
71
  ax[0].set_ylabel("Z")
72
  ax[1].set_ylabel("N")
 
36
  if len(waveform.shape) == 1:
37
  waveform = waveform.reshape(1, waveform.shape[0])
38
 
39
+ orig_waveform = waveform[:, :6000].copy()
40
  processed_input = prepare_waveform(waveform)
41
 
42
  # Make prediction
 
46
  p_phase = output[:, 0]
47
  s_phase = output[:, 1]
48
 
49
+ return processed_input, p_phase, s_phase, orig_waveform
50
 
51
 
52
  def mark_phases(waveform, uploaded_file, p_thres, s_thres):
 
54
  if uploaded_file is not None:
55
  waveform = uploaded_file.name
56
 
57
+ processed_input, p_phase, s_phase, orig_waveform = make_prediction(waveform)
58
 
59
  # Create a plot of the waveform with the phases marked
60
  if sum(processed_input[0][2] == 0): # if input is 1C
61
  fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)
62
 
63
+ ax[0].plot(orig_waveform[0], color="black", lw=1)
64
  ax[0].set_ylabel("Norm. Ampl.")
65
 
66
  else: # if input is 3C
67
  fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)
68
+ ax[0].plot(orig_waveform[0], color="black", lw=1)
69
+ ax[1].plot(orig_waveform[1], color="black", lw=1)
70
+ ax[2].plot(orig_waveform[2], color="black", lw=1)
71
 
72
  ax[0].set_ylabel("Z")
73
  ax[1].set_ylabel("N")