bpiyush commited on
Commit
a0be511
·
verified ·
1 Parent(s): 41ce9e0

Upload util.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. util.py +392 -0
util.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ custom_css = """
2
+ <style>
3
+ .container {
4
+ max-width: 100% !important;
5
+ padding-left: 0 !important;
6
+ padding-right: 0 !important;
7
+ }
8
+ .header {
9
+ padding: 30px;
10
+ margin-bottom: 30px;
11
+ text-align: center;
12
+ font-family: 'Helvetica Neue', Arial, sans-serif;
13
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
14
+ }
15
+ .header h1 {
16
+ font-size: 36px;
17
+ margin-bottom: 15px;
18
+ font-weight: bold;
19
+ color: #333333; /* Explicitly set heading color */
20
+ }
21
+ .header h2 {
22
+ font-size: 24px;
23
+ margin-bottom: 10px;
24
+ color: #333333; /* Explicitly set subheading color */
25
+ }
26
+ .header p {
27
+ font-size: 18px;
28
+ margin: 5px 0;
29
+ color: #666666;
30
+ }
31
+ .blue-text {
32
+ color: #4a90e2;
33
+ }
34
+ /* Custom styles for slider container */
35
+ .slider-container {
36
+ background-color: white !important;
37
+ padding-top: 0.9em;
38
+ padding-bottom: 0.9em;
39
+ }
40
+ /* Add gap before examples */
41
+ .examples-holder {
42
+ margin-top: 2em;
43
+ }
44
+ /* Set fixed size for example videos */
45
+ .gradio-container .gradio-examples .gr-sample {
46
+ width: 240px !important;
47
+ height: 135px !important;
48
+ object-fit: cover;
49
+ display: inline-block;
50
+ margin-right: 10px;
51
+ }
52
+ .gradio-container .gradio-examples {
53
+ display: flex;
54
+ flex-wrap: wrap;
55
+ gap: 10px;
56
+ }
57
+ /* Ensure the parent container does not stretch */
58
+ .gradio-container .gradio-examples {
59
+ max-width: 100%;
60
+ overflow: hidden;
61
+ }
62
+ /* Additional styles to ensure proper sizing in Safari */
63
+ .gradio-container .gradio-examples .gr-sample img {
64
+ width: 240px !important;
65
+ height: 135px !important;
66
+ object-fit: cover;
67
+ }
68
+ </style>
69
+ """
70
+
71
+ custom_html = custom_css + """
72
+ <div class="header">
73
+ <h1><span class="blue-text">The Sound of Water</span>: Inferring Physical Properties from Pouring Liquids</h1>
74
+ <p><a href='https://bpiyush.github.io/pouring-water-website/'>Project Page</a> |
75
+ <a href='https://github.com/bpiyush/SoundOfWater'>Github</a> |
76
+ <a href='#'>Paper</a> |
77
+ <a href='https://huggingface.co/datasets/bpiyush/sound-of-water'>Data</a>
78
+ <a href='https://huggingface.co/bpiyush/sound-of-water-models'>Models</a></p>
79
+ </div>
80
+ """
81
+
82
+ tips = """
83
+ <div>
84
+ <br><br>
85
+ Please give us a 🌟 on <a href='https://github.com/bpiyush/SoundOfWater'>Github</a> if you like our work!
86
+ Tips to get better results:
87
+ <ul>
88
+ <li>Make sure there is not too much noise such that the pouring is audible.</li>
89
+ <li>The video is not used during the inference.</li>
90
+ </ul>
91
+ </div>
92
+ """
93
+
94
+ import os
95
+ import sys
96
+
97
+ import gradio as gr
98
+ import torch
99
+ import numpy as np
100
+ import matplotlib.pyplot as plt
101
+ plt.rcParams["font.family"] = "serif"
102
+ import decord
103
+ import PIL, PIL.Image
104
+ import librosa
105
+ from IPython.display import Markdown, display
106
+ import pandas as pd
107
+
108
+ import shared.utils as su
109
+ import sound_of_water.audio_pitch.model as audio_models
110
+ import sound_of_water.data.audio_loader as audio_loader
111
+ import sound_of_water.data.audio_transforms as at
112
+ import sound_of_water.data.csv_loader as csv_loader
113
+
114
+
115
+ def read_html_file(file):
116
+ with open(file) as f:
117
+ return f.read()
118
+
119
+
120
+
121
+ def define_axes(figsize=(13, 4), width_ratios=[0.22, 0.78]):
122
+ fig, axes = plt.subplots(
123
+ 1, 2, figsize=figsize, width_ratios=width_ratios,
124
+ layout="constrained",
125
+ )
126
+ return fig, axes
127
+
128
+
129
+ def show_frame_and_spectrogram(frame, spectrogram, visualise_args, axes=None):
130
+ """Shows the frame and spectrogram side by side."""
131
+
132
+ if axes is None:
133
+ fig, axes = define_axes()
134
+ else:
135
+ assert len(axes) == 2
136
+
137
+ ax = axes[0]
138
+ ax.imshow(frame, aspect="auto")
139
+ ax.set_title("Example frame")
140
+ ax.set_xticks([])
141
+ ax.set_yticks([])
142
+ ax = axes[1]
143
+ audio_loader.show_logmelspectrogram(
144
+ S=spectrogram,
145
+ ax=ax,
146
+ show=False,
147
+ sr=visualise_args["sr"],
148
+ n_fft=visualise_args["n_fft"],
149
+ hop_length=visualise_args["hop_length"],
150
+ )
151
+
152
+
153
+ def scatter_pitch(ax, t, f, s=60, marker="o", color="limegreen", label="Pitch"):
154
+ """Scatter plot of pitch."""
155
+ ax.scatter(t, f, color=color, label=label, s=s, marker=marker)
156
+ ax.set_xlabel("Time (s)")
157
+ ax.set_ylabel("Frequency (Hz)")
158
+ ax.legend(loc="upper left")
159
+
160
+
161
+ # Load video frame
162
+ def load_frame(video_path):
163
+ vr = decord.VideoReader(video_path, num_threads=1)
164
+ frame = PIL.Image.fromarray(vr[0].asnumpy())
165
+ frame = audio_loader.crop_or_pad_to_size(frame, size=(270, 480))
166
+ return frame
167
+
168
+
169
+ def load_spectrogram(video_path):
170
+ y = audio_loader.load_audio_clips(
171
+ audio_path=video_path,
172
+ clips=None,
173
+ load_entire=True,
174
+ cut_to_clip_len=False,
175
+ **aload_args,
176
+ )[0]
177
+ S = audio_loader.librosa_harmonic_spectrogram_db(
178
+ y,
179
+ sr=visualise_args["sr"],
180
+ n_fft=visualise_args["n_fft"],
181
+ hop_length=visualise_args["hop_length"],
182
+ n_mels=visualise_args['n_mels'],
183
+ )
184
+ return S
185
+
186
+
187
+ # Load audio
188
+ visualise_args = {
189
+ "sr": 16000,
190
+ "n_fft": 400,
191
+ "hop_length": 320,
192
+ "n_mels": 64,
193
+ "margin": 16.,
194
+ "C": 340 * 100.,
195
+ "audio_output_fps": 49.,
196
+ "w_max": 100.,
197
+ "n_bins": 64,
198
+ }
199
+ aload_args = {
200
+ "sr": 16000,
201
+ "clip_len": None,
202
+ "backend": "decord",
203
+ }
204
+
205
+
206
+ cfg_backbone = {
207
+ "name": "Wav2Vec2WithTimeEncoding",
208
+ "args": dict(),
209
+ }
210
+ backbone = getattr(audio_models, cfg_backbone["name"])(
211
+ **cfg_backbone["args"],
212
+ )
213
+
214
+
215
+ cfg_model = {
216
+ "name": "WavelengthWithTime",
217
+ "args": {
218
+ "axial": True,
219
+ "axial_bins": 64,
220
+ "radial": True,
221
+ "radial_bins": 64,
222
+ "freeze_backbone": True,
223
+ "train_backbone_modules": [6, 7, 8, 9, 10, 11],
224
+ "act": "softmax",
225
+ "criterion": "kl_div",
226
+ }
227
+ }
228
+
229
+
230
+ def load_model():
231
+ model = getattr(audio_models, cfg_model["name"])(
232
+ backbone=backbone, **cfg_model["args"],
233
+ )
234
+ su.misc.num_params(model)
235
+
236
+
237
+ # Load the model weights from trained checkpoint
238
+ # NOTE: Be sure to set the correct path to the checkpoint
239
+ su.log.print_update("[:::] Loading checkpoint ", color="cyan", fillchar=".", pos="left")
240
+ # ckpt_dir = "/work/piyush/pretrained_checkpoints/SoundOfWater"
241
+ ckpt_dir = "./checkpoints"
242
+ ckpt_path = os.path.join(
243
+ ckpt_dir,
244
+ "dsr9mf13_ep100_step12423_real_finetuned_with_cosupervision.pth",
245
+ )
246
+ assert os.path.exists(ckpt_path), \
247
+ f"Checkpoint not found at {ckpt_path}."
248
+ print("Loading checkpoint from: ", ckpt_path)
249
+ ckpt = torch.load(ckpt_path, map_location="cpu")
250
+ msg = model.load_state_dict(ckpt)
251
+ print(msg)
252
+ return model
253
+
254
+
255
+ # Define audio transforms
256
+ cfg_transform = {
257
+ "audio": {
258
+ "wave": [
259
+ {
260
+ "name": "AddNoise",
261
+ "args": {
262
+ "noise_level": 0.001
263
+ },
264
+ "augmentation": True,
265
+ },
266
+ {
267
+ "name": "ChangeVolume",
268
+ "args": {
269
+ "volume_factor": [0.8, 1.2]
270
+ },
271
+ "augmentation": True,
272
+ },
273
+ {
274
+ "name": "Wav2Vec2WaveformProcessor",
275
+ "args": {
276
+ "model_name": "facebook/wav2vec2-base-960h",
277
+ "sr": 16000
278
+ }
279
+ }
280
+ ],
281
+ "spec": None,
282
+ }
283
+ }
284
+ audio_transform = at.define_audio_transforms(
285
+ cfg_transform, augment=False,
286
+ )
287
+
288
+ # Define audio pipeline arguments
289
+ apipe_args = {
290
+ "spec_args": None,
291
+ "stack": True,
292
+ }
293
+
294
+
295
+ def load_audio_tensor(video_path):
296
+ # Load and transform input audio
297
+ audio = audio_loader.load_and_process_audio(
298
+ audio_path=video_path,
299
+ clips=None,
300
+ load_entire=True,
301
+ cut_to_clip_len=False,
302
+ audio_transform=audio_transform,
303
+ aload_args=aload_args,
304
+ apipe_args=apipe_args,
305
+ )[0]
306
+ return audio
307
+
308
+
309
+ def get_model_output(audio, model):
310
+ with torch.no_grad():
311
+ NS = audio.shape[-1]
312
+ duration = NS / 16000
313
+ t = torch.tensor([[0, duration]]).unsqueeze(0)
314
+ x = audio.unsqueeze(0)
315
+ z_audio = model.backbone(x, t)[0][0].cpu()
316
+ y_audio = model(x, t)["axial"][0][0].cpu()
317
+ return z_audio, y_audio
318
+
319
+
320
+ def show_output(frame, S, y_audio, z_audio):
321
+ # duration = S.shape[-1] / visualise_args["sr"]
322
+ # print(S.shape, y_audio.shape, z_audio.shape)
323
+ duration = librosa.get_duration(
324
+ S=S,
325
+ sr=visualise_args["sr"],
326
+ n_fft=visualise_args["n_fft"],
327
+ hop_length=visualise_args["hop_length"],
328
+ )
329
+ timestamps = np.linspace(0., duration, 25)
330
+
331
+ # Get timestamps at evaluation frames
332
+ n_frames = len(y_audio)
333
+ timestamps_eval = librosa.frames_to_time(
334
+ np.arange(n_frames),
335
+ sr=visualise_args['sr'],
336
+ n_fft=visualise_args['n_fft'],
337
+ hop_length=visualise_args['hop_length'],
338
+ )
339
+ # Get predicted frequencies at these times
340
+ wavelengths = y_audio @ torch.linspace(
341
+ 0, visualise_args['w_max'], visualise_args['n_bins'],
342
+ )
343
+ f_pred = visualise_args['C'] / wavelengths
344
+ # Pick only those timestamps where we define the true pitch
345
+ indices = su.misc.find_nearest_indices(timestamps_eval, timestamps)
346
+ f_pred = f_pred[indices]
347
+
348
+ # print(timestamps, f_pred)
349
+
350
+ # Show the true/pref pitch overlaid on the spectrogram
351
+ fig, axes = define_axes()
352
+ show_frame_and_spectrogram(frame, S, visualise_args, axes=axes)
353
+ scatter_pitch(axes[1], timestamps, f_pred, color="white", label="Estimated pitch", marker="o", s=70)
354
+ axes[1].set_title("True and predicted pitch overlaid on the spectrogram")
355
+ # plt.show()
356
+ # Convert to PIL Image and return the Image
357
+ from PIL import Image
358
+
359
+ # Draw the figure to a canvas
360
+ canvas = fig.canvas
361
+ canvas.draw()
362
+
363
+ # Get the RGBA buffer from the figure
364
+ w, h = fig.canvas.get_width_height()
365
+ buf = canvas.tostring_rgb()
366
+
367
+ # Create a PIL image from the RGB data
368
+ image = Image.frombytes("RGB", (w, h), buf)
369
+
370
+
371
+ # Get physical properties
372
+ l_pred = su.physics.estimate_length_of_air_column(wavelengths)
373
+ l_pred_mean = l_pred.mean().item()
374
+ l_pred_mean = np.round(l_pred_mean, 2)
375
+ H_pred = su.physics.estimate_cylinder_height(wavelengths)
376
+ H_pred = np.round(H_pred, 2)
377
+ R_pred = su.physics.estimate_cylinder_radius(wavelengths)
378
+ R_pred = np.round(R_pred, 2)
379
+ # print(f"Estimated length: {l_pred_mean} cm, Estimated height: {H_pred} cm, Estimated radius: {R_pred} cm")
380
+ df_show = pd.DataFrame({
381
+ "Physical Property": ["Container height", "Container radius", "Length of air column (mean)"],
382
+ "Estimated Value (in cms)": [H_pred, R_pred, l_pred_mean],
383
+ })
384
+
385
+
386
+ tsne_image = su.visualize.show_temporal_tsne(
387
+ z_audio.detach().numpy(), timestamps_eval, show=False,
388
+ figsize=(6, 5), title="Temporal t-SNE of latent features",
389
+ return_as_pil = True,
390
+ )
391
+
392
+ return image, df_show, tsne_image