Spaces:
Runtime error
Runtime error
RamAnanth1
commited on
Initial commit
Browse filesMinor modifications to TTS-arena code
app.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# An arena for 3D generations with code inspired from TTS arena
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import pandas as pd
|
5 |
+
from langdetect import detect
|
6 |
+
from datasets import load_dataset
|
7 |
+
import threading, time, uuid, sqlite3, shutil, os, random, asyncio, threading
|
8 |
+
from pathlib import Path
|
9 |
+
from huggingface_hub import CommitScheduler, delete_file, hf_hub_download
|
10 |
+
from gradio_client import Client
|
11 |
+
from detoxify import Detoxify
|
12 |
+
import os
|
13 |
+
import tempfile
|
14 |
+
|
15 |
+
toxicity = Detoxify('original')
|
16 |
+
|
17 |
+
####################################
|
18 |
+
# Constants
|
19 |
+
####################################
|
20 |
+
AVAILABLE_MODELS = {
|
21 |
+
'TripoSR': 'TripSR',
|
22 |
+
'Shape-E': 'shap-e',
|
23 |
+
}
|
24 |
+
|
25 |
+
SPACE_ID = os.getenv('SPACE_ID')
|
26 |
+
MAX_SAMPLE_TXT_LENGTH = 300
|
27 |
+
MIN_SAMPLE_TXT_LENGTH = 10
|
28 |
+
DB_DATASET_ID = os.getenv('DATASET_ID')
|
29 |
+
DB_NAME = "database.db"
|
30 |
+
|
31 |
+
# If /data available => means local storage is enabled => let's use it!
|
32 |
+
DB_PATH = f"/data/{DB_NAME}" if os.path.isdir("/data") else DB_NAME
|
33 |
+
print(f"Using {DB_PATH}")
|
34 |
+
|
35 |
+
####################################
|
36 |
+
# Functions
|
37 |
+
####################################
|
38 |
+
|
39 |
+
def create_db_if_missing():
|
40 |
+
conn = get_db()
|
41 |
+
cursor = conn.cursor()
|
42 |
+
cursor.execute('''
|
43 |
+
CREATE TABLE IF NOT EXISTS model (
|
44 |
+
name TEXT UNIQUE,
|
45 |
+
upvote INTEGER,
|
46 |
+
downvote INTEGER
|
47 |
+
);
|
48 |
+
''')
|
49 |
+
cursor.execute('''
|
50 |
+
CREATE TABLE IF NOT EXISTS vote (
|
51 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
52 |
+
username TEXT,
|
53 |
+
model TEXT,
|
54 |
+
vote INTEGER,
|
55 |
+
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
56 |
+
);
|
57 |
+
''')
|
58 |
+
cursor.execute('''
|
59 |
+
CREATE TABLE IF NOT EXISTS votelog (
|
60 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
61 |
+
username TEXT,
|
62 |
+
chosen TEXT,
|
63 |
+
rejected TEXT,
|
64 |
+
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
65 |
+
);
|
66 |
+
''')
|
67 |
+
|
68 |
+
def get_db():
|
69 |
+
return sqlite3.connect(DB_PATH)
|
70 |
+
|
71 |
+
####################################
|
72 |
+
# Space initialization
|
73 |
+
####################################
|
74 |
+
|
75 |
+
# Download existing DB
|
76 |
+
if not os.path.isfile(DB_PATH):
|
77 |
+
print("Downloading DB...")
|
78 |
+
try:
|
79 |
+
cache_path = hf_hub_download(repo_id=DB_DATASET_ID, repo_type='dataset', filename=DB_NAME)
|
80 |
+
shutil.copyfile(cache_path, DB_PATH)
|
81 |
+
print("Downloaded DB")
|
82 |
+
except Exception as e:
|
83 |
+
print("Error while downloading DB:", e)
|
84 |
+
|
85 |
+
# Create DB table (if doesn't exist)
|
86 |
+
create_db_if_missing()
|
87 |
+
|
88 |
+
# Sync local DB with remote repo every 5 minute (only if a change is detected)
|
89 |
+
scheduler = CommitScheduler(
|
90 |
+
repo_id=DB_DATASET_ID,
|
91 |
+
repo_type="dataset",
|
92 |
+
folder_path=Path(DB_PATH).parent,
|
93 |
+
every=5,
|
94 |
+
allow_patterns=DB_NAME,
|
95 |
+
)
|
96 |
+
|
97 |
+
####################################
|
98 |
+
# Router API
|
99 |
+
####################################
|
100 |
+
router = Client("RamAnanth1/3D-Arena-Router", hf_token=os.getenv('HF_TOKEN'))
|
101 |
+
####################################
|
102 |
+
# Gradio app
|
103 |
+
####################################
|
104 |
+
MUST_BE_LOGGEDIN = "Please login with Hugging Face to participate in the 3D Arena."
|
105 |
+
DESCR = """
|
106 |
+
# ⚔️3D Arena: Benchmarking Image-to-3D models
|
107 |
+
|
108 |
+
Vote to help the community find the best Image-to-3D model!
|
109 |
+
""".strip()
|
110 |
+
|
111 |
+
INSTR = """
|
112 |
+
## 🗳️ Vote
|
113 |
+
|
114 |
+
* Input image to generate a 3D reconstruction.
|
115 |
+
* View the responses of the models, one after the other.
|
116 |
+
* Vote on which model made a better reconstruction.
|
117 |
+
* _Note: Model names are revealed after the vote is cast._
|
118 |
+
|
119 |
+
Note: It may take up to 60 seconds to get a response.
|
120 |
+
""".strip()
|
121 |
+
request = ''
|
122 |
+
if SPACE_ID:
|
123 |
+
request = f"""
|
124 |
+
### Request a model
|
125 |
+
|
126 |
+
Please [create a Discussion](https://huggingface.co/spaces/{SPACE_ID}/discussions/new) to request a model.
|
127 |
+
"""
|
128 |
+
ABOUT = f"""
|
129 |
+
## 📄 About
|
130 |
+
|
131 |
+
The 3D Arena evaluates leading 3D generation model. It is inspired by LMsys's [Chatbot Arena](https://chat.lmsys.org/) and [TTS-Arena](https://huggingface.co/spaces/TTS-AGI/TTS-Arena).
|
132 |
+
|
133 |
+
### The Arena
|
134 |
+
|
135 |
+
The leaderboard allows a user to input an image, for which a 3D reconstruction be synthesized by two models. After viewing each sample, the user can vote on which model works better. Due to the risks of human bias and abuse, model names are revealed only after a vote is submitted.
|
136 |
+
|
137 |
+
{request}
|
138 |
+
|
139 |
+
|
140 |
+
""".strip()
|
141 |
+
LDESC = """
|
142 |
+
## 🏆 Leaderboard
|
143 |
+
|
144 |
+
Vote to help the community find the best Image-to-3D model!
|
145 |
+
|
146 |
+
The leaderboard displays models in descending order of how suitable the models are (based on votes cast by the community).
|
147 |
+
|
148 |
+
Important: In order to help keep results fair, the leaderboard hides results by default until the number of votes passes a threshold. Tick the `Reveal preliminary results` to show models without sufficient votes. Please note that preliminary results may be inaccurate.
|
149 |
+
""".strip()
|
150 |
+
|
151 |
+
def del_db(txt):
|
152 |
+
if not txt.lower() == 'delete db':
|
153 |
+
raise gr.Error('You did not enter "delete db"')
|
154 |
+
|
155 |
+
# Delete local + remote
|
156 |
+
os.remove(DB_PATH)
|
157 |
+
delete_file(path_in_repo=DB_NAME, repo_id=DB_DATASET_ID, repo_type='dataset')
|
158 |
+
|
159 |
+
# Recreate
|
160 |
+
create_db_if_missing()
|
161 |
+
return 'Delete DB'
|
162 |
+
|
163 |
+
theme = gr.themes.Monochrome(
|
164 |
+
primary_hue="indigo",
|
165 |
+
secondary_hue="blue",
|
166 |
+
neutral_hue="slate",
|
167 |
+
radius_size=gr.themes.sizes.radius_sm,
|
168 |
+
font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
|
169 |
+
)
|
170 |
+
css = ".generating {visibility: hidden}"
|
171 |
+
|
172 |
+
model_names = {
|
173 |
+
'TripoSR': 'TripoSR',
|
174 |
+
'Shap-E': 'Shap-E',
|
175 |
+
}
|
176 |
+
model_licenses = {
|
177 |
+
'TripoSR': 'MIT License',
|
178 |
+
'Shap-E': 'MIT License'
|
179 |
+
}
|
180 |
+
model_links = {
|
181 |
+
'TripoSR': 'https://github.com/VAST-AI-Research/TripoSR',
|
182 |
+
'Shap-E': 'https://github.com/openai/shap-e',
|
183 |
+
}
|
184 |
+
|
185 |
+
def model_license(name):
|
186 |
+
print(name)
|
187 |
+
for k, v in AVAILABLE_MODELS.items():
|
188 |
+
if k == name:
|
189 |
+
if v in model_licenses:
|
190 |
+
return model_licenses[v]
|
191 |
+
print('---')
|
192 |
+
return 'Unknown'
|
193 |
+
def get_leaderboard(reveal_prelim = False):
|
194 |
+
conn = get_db()
|
195 |
+
cursor = conn.cursor()
|
196 |
+
sql = 'SELECT name, upvote, downvote FROM model'
|
197 |
+
# if not reveal_prelim: sql += ' WHERE EXISTS (SELECT 1 FROM model WHERE (upvote + downvote) > 750)'
|
198 |
+
if not reveal_prelim: sql += ' WHERE (upvote + downvote) > 500'
|
199 |
+
cursor.execute(sql)
|
200 |
+
data = cursor.fetchall()
|
201 |
+
df = pd.DataFrame(data, columns=['name', 'upvote', 'downvote'])
|
202 |
+
# df['license'] = df['name'].map(model_license)
|
203 |
+
df['name'] = df['name'].replace(model_names)
|
204 |
+
df['votes'] = df['upvote'] + df['downvote']
|
205 |
+
# df['score'] = round((df['upvote'] / df['votes']) * 100, 2) # Percentage score
|
206 |
+
|
207 |
+
## ELO SCORE
|
208 |
+
df['score'] = 1200
|
209 |
+
for i in range(len(df)):
|
210 |
+
for j in range(len(df)):
|
211 |
+
if i != j:
|
212 |
+
expected_a = 1 / (1 + 10 ** ((df['score'][j] - df['score'][i]) / 400))
|
213 |
+
expected_b = 1 / (1 + 10 ** ((df['score'][i] - df['score'][j]) / 400))
|
214 |
+
actual_a = df['upvote'][i] / df['votes'][i]
|
215 |
+
actual_b = df['upvote'][j] / df['votes'][j]
|
216 |
+
df.at[i, 'score'] += 32 * (actual_a - expected_a)
|
217 |
+
df.at[j, 'score'] += 32 * (actual_b - expected_b)
|
218 |
+
df['score'] = round(df['score'])
|
219 |
+
## ELO SCORE
|
220 |
+
df = df.sort_values(by='score', ascending=False)
|
221 |
+
df['order'] = ['#' + str(i + 1) for i in range(len(df))]
|
222 |
+
# df = df[['name', 'score', 'upvote', 'votes']]
|
223 |
+
# df = df[['order', 'name', 'score', 'license', 'votes']]
|
224 |
+
df = df[['order', 'name', 'score', 'votes']]
|
225 |
+
return df
|
226 |
+
|
227 |
+
def mkuuid(uid):
|
228 |
+
if not uid:
|
229 |
+
uid = uuid.uuid4()
|
230 |
+
return uid
|
231 |
+
|
232 |
+
def upvote_model(model, uname):
|
233 |
+
conn = get_db()
|
234 |
+
cursor = conn.cursor()
|
235 |
+
cursor.execute('UPDATE model SET upvote = upvote + 1 WHERE name = ?', (model,))
|
236 |
+
if cursor.rowcount == 0:
|
237 |
+
cursor.execute('INSERT OR REPLACE INTO model (name, upvote, downvote) VALUES (?, 1, 0)', (model,))
|
238 |
+
cursor.execute('INSERT INTO vote (username, model, vote) VALUES (?, ?, ?)', (uname, model, 1,))
|
239 |
+
with scheduler.lock:
|
240 |
+
conn.commit()
|
241 |
+
cursor.close()
|
242 |
+
|
243 |
+
def downvote_model(model, uname):
|
244 |
+
conn = get_db()
|
245 |
+
cursor = conn.cursor()
|
246 |
+
cursor.execute('UPDATE model SET downvote = downvote + 1 WHERE name = ?', (model,))
|
247 |
+
if cursor.rowcount == 0:
|
248 |
+
cursor.execute('INSERT OR REPLACE INTO model (name, upvote, downvote) VALUES (?, 0, 1)', (model,))
|
249 |
+
cursor.execute('INSERT INTO vote (username, model, vote) VALUES (?, ?, ?)', (uname, model, -1,))
|
250 |
+
with scheduler.lock:
|
251 |
+
conn.commit()
|
252 |
+
cursor.close()
|
253 |
+
|
254 |
+
def a_is_better(model1, model2, userid):
|
255 |
+
userid = mkuuid(userid)
|
256 |
+
if model1 and model2:
|
257 |
+
conn = get_db()
|
258 |
+
cursor = conn.cursor()
|
259 |
+
cursor.execute('INSERT INTO votelog (username, chosen, rejected) VALUES (?, ?, ?)', (str(userid), model1, model2,))
|
260 |
+
with scheduler.lock:
|
261 |
+
conn.commit()
|
262 |
+
cursor.close()
|
263 |
+
upvote_model(model1, str(userid))
|
264 |
+
downvote_model(model2, str(userid))
|
265 |
+
return reload(model1, model2, userid, chose_a=True)
|
266 |
+
def b_is_better(model1, model2, userid):
|
267 |
+
userid = mkuuid(userid)
|
268 |
+
if model1 and model2:
|
269 |
+
conn = get_db()
|
270 |
+
cursor = conn.cursor()
|
271 |
+
cursor.execute('INSERT INTO votelog (username, chosen, rejected) VALUES (?, ?, ?)', (str(userid), model2, model1,))
|
272 |
+
with scheduler.lock:
|
273 |
+
conn.commit()
|
274 |
+
cursor.close()
|
275 |
+
upvote_model(model2, str(userid))
|
276 |
+
downvote_model(model1, str(userid))
|
277 |
+
return reload(model1, model2, userid, chose_b=True)
|
278 |
+
def both_bad(model1, model2, userid):
|
279 |
+
userid = mkuuid(userid)
|
280 |
+
if model1 and model2:
|
281 |
+
downvote_model(model1, str(userid))
|
282 |
+
downvote_model(model2, str(userid))
|
283 |
+
return reload(model1, model2, userid)
|
284 |
+
def both_good(model1, model2, userid):
|
285 |
+
userid = mkuuid(userid)
|
286 |
+
if model1 and model2:
|
287 |
+
upvote_model(model1, str(userid))
|
288 |
+
upvote_model(model2, str(userid))
|
289 |
+
return reload(model1, model2, userid)
|
290 |
+
def reload(chosenmodel1=None, chosenmodel2=None, userid=None, chose_a=False, chose_b=False):
|
291 |
+
out = [
|
292 |
+
gr.update(interactive=False, visible=False),
|
293 |
+
gr.update(interactive=False, visible=False)
|
294 |
+
]
|
295 |
+
if chose_a == True:
|
296 |
+
out.append(gr.update(value=f'Your vote: {chosenmodel1}', interactive=False, visible=True))
|
297 |
+
out.append(gr.update(value=f'{chosenmodel2}', interactive=False, visible=True))
|
298 |
+
else:
|
299 |
+
out.append(gr.update(value=f'{chosenmodel1}', interactive=False, visible=True))
|
300 |
+
out.append(gr.update(value=f'Your vote: {chosenmodel2}', interactive=False, visible=True))
|
301 |
+
out.append(gr.update(visible=True))
|
302 |
+
return out
|
303 |
+
|
304 |
+
with gr.Blocks() as leaderboard:
|
305 |
+
gr.Markdown(LDESC)
|
306 |
+
# df = gr.Dataframe(interactive=False, value=get_leaderboard())
|
307 |
+
df = gr.Dataframe(interactive=False, min_width=0, wrap=True, column_widths=[30, 200, 50, 50])
|
308 |
+
with gr.Row():
|
309 |
+
reveal_prelim = gr.Checkbox(label="Reveal preliminary results", info="Show all models, including models with very few human ratings.", scale=1)
|
310 |
+
reloadbtn = gr.Button("Refresh", scale=3)
|
311 |
+
reveal_prelim.input(get_leaderboard, inputs=[reveal_prelim], outputs=[df])
|
312 |
+
leaderboard.load(get_leaderboard, inputs=[reveal_prelim], outputs=[df])
|
313 |
+
reloadbtn.click(get_leaderboard, inputs=[reveal_prelim], outputs=[df])
|
314 |
+
|
315 |
+
def synthandreturn(text):
|
316 |
+
text = text.strip()
|
317 |
+
if len(text) > MAX_SAMPLE_TXT_LENGTH:
|
318 |
+
raise gr.Error(f'You exceeded the limit of {MAX_SAMPLE_TXT_LENGTH} characters')
|
319 |
+
if len(text) < MIN_SAMPLE_TXT_LENGTH:
|
320 |
+
raise gr.Error(f'Please input a text longer than {MIN_SAMPLE_TXT_LENGTH} characters')
|
321 |
+
if (
|
322 |
+
# test toxicity
|
323 |
+
toxicity.predict(text)['toxicity'] > 0.8
|
324 |
+
):
|
325 |
+
print(f'Detected toxic content! "{text}"')
|
326 |
+
raise gr.Error('Your text failed the toxicity test')
|
327 |
+
if not text:
|
328 |
+
raise gr.Error(f'You did not enter any text')
|
329 |
+
# Check language
|
330 |
+
try:
|
331 |
+
if not detect(text) == "en":
|
332 |
+
gr.Warning('Warning: The input text may not be in English')
|
333 |
+
except:
|
334 |
+
pass
|
335 |
+
# Get two random models
|
336 |
+
mdl1, mdl2 = random.sample(list(AVAILABLE_MODELS.keys()), 2)
|
337 |
+
log_text(text)
|
338 |
+
print("[debug] Using", mdl1, mdl2)
|
339 |
+
def predict_and_update_result(text, model, result_storage):
|
340 |
+
try:
|
341 |
+
if model in AVAILABLE_MODELS:
|
342 |
+
result = router.predict(text, AVAILABLE_MODELS[model].lower(), api_name="/synthesize")
|
343 |
+
else:
|
344 |
+
result = router.predict(text, model.lower(), api_name="/synthesize")
|
345 |
+
except:
|
346 |
+
raise gr.Error('Unable to call API, please try again :)')
|
347 |
+
print('Done with', model)
|
348 |
+
try:
|
349 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
|
350 |
+
audio = AudioSegment.from_file(result)
|
351 |
+
current_sr = audio.frame_rate
|
352 |
+
if current_sr > 24000:
|
353 |
+
audio = audio.set_frame_rate(24000)
|
354 |
+
try:
|
355 |
+
print('Trying to normalize audio')
|
356 |
+
audio = match_target_amplitude(audio, -20)
|
357 |
+
except:
|
358 |
+
print('[WARN] Unable to normalize audio')
|
359 |
+
audio.export(f.name, format="wav")
|
360 |
+
os.unlink(result)
|
361 |
+
result = f.name
|
362 |
+
except:
|
363 |
+
pass
|
364 |
+
|
365 |
+
result_storage[model] = result
|
366 |
+
|
367 |
+
results = {}
|
368 |
+
thread1 = threading.Thread(target=predict_and_update_result, args=(text, mdl1, results))
|
369 |
+
thread2 = threading.Thread(target=predict_and_update_result, args=(text, mdl2, results))
|
370 |
+
thread1.start()
|
371 |
+
thread2.start()
|
372 |
+
thread1.join()
|
373 |
+
thread2.join()
|
374 |
+
|
375 |
+
return (
|
376 |
+
text,
|
377 |
+
"Synthesize",
|
378 |
+
gr.update(visible=True), # r2
|
379 |
+
mdl1, # model1
|
380 |
+
mdl2, # model2
|
381 |
+
gr.update(visible=True, value=results[mdl1]), # aud1
|
382 |
+
gr.update(visible=True, value=results[mdl2]), # aud2
|
383 |
+
gr.update(visible=True, interactive=True),
|
384 |
+
gr.update(visible=True, interactive=True),
|
385 |
+
gr.update(visible=False),
|
386 |
+
gr.update(visible=False),
|
387 |
+
gr.update(visible=False), #nxt round btn
|
388 |
+
)
|
389 |
+
|
390 |
+
def clear_stuff():
|
391 |
+
return "", "Synthesize", gr.update(visible=False), '', '', gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
392 |
+
|
393 |
+
with gr.Blocks() as vote:
|
394 |
+
useridstate = gr.State()
|
395 |
+
gr.Markdown(INSTR)
|
396 |
+
with gr.Group():
|
397 |
+
with gr.Row():
|
398 |
+
text = gr.Textbox(container=False, show_label=False, placeholder="Enter text to synthesize", lines=1, max_lines=1, scale=9999999, min_width=0)
|
399 |
+
btn = gr.Button("Synthesize", variant='primary')
|
400 |
+
model1 = gr.Textbox(interactive=False, lines=1, max_lines=1, visible=False)
|
401 |
+
model2 = gr.Textbox(interactive=False, lines=1, max_lines=1, visible=False)
|
402 |
+
with gr.Row(visible=False) as r2:
|
403 |
+
with gr.Column():
|
404 |
+
with gr.Group():
|
405 |
+
aud1 = gr.Audio(interactive=False, show_label=False, show_download_button=False, show_share_button=False, waveform_options={'waveform_progress_color': '#3C82F6'})
|
406 |
+
abetter = gr.Button("A is better", variant='primary')
|
407 |
+
prevmodel1 = gr.Textbox(interactive=False, show_label=False, container=False, value="Vote to reveal model A", text_align="center", lines=1, max_lines=1, visible=False)
|
408 |
+
with gr.Column():
|
409 |
+
with gr.Group():
|
410 |
+
aud2 = gr.Audio(interactive=False, show_label=False, show_download_button=False, show_share_button=False, waveform_options={'waveform_progress_color': '#3C82F6'})
|
411 |
+
bbetter = gr.Button("B is better", variant='primary')
|
412 |
+
prevmodel2 = gr.Textbox(interactive=False, show_label=False, container=False, value="Vote to reveal model B", text_align="center", lines=1, max_lines=1, visible=False)
|
413 |
+
nxtroundbtn = gr.Button('Next round', visible=False)
|
414 |
+
outputs = [text, btn, r2, model1, model2, aud1, aud2, abetter, bbetter, prevmodel1, prevmodel2, nxtroundbtn]
|
415 |
+
btn.click(synthandreturn, inputs=[text], outputs=outputs)
|
416 |
+
nxtroundbtn.click(clear_stuff, outputs=outputs)
|
417 |
+
|
418 |
+
|
419 |
+
nxt_outputs = [abetter, bbetter, prevmodel1, prevmodel2, nxtroundbtn]
|
420 |
+
abetter.click(a_is_better, outputs=nxt_outputs, inputs=[model1, model2, useridstate])
|
421 |
+
bbetter.click(b_is_better, outputs=nxt_outputs, inputs=[model1, model2, useridstate])
|
422 |
+
|
423 |
+
with gr.Blocks() as about:
|
424 |
+
gr.Markdown(ABOUT)
|
425 |
+
with gr.Blocks(theme=theme, css=css, title="3D Arena") as demo:
|
426 |
+
gr.Markdown(DESCR)
|
427 |
+
gr.TabbedInterface([vote, leaderboard, about], ['🗳️ Vote', '🏆 Leaderboard', '📄 About'])
|
428 |
+
|
429 |
+
demo.queue(api_open=False, default_concurrency_limit=40).launch(show_api=False)
|