RamAnanth1 commited on
Commit
b9f865a
·
verified ·
1 Parent(s): b7aef95

Initial commit

Browse files

Minor modifications to TTS-arena code

Files changed (1) hide show
  1. app.py +429 -0
app.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # An arena for 3D generations with code inspired from TTS arena
2
+
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from langdetect import detect
6
+ from datasets import load_dataset
7
+ import threading, time, uuid, sqlite3, shutil, os, random, asyncio, threading
8
+ from pathlib import Path
9
+ from huggingface_hub import CommitScheduler, delete_file, hf_hub_download
10
+ from gradio_client import Client
11
+ from detoxify import Detoxify
12
+ import os
13
+ import tempfile
14
+
15
+ toxicity = Detoxify('original')
16
+
17
+ ####################################
18
+ # Constants
19
+ ####################################
20
+ AVAILABLE_MODELS = {
21
+ 'TripoSR': 'TripSR',
22
+ 'Shape-E': 'shap-e',
23
+ }
24
+
25
+ SPACE_ID = os.getenv('SPACE_ID')
26
+ MAX_SAMPLE_TXT_LENGTH = 300
27
+ MIN_SAMPLE_TXT_LENGTH = 10
28
+ DB_DATASET_ID = os.getenv('DATASET_ID')
29
+ DB_NAME = "database.db"
30
+
31
+ # If /data available => means local storage is enabled => let's use it!
32
+ DB_PATH = f"/data/{DB_NAME}" if os.path.isdir("/data") else DB_NAME
33
+ print(f"Using {DB_PATH}")
34
+
35
+ ####################################
36
+ # Functions
37
+ ####################################
38
+
39
+ def create_db_if_missing():
40
+ conn = get_db()
41
+ cursor = conn.cursor()
42
+ cursor.execute('''
43
+ CREATE TABLE IF NOT EXISTS model (
44
+ name TEXT UNIQUE,
45
+ upvote INTEGER,
46
+ downvote INTEGER
47
+ );
48
+ ''')
49
+ cursor.execute('''
50
+ CREATE TABLE IF NOT EXISTS vote (
51
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
52
+ username TEXT,
53
+ model TEXT,
54
+ vote INTEGER,
55
+ timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
56
+ );
57
+ ''')
58
+ cursor.execute('''
59
+ CREATE TABLE IF NOT EXISTS votelog (
60
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
61
+ username TEXT,
62
+ chosen TEXT,
63
+ rejected TEXT,
64
+ timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
65
+ );
66
+ ''')
67
+
68
+ def get_db():
69
+ return sqlite3.connect(DB_PATH)
70
+
71
+ ####################################
72
+ # Space initialization
73
+ ####################################
74
+
75
+ # Download existing DB
76
+ if not os.path.isfile(DB_PATH):
77
+ print("Downloading DB...")
78
+ try:
79
+ cache_path = hf_hub_download(repo_id=DB_DATASET_ID, repo_type='dataset', filename=DB_NAME)
80
+ shutil.copyfile(cache_path, DB_PATH)
81
+ print("Downloaded DB")
82
+ except Exception as e:
83
+ print("Error while downloading DB:", e)
84
+
85
+ # Create DB table (if doesn't exist)
86
+ create_db_if_missing()
87
+
88
+ # Sync local DB with remote repo every 5 minute (only if a change is detected)
89
+ scheduler = CommitScheduler(
90
+ repo_id=DB_DATASET_ID,
91
+ repo_type="dataset",
92
+ folder_path=Path(DB_PATH).parent,
93
+ every=5,
94
+ allow_patterns=DB_NAME,
95
+ )
96
+
97
+ ####################################
98
+ # Router API
99
+ ####################################
100
+ router = Client("RamAnanth1/3D-Arena-Router", hf_token=os.getenv('HF_TOKEN'))
101
+ ####################################
102
+ # Gradio app
103
+ ####################################
104
+ MUST_BE_LOGGEDIN = "Please login with Hugging Face to participate in the 3D Arena."
105
+ DESCR = """
106
+ # ⚔️3D Arena: Benchmarking Image-to-3D models
107
+
108
+ Vote to help the community find the best Image-to-3D model!
109
+ """.strip()
110
+
111
+ INSTR = """
112
+ ## 🗳️ Vote
113
+
114
+ * Input image to generate a 3D reconstruction.
115
+ * View the responses of the models, one after the other.
116
+ * Vote on which model made a better reconstruction.
117
+ * _Note: Model names are revealed after the vote is cast._
118
+
119
+ Note: It may take up to 60 seconds to get a response.
120
+ """.strip()
121
+ request = ''
122
+ if SPACE_ID:
123
+ request = f"""
124
+ ### Request a model
125
+
126
+ Please [create a Discussion](https://huggingface.co/spaces/{SPACE_ID}/discussions/new) to request a model.
127
+ """
128
+ ABOUT = f"""
129
+ ## 📄 About
130
+
131
+ The 3D Arena evaluates leading 3D generation model. It is inspired by LMsys's [Chatbot Arena](https://chat.lmsys.org/) and [TTS-Arena](https://huggingface.co/spaces/TTS-AGI/TTS-Arena).
132
+
133
+ ### The Arena
134
+
135
+ The leaderboard allows a user to input an image, for which a 3D reconstruction be synthesized by two models. After viewing each sample, the user can vote on which model works better. Due to the risks of human bias and abuse, model names are revealed only after a vote is submitted.
136
+
137
+ {request}
138
+
139
+
140
+ """.strip()
141
+ LDESC = """
142
+ ## 🏆 Leaderboard
143
+
144
+ Vote to help the community find the best Image-to-3D model!
145
+
146
+ The leaderboard displays models in descending order of how suitable the models are (based on votes cast by the community).
147
+
148
+ Important: In order to help keep results fair, the leaderboard hides results by default until the number of votes passes a threshold. Tick the `Reveal preliminary results` to show models without sufficient votes. Please note that preliminary results may be inaccurate.
149
+ """.strip()
150
+
151
+ def del_db(txt):
152
+ if not txt.lower() == 'delete db':
153
+ raise gr.Error('You did not enter "delete db"')
154
+
155
+ # Delete local + remote
156
+ os.remove(DB_PATH)
157
+ delete_file(path_in_repo=DB_NAME, repo_id=DB_DATASET_ID, repo_type='dataset')
158
+
159
+ # Recreate
160
+ create_db_if_missing()
161
+ return 'Delete DB'
162
+
163
+ theme = gr.themes.Monochrome(
164
+ primary_hue="indigo",
165
+ secondary_hue="blue",
166
+ neutral_hue="slate",
167
+ radius_size=gr.themes.sizes.radius_sm,
168
+ font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
169
+ )
170
+ css = ".generating {visibility: hidden}"
171
+
172
+ model_names = {
173
+ 'TripoSR': 'TripoSR',
174
+ 'Shap-E': 'Shap-E',
175
+ }
176
+ model_licenses = {
177
+ 'TripoSR': 'MIT License',
178
+ 'Shap-E': 'MIT License'
179
+ }
180
+ model_links = {
181
+ 'TripoSR': 'https://github.com/VAST-AI-Research/TripoSR',
182
+ 'Shap-E': 'https://github.com/openai/shap-e',
183
+ }
184
+
185
+ def model_license(name):
186
+ print(name)
187
+ for k, v in AVAILABLE_MODELS.items():
188
+ if k == name:
189
+ if v in model_licenses:
190
+ return model_licenses[v]
191
+ print('---')
192
+ return 'Unknown'
193
+ def get_leaderboard(reveal_prelim = False):
194
+ conn = get_db()
195
+ cursor = conn.cursor()
196
+ sql = 'SELECT name, upvote, downvote FROM model'
197
+ # if not reveal_prelim: sql += ' WHERE EXISTS (SELECT 1 FROM model WHERE (upvote + downvote) > 750)'
198
+ if not reveal_prelim: sql += ' WHERE (upvote + downvote) > 500'
199
+ cursor.execute(sql)
200
+ data = cursor.fetchall()
201
+ df = pd.DataFrame(data, columns=['name', 'upvote', 'downvote'])
202
+ # df['license'] = df['name'].map(model_license)
203
+ df['name'] = df['name'].replace(model_names)
204
+ df['votes'] = df['upvote'] + df['downvote']
205
+ # df['score'] = round((df['upvote'] / df['votes']) * 100, 2) # Percentage score
206
+
207
+ ## ELO SCORE
208
+ df['score'] = 1200
209
+ for i in range(len(df)):
210
+ for j in range(len(df)):
211
+ if i != j:
212
+ expected_a = 1 / (1 + 10 ** ((df['score'][j] - df['score'][i]) / 400))
213
+ expected_b = 1 / (1 + 10 ** ((df['score'][i] - df['score'][j]) / 400))
214
+ actual_a = df['upvote'][i] / df['votes'][i]
215
+ actual_b = df['upvote'][j] / df['votes'][j]
216
+ df.at[i, 'score'] += 32 * (actual_a - expected_a)
217
+ df.at[j, 'score'] += 32 * (actual_b - expected_b)
218
+ df['score'] = round(df['score'])
219
+ ## ELO SCORE
220
+ df = df.sort_values(by='score', ascending=False)
221
+ df['order'] = ['#' + str(i + 1) for i in range(len(df))]
222
+ # df = df[['name', 'score', 'upvote', 'votes']]
223
+ # df = df[['order', 'name', 'score', 'license', 'votes']]
224
+ df = df[['order', 'name', 'score', 'votes']]
225
+ return df
226
+
227
+ def mkuuid(uid):
228
+ if not uid:
229
+ uid = uuid.uuid4()
230
+ return uid
231
+
232
+ def upvote_model(model, uname):
233
+ conn = get_db()
234
+ cursor = conn.cursor()
235
+ cursor.execute('UPDATE model SET upvote = upvote + 1 WHERE name = ?', (model,))
236
+ if cursor.rowcount == 0:
237
+ cursor.execute('INSERT OR REPLACE INTO model (name, upvote, downvote) VALUES (?, 1, 0)', (model,))
238
+ cursor.execute('INSERT INTO vote (username, model, vote) VALUES (?, ?, ?)', (uname, model, 1,))
239
+ with scheduler.lock:
240
+ conn.commit()
241
+ cursor.close()
242
+
243
+ def downvote_model(model, uname):
244
+ conn = get_db()
245
+ cursor = conn.cursor()
246
+ cursor.execute('UPDATE model SET downvote = downvote + 1 WHERE name = ?', (model,))
247
+ if cursor.rowcount == 0:
248
+ cursor.execute('INSERT OR REPLACE INTO model (name, upvote, downvote) VALUES (?, 0, 1)', (model,))
249
+ cursor.execute('INSERT INTO vote (username, model, vote) VALUES (?, ?, ?)', (uname, model, -1,))
250
+ with scheduler.lock:
251
+ conn.commit()
252
+ cursor.close()
253
+
254
+ def a_is_better(model1, model2, userid):
255
+ userid = mkuuid(userid)
256
+ if model1 and model2:
257
+ conn = get_db()
258
+ cursor = conn.cursor()
259
+ cursor.execute('INSERT INTO votelog (username, chosen, rejected) VALUES (?, ?, ?)', (str(userid), model1, model2,))
260
+ with scheduler.lock:
261
+ conn.commit()
262
+ cursor.close()
263
+ upvote_model(model1, str(userid))
264
+ downvote_model(model2, str(userid))
265
+ return reload(model1, model2, userid, chose_a=True)
266
+ def b_is_better(model1, model2, userid):
267
+ userid = mkuuid(userid)
268
+ if model1 and model2:
269
+ conn = get_db()
270
+ cursor = conn.cursor()
271
+ cursor.execute('INSERT INTO votelog (username, chosen, rejected) VALUES (?, ?, ?)', (str(userid), model2, model1,))
272
+ with scheduler.lock:
273
+ conn.commit()
274
+ cursor.close()
275
+ upvote_model(model2, str(userid))
276
+ downvote_model(model1, str(userid))
277
+ return reload(model1, model2, userid, chose_b=True)
278
+ def both_bad(model1, model2, userid):
279
+ userid = mkuuid(userid)
280
+ if model1 and model2:
281
+ downvote_model(model1, str(userid))
282
+ downvote_model(model2, str(userid))
283
+ return reload(model1, model2, userid)
284
+ def both_good(model1, model2, userid):
285
+ userid = mkuuid(userid)
286
+ if model1 and model2:
287
+ upvote_model(model1, str(userid))
288
+ upvote_model(model2, str(userid))
289
+ return reload(model1, model2, userid)
290
+ def reload(chosenmodel1=None, chosenmodel2=None, userid=None, chose_a=False, chose_b=False):
291
+ out = [
292
+ gr.update(interactive=False, visible=False),
293
+ gr.update(interactive=False, visible=False)
294
+ ]
295
+ if chose_a == True:
296
+ out.append(gr.update(value=f'Your vote: {chosenmodel1}', interactive=False, visible=True))
297
+ out.append(gr.update(value=f'{chosenmodel2}', interactive=False, visible=True))
298
+ else:
299
+ out.append(gr.update(value=f'{chosenmodel1}', interactive=False, visible=True))
300
+ out.append(gr.update(value=f'Your vote: {chosenmodel2}', interactive=False, visible=True))
301
+ out.append(gr.update(visible=True))
302
+ return out
303
+
304
+ with gr.Blocks() as leaderboard:
305
+ gr.Markdown(LDESC)
306
+ # df = gr.Dataframe(interactive=False, value=get_leaderboard())
307
+ df = gr.Dataframe(interactive=False, min_width=0, wrap=True, column_widths=[30, 200, 50, 50])
308
+ with gr.Row():
309
+ reveal_prelim = gr.Checkbox(label="Reveal preliminary results", info="Show all models, including models with very few human ratings.", scale=1)
310
+ reloadbtn = gr.Button("Refresh", scale=3)
311
+ reveal_prelim.input(get_leaderboard, inputs=[reveal_prelim], outputs=[df])
312
+ leaderboard.load(get_leaderboard, inputs=[reveal_prelim], outputs=[df])
313
+ reloadbtn.click(get_leaderboard, inputs=[reveal_prelim], outputs=[df])
314
+
315
+ def synthandreturn(text):
316
+ text = text.strip()
317
+ if len(text) > MAX_SAMPLE_TXT_LENGTH:
318
+ raise gr.Error(f'You exceeded the limit of {MAX_SAMPLE_TXT_LENGTH} characters')
319
+ if len(text) < MIN_SAMPLE_TXT_LENGTH:
320
+ raise gr.Error(f'Please input a text longer than {MIN_SAMPLE_TXT_LENGTH} characters')
321
+ if (
322
+ # test toxicity
323
+ toxicity.predict(text)['toxicity'] > 0.8
324
+ ):
325
+ print(f'Detected toxic content! "{text}"')
326
+ raise gr.Error('Your text failed the toxicity test')
327
+ if not text:
328
+ raise gr.Error(f'You did not enter any text')
329
+ # Check language
330
+ try:
331
+ if not detect(text) == "en":
332
+ gr.Warning('Warning: The input text may not be in English')
333
+ except:
334
+ pass
335
+ # Get two random models
336
+ mdl1, mdl2 = random.sample(list(AVAILABLE_MODELS.keys()), 2)
337
+ log_text(text)
338
+ print("[debug] Using", mdl1, mdl2)
339
+ def predict_and_update_result(text, model, result_storage):
340
+ try:
341
+ if model in AVAILABLE_MODELS:
342
+ result = router.predict(text, AVAILABLE_MODELS[model].lower(), api_name="/synthesize")
343
+ else:
344
+ result = router.predict(text, model.lower(), api_name="/synthesize")
345
+ except:
346
+ raise gr.Error('Unable to call API, please try again :)')
347
+ print('Done with', model)
348
+ try:
349
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
350
+ audio = AudioSegment.from_file(result)
351
+ current_sr = audio.frame_rate
352
+ if current_sr > 24000:
353
+ audio = audio.set_frame_rate(24000)
354
+ try:
355
+ print('Trying to normalize audio')
356
+ audio = match_target_amplitude(audio, -20)
357
+ except:
358
+ print('[WARN] Unable to normalize audio')
359
+ audio.export(f.name, format="wav")
360
+ os.unlink(result)
361
+ result = f.name
362
+ except:
363
+ pass
364
+
365
+ result_storage[model] = result
366
+
367
+ results = {}
368
+ thread1 = threading.Thread(target=predict_and_update_result, args=(text, mdl1, results))
369
+ thread2 = threading.Thread(target=predict_and_update_result, args=(text, mdl2, results))
370
+ thread1.start()
371
+ thread2.start()
372
+ thread1.join()
373
+ thread2.join()
374
+
375
+ return (
376
+ text,
377
+ "Synthesize",
378
+ gr.update(visible=True), # r2
379
+ mdl1, # model1
380
+ mdl2, # model2
381
+ gr.update(visible=True, value=results[mdl1]), # aud1
382
+ gr.update(visible=True, value=results[mdl2]), # aud2
383
+ gr.update(visible=True, interactive=True),
384
+ gr.update(visible=True, interactive=True),
385
+ gr.update(visible=False),
386
+ gr.update(visible=False),
387
+ gr.update(visible=False), #nxt round btn
388
+ )
389
+
390
+ def clear_stuff():
391
+ return "", "Synthesize", gr.update(visible=False), '', '', gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
392
+
393
+ with gr.Blocks() as vote:
394
+ useridstate = gr.State()
395
+ gr.Markdown(INSTR)
396
+ with gr.Group():
397
+ with gr.Row():
398
+ text = gr.Textbox(container=False, show_label=False, placeholder="Enter text to synthesize", lines=1, max_lines=1, scale=9999999, min_width=0)
399
+ btn = gr.Button("Synthesize", variant='primary')
400
+ model1 = gr.Textbox(interactive=False, lines=1, max_lines=1, visible=False)
401
+ model2 = gr.Textbox(interactive=False, lines=1, max_lines=1, visible=False)
402
+ with gr.Row(visible=False) as r2:
403
+ with gr.Column():
404
+ with gr.Group():
405
+ aud1 = gr.Audio(interactive=False, show_label=False, show_download_button=False, show_share_button=False, waveform_options={'waveform_progress_color': '#3C82F6'})
406
+ abetter = gr.Button("A is better", variant='primary')
407
+ prevmodel1 = gr.Textbox(interactive=False, show_label=False, container=False, value="Vote to reveal model A", text_align="center", lines=1, max_lines=1, visible=False)
408
+ with gr.Column():
409
+ with gr.Group():
410
+ aud2 = gr.Audio(interactive=False, show_label=False, show_download_button=False, show_share_button=False, waveform_options={'waveform_progress_color': '#3C82F6'})
411
+ bbetter = gr.Button("B is better", variant='primary')
412
+ prevmodel2 = gr.Textbox(interactive=False, show_label=False, container=False, value="Vote to reveal model B", text_align="center", lines=1, max_lines=1, visible=False)
413
+ nxtroundbtn = gr.Button('Next round', visible=False)
414
+ outputs = [text, btn, r2, model1, model2, aud1, aud2, abetter, bbetter, prevmodel1, prevmodel2, nxtroundbtn]
415
+ btn.click(synthandreturn, inputs=[text], outputs=outputs)
416
+ nxtroundbtn.click(clear_stuff, outputs=outputs)
417
+
418
+
419
+ nxt_outputs = [abetter, bbetter, prevmodel1, prevmodel2, nxtroundbtn]
420
+ abetter.click(a_is_better, outputs=nxt_outputs, inputs=[model1, model2, useridstate])
421
+ bbetter.click(b_is_better, outputs=nxt_outputs, inputs=[model1, model2, useridstate])
422
+
423
+ with gr.Blocks() as about:
424
+ gr.Markdown(ABOUT)
425
+ with gr.Blocks(theme=theme, css=css, title="3D Arena") as demo:
426
+ gr.Markdown(DESCR)
427
+ gr.TabbedInterface([vote, leaderboard, about], ['🗳️ Vote', '🏆 Leaderboard', '📄 About'])
428
+
429
+ demo.queue(api_open=False, default_concurrency_limit=40).launch(show_api=False)