Taino commited on
Commit
b8f40f2
·
1 Parent(s): 41b0801

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +236 -0
main.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, re, json
2
+ import argparse
3
+ #import whisper_timestamped as wt
4
+ from pdb import set_trace as b
5
+ from pprint import pprint as pp
6
+ from profanity_check import predict, predict_prob
7
+ from pydub import AudioSegment
8
+ from pydub.playback import play
9
+ from subprocess import Popen, PIPE
10
+
11
+
12
+ def parse_args():
13
+ """
14
+ """
15
+ parser = argparse.ArgumentParser(
16
+ description=('Tool to mute profanities in a song (source separation -> speech recognition -> profanity detection -> mask profanities -> re-mix)'),
17
+ usage=('see <py main.py --help> or run as local web app with streamlit: <streamlit run main.py>')
18
+ )
19
+
20
+ parser.add_argument(
21
+ '-i',
22
+ '--input',
23
+ default=None,
24
+ nargs='?',
25
+ #required=True,
26
+ help=("path to a mp3")
27
+ )
28
+ parser.add_argument(
29
+ '-m',
30
+ '--model',
31
+ default='small',
32
+ nargs='?',
33
+ help=("model used by whisper for speech recognition: tiny, small (default), medium or large")
34
+ )
35
+ parser.add_argument(
36
+ '-p',
37
+ '--play',
38
+ default=False,
39
+ action='store_true',
40
+ help=("play output audio at the end")
41
+ )
42
+ parser.add_argument(
43
+ '-v',
44
+ '--verbose',
45
+ default=False,
46
+ action='store_true',
47
+ help=("print transcribed text and detected profanities to screen")
48
+ )
49
+ return parser.parse_args()
50
+
51
+
52
+
53
+ def main(args, input_file=None, model_size=None, verbose=False, play_output=False):
54
+ """
55
+ """
56
+ if not input_file:
57
+ input_file = args.input
58
+
59
+ if not model_size:
60
+ model_size = args.model
61
+
62
+ if not verbose:
63
+ verbose = args.verbose
64
+
65
+ if not play_output:
66
+ play_output = args.play
67
+
68
+ # exit if input file not found
69
+ if not os.path.isfile(input_file):
70
+ print('Error: --input file not found')
71
+ sys.exit()
72
+
73
+ print(f'\nProcessing input file: {input_file}')
74
+
75
+ # split audio into vocals + accompaniment
76
+ print('Running source separation')
77
+ stems_dir = source_separation(input_file)
78
+ vocal_stem = os.path.join(stems_dir, 'vocals.wav')
79
+ instr_stem = os.path.join(stems_dir, 'no_vocals.wav')
80
+ print(f'Vocal stem written to: {vocal_stem}')
81
+
82
+ # speech rec (audio->text)
83
+ print('Transcribe vocal stem into text with word-level timestamps')
84
+ #cmd = f'whisper_timestamped --task transcribe --model {model_size} {vocal_stem}'
85
+ #stdout, stderr = Popen(cmd, stdout=PIPE, stderr=PIPE, shell=True, executable='/bin/bash').communicate()
86
+ #text = json.loads('\n'.join(stdout.decode('utf8').split('\n')[1:]))
87
+
88
+ import whisper_timestamped as wt
89
+ audio = wt.load_audio(vocal_stem)
90
+ model = wt.load_model(model_size, device='cpu')
91
+ text = wt.transcribe(model, audio, language='en')
92
+
93
+ if verbose:
94
+ print('\nTranscribed text:')
95
+ print(text['text']+'\n')
96
+
97
+ # checking for profanities in text
98
+ print('Run profanity detection on text')
99
+ profanities = profanity_detection(text)
100
+ if not profanities:
101
+ print(f'No profanities found in {input_file} - exiting')
102
+ sys.exit()
103
+ if verbose:
104
+ print('profanities found in text:')
105
+ pp(profanities)
106
+
107
+ # masking
108
+ print('Mask profanities in vocal stem')
109
+ vocals = mask_profanities(vocal_stem, profanities)
110
+
111
+ # re-mixing
112
+ print('Merge instrumentals stem and masked vocals stem')
113
+ mix = AudioSegment.from_wav(instr_stem).overlay(vocals)
114
+
115
+ # write mix to file
116
+ outpath = input_file.replace('.mp3', '_masked.mp3').replace('.wav', '_masked.wav')
117
+ if input_file.endswith('.wav'):
118
+ mix.export(outpath, format="wav")
119
+ elif input_file.endswith('.mp3'):
120
+ mix.export(outpath, format="mp3")
121
+ print(f'Mixed file written to: {outpath}')
122
+
123
+ # play output
124
+ if play_output:
125
+ print('\nPlaying output...')
126
+ play(mix)
127
+
128
+ return outpath
129
+
130
+
131
+ def source_separation(inpath):
132
+ """
133
+ Execute shell command to run demucs and pipe stdout/stderr back to python
134
+ """
135
+ cmd = f'demucs --two-stems=vocals --jobs 8 "{inpath}"'
136
+ stdout, stderr = Popen(cmd, stdout=PIPE, stderr=PIPE, shell=True, executable='/bin/bash').communicate()
137
+ stdout = stdout.decode('utf8')
138
+
139
+ # exit if demucs error'd out
140
+ if stderr:
141
+ stderr = stderr.decode('utf-8').lower()
142
+ if 'error' in stderr or 'not exist' in stderr:
143
+ print(stderr.decode('utf8').split('\n')[0])
144
+ sys.exit()
145
+
146
+ # parse stems directory path from stdout and return it if successful
147
+ stems_dir = ''.join(re.findall('/.*', stdout)).replace('.mp3','').replace('.wav','').replace('samples/','')
148
+ if not os.path.isdir(stems_dir):
149
+ print(f'Error: output stem directory "{stems_dir}" not found')
150
+ sys.exit()
151
+
152
+ return stems_dir
153
+
154
+
155
+ def profanity_detection(text):
156
+ """
157
+ """
158
+ # detect profanities in text
159
+ profs = []
160
+ for segment in text['segments']:
161
+ for word in segment['words']:
162
+ #if word['confidence']<.25:
163
+ # print(word)
164
+ text = word['text'].replace('.','').replace(',','').lower()
165
+
166
+ # skip false positives
167
+ if text in ['cancer', 'hell', 'junk', 'die', 'lame', 'freak', 'freaky', 'white', 'stink', 'shut', 'spit', 'mouth','orders','eat','clouds']:
168
+ continue
169
+
170
+ # assume anything returned by whisper with more than 1 * is profanity e.g n***a
171
+ if '**' in text:
172
+ profs.append(word)
173
+ continue
174
+
175
+ # add true negatives
176
+ if text in ['bitchy', 'puss']:
177
+ profs.append(word)
178
+ continue
179
+
180
+ # run profanity detection - returns 1 (True) or 0 (False)
181
+ if predict([word['text']])[0]:
182
+ profs.append(word)
183
+
184
+ return profs
185
+
186
+
187
+ def mask_profanities(vocal_stem, profanities):
188
+ """
189
+ """
190
+ # load vocal stem and mask profanities
191
+ vocals = AudioSegment.from_wav(vocal_stem)
192
+ for prof in profanities:
193
+ mask = vocals[prof['start']*1000:prof['end']*1000] # pydub works in milliseconds
194
+ mask -= 50 # reduce lvl by some dB (enough to ~mute it)
195
+ #mask = mask.silent(len(mask))
196
+ #mask = mask.fade_in(100).fade_out(100) # it prepends/appends fades so end up with longer mask
197
+ start = vocals[:prof['start']*1000]
198
+ end = vocals[prof['end']*1000:]
199
+ #print(f"masking {prof['text']} from {prof['start']} to {prof['end']}")
200
+ vocals = start + mask + end
201
+
202
+ return vocals
203
+
204
+
205
+
206
+ if __name__ == "__main__":
207
+ args = parse_args()
208
+
209
+ if len(sys.argv)>1:
210
+ main(args)
211
+ else:
212
+ import streamlit as st
213
+ st.title('Saylss')
214
+ model = st.selectbox('Choose model size:', ('tiny','small','medium','large'), index=1)
215
+ uploaded_file = st.file_uploader("Choose input track", type=[".mp3",".wav"], accept_multiple_files=False)
216
+
217
+ if uploaded_file is not None:
218
+
219
+ # display input audio
220
+ #st.text('Play input track:')
221
+ audio_bytes_input = uploaded_file.read()
222
+ st.audio(audio_bytes_input, format='audio/wav')
223
+
224
+ # run code
225
+ with st.spinner('Processing input audio...'):
226
+ outpath = main(args, input_file=os.path.join('audio/samples',uploaded_file.name), model_size=model)
227
+
228
+ # display output audio
229
+ #st.text('Play output Track:')
230
+ st.text('\nOutput:')
231
+ audio_file = open(outpath, 'rb')
232
+ audio_bytes = audio_file.read()
233
+ st.audio(audio_bytes, format='audio/wav')
234
+
235
+
236
+