taskswithcode commited on
Commit
e227e49
·
1 Parent(s): 57eed52
Files changed (2) hide show
  1. app.py +44 -22
  2. doc_app_models.json +5 -5
app.py CHANGED
@@ -34,13 +34,11 @@ INFO_URL = "http://www.taskswithcode.com/stats/"
34
 
35
 
36
  def get_views(action):
37
- print("in get views",action)
38
  ret_val = 0
39
  hostname = socket.gethostname()
40
  ip_address = socket.gethostbyname(hostname)
41
  if ("view_count" not in st.session_state):
42
  try:
43
- print("inside get views")
44
  app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address}
45
  res = requests.post(INFO_URL, json = app_info).json()
46
  print(res)
@@ -61,7 +59,8 @@ def get_views(action):
61
 
62
  def construct_model_info_for_display(model_names):
63
  options_arr = []
64
- markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b></div>"
 
65
  for node in model_names:
66
  options_arr .append(node["name"])
67
  if (node["mark"] == "True"):
@@ -88,20 +87,19 @@ with col:
88
 
89
 
90
  @st.experimental_memo
91
- def load_model(model_name,model_names):
92
  try:
93
  ret_model = None
94
- for node in model_names:
95
- if (model_name.startswith(node["name"])):
96
- obj_class = globals()[node["class"]]
97
- ret_model = obj_class()
98
- ret_model.init_model(node["model"])
99
  assert(ret_model is not None)
100
  except Exception as e:
101
- st.error("Unable to load model:" + model_name + " " + str(e))
102
  pass
103
  return ret_model
104
 
 
105
 
106
  @st.experimental_memo
107
  def cached_compute_similarity(sentences,_model,model_name,main_index):
@@ -117,18 +115,27 @@ def uncached_compute_similarity(sentences,_model,model_name,main_index):
117
  #st.success("Similarity computation complete")
118
  return results
119
 
 
120
  def get_model_info(model_names,model_name):
121
  for node in model_names:
122
  if (model_name == node["name"]):
123
- return node
 
 
124
 
125
- def run_test(model_names,model_name,sentences,display_area,main_index,user_uploaded):
126
  display_area.text("Loading model:" + model_name)
127
- model_info = get_model_info(model_names,model_name)
 
 
 
 
 
 
128
  if ("Note" in model_info):
129
  fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
130
  display_area.write(fail_link)
131
- model = load_model(model_name,model_names)
132
  display_area.text("Model " + model_name + " load complete")
133
  try:
134
  if (user_uploaded):
@@ -148,9 +155,10 @@ def run_test(model_names,model_name,sentences,display_area,main_index,user_uploa
148
 
149
 
150
 
151
- def display_results(orig_sentences,main_index,results,response_info,app_mode):
152
  main_sent = f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">{response_info}<br/><br/></div>"
153
- score_text = "cosine_distance" if app_mode == SEM_SIMILARITY else "cosine_distance/score"
 
154
  pivot_name = "main sentence" if app_mode == SEM_SIMILARITY else "query"
155
  main_sent += f"<div style=\"font-size:14px; color: #6f6f6f; text-align: left\">Results sorted by {score_text}. Closest to furthest away from {pivot_name}</div>"
156
  pivot_name = pivot_name[0].upper() + pivot_name[1:]
@@ -172,10 +180,14 @@ def display_results(orig_sentences,main_index,results,response_info,app_mode):
172
 
173
 
174
  def init_session():
175
- st.session_state["download_ready"] = None
176
- st.session_state["model_name"] = "ss_test"
177
- st.session_state["main_index"] = 1
178
- st.session_state["file_name"] = "default"
 
 
 
 
179
 
180
  def app_main(app_mode,example_files,model_name_files):
181
  init_session()
@@ -185,6 +197,7 @@ def app_main(app_mode,example_files,model_name_files):
185
  model_names = json.load(fp)
186
  curr_use_case = use_case[app_mode].split(".")[0]
187
  st.markdown("<h5 style='text-align: center;'>Compare popular/state-of-the-art models for tasks using sentence embeddings</h5>", unsafe_allow_html=True)
 
188
  st.markdown(f"<div style='color: #4f4f4f; text-align: left'>Use cases for sentence embeddings<br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;<a href=\'{use_case_url['1']}\' target='_blank'>{use_case['1']}</a><br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;{use_case['2']}<br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;{use_case['3']}<br/><i>This app illustrates <b>'{curr_use_case}'</b> use case</i></div>", unsafe_allow_html=True)
189
  st.markdown(f"<div style='color: #9f9f9f; text-align: right'>views:&nbsp;{get_views('init')}</div>", unsafe_allow_html=True)
190
 
@@ -207,6 +220,9 @@ def app_main(app_mode,example_files,model_name_files):
207
  selected_model = st.selectbox(label=selection_label,
208
  options = options_arr, index=0, key = "twc_model")
209
  st.write("")
 
 
 
210
  if (app_mode == SEM_SIMILARITY):
211
  main_index = st.number_input('Step 3. Enter index of sentence in file to make it the main sentence',value=1,min_value = 1)
212
  else:
@@ -232,14 +248,20 @@ def app_main(app_mode,example_files,model_name_files):
232
  if (len(sentences) > MAX_INPUT):
233
  st.info(f"Input sentence count exceeds maximum sentence limit. First {MAX_INPUT} out of {len(sentences)} sentences chosen")
234
  sentences = sentences[:MAX_INPUT]
 
 
 
 
235
  st.session_state["model_name"] = selected_model
236
  st.session_state["main_index"] = main_index
237
- results = run_test(model_names,selected_model,sentences,display_area,main_index - 1,(uploaded_file is not None))
238
  display_area.empty()
239
  with display_area.container():
240
  device = 'GPU' if torch.cuda.is_available() else 'CPU'
241
  response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for {len(sentences)} sentences"
242
- display_results(sentences,main_index - 1,results,response_info,app_mode)
 
 
243
  #st.json(results)
244
  st.download_button(
245
  label="Download results as json",
 
34
 
35
 
36
  def get_views(action):
 
37
  ret_val = 0
38
  hostname = socket.gethostname()
39
  ip_address = socket.gethostbyname(hostname)
40
  if ("view_count" not in st.session_state):
41
  try:
 
42
  app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address}
43
  res = requests.post(INFO_URL, json = app_info).json()
44
  print(res)
 
59
 
60
  def construct_model_info_for_display(model_names):
61
  options_arr = []
62
+ markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/><i>These are either state-of-the-art or the most downloaded models on Huggingface</i></div>"
63
+ markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
64
  for node in model_names:
65
  options_arr .append(node["name"])
66
  if (node["mark"] == "True"):
 
87
 
88
 
89
  @st.experimental_memo
90
+ def load_model(model_name,model_class,load_model_name):
91
  try:
92
  ret_model = None
93
+ obj_class = globals()[model_class]
94
+ ret_model = obj_class()
95
+ ret_model.init_model(load_model_name)
 
 
96
  assert(ret_model is not None)
97
  except Exception as e:
98
+ st.error("Unable to load model:" + model_name + " " + load_model_name + " " + str(e))
99
  pass
100
  return ret_model
101
 
102
+
103
 
104
  @st.experimental_memo
105
  def cached_compute_similarity(sentences,_model,model_name,main_index):
 
115
  #st.success("Similarity computation complete")
116
  return results
117
 
118
+ DEFAULT_HF_MODEL = "sentence-transformers/paraphrase-MiniLM-L6-v2"
119
  def get_model_info(model_names,model_name):
120
  for node in model_names:
121
  if (model_name == node["name"]):
122
+ return node,model_name
123
+ return get_model_info(model_names,DEFAULT_HF_MODEL)
124
+
125
 
126
+ def run_test(model_names,model_name,sentences,display_area,main_index,user_uploaded,custom_model):
127
  display_area.text("Loading model:" + model_name)
128
+ #Note. model_name may get mapped to new name in the call below for custom models
129
+ orig_model_name = model_name
130
+ model_info,model_name = get_model_info(model_names,model_name)
131
+ if (model_name != orig_model_name):
132
+ load_model_name = orig_model_name
133
+ else:
134
+ load_model_name = model_info["model"]
135
  if ("Note" in model_info):
136
  fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
137
  display_area.write(fail_link)
138
+ model = load_model(model_name,model_info["class"],load_model_name)
139
  display_area.text("Model " + model_name + " load complete")
140
  try:
141
  if (user_uploaded):
 
155
 
156
 
157
 
158
+ def display_results(orig_sentences,main_index,results,response_info,app_mode,model_name):
159
  main_sent = f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">{response_info}<br/><br/></div>"
160
+ main_sent += f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">Showing results for model:&nbsp;<b>{model_name}</b></div>"
161
+ score_text = "cosine distance" if app_mode == SEM_SIMILARITY else "cosine distance/score"
162
  pivot_name = "main sentence" if app_mode == SEM_SIMILARITY else "query"
163
  main_sent += f"<div style=\"font-size:14px; color: #6f6f6f; text-align: left\">Results sorted by {score_text}. Closest to furthest away from {pivot_name}</div>"
164
  pivot_name = pivot_name[0].upper() + pivot_name[1:]
 
180
 
181
 
182
  def init_session():
183
+ if ("model_name" not in st.session_state):
184
+ st.session_state["model_name"] = "ss_test"
185
+ st.session_state["download_ready"] = None
186
+ st.session_state["model_name"] = "ss_test"
187
+ st.session_state["main_index"] = 1
188
+ st.session_state["file_name"] = "default"
189
+ else:
190
+ print("Skipping init session")
191
 
192
  def app_main(app_mode,example_files,model_name_files):
193
  init_session()
 
197
  model_names = json.load(fp)
198
  curr_use_case = use_case[app_mode].split(".")[0]
199
  st.markdown("<h5 style='text-align: center;'>Compare popular/state-of-the-art models for tasks using sentence embeddings</h5>", unsafe_allow_html=True)
200
+ st.markdown(f"<p style='font-size:14px; color: #4f4f4f; text-align: center'><i>Or compare your own model with state-of-the-art/popular models</p>", unsafe_allow_html=True)
201
  st.markdown(f"<div style='color: #4f4f4f; text-align: left'>Use cases for sentence embeddings<br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;<a href=\'{use_case_url['1']}\' target='_blank'>{use_case['1']}</a><br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;{use_case['2']}<br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;{use_case['3']}<br/><i>This app illustrates <b>'{curr_use_case}'</b> use case</i></div>", unsafe_allow_html=True)
202
  st.markdown(f"<div style='color: #9f9f9f; text-align: right'>views:&nbsp;{get_views('init')}</div>", unsafe_allow_html=True)
203
 
 
220
  selected_model = st.selectbox(label=selection_label,
221
  options = options_arr, index=0, key = "twc_model")
222
  st.write("")
223
+ custom_model_selection = st.text_input("Model not listed above? Type any Huggingface semantic search model name ", "",key="custom_model")
224
+ hf_link_str = "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><a href='https://huggingface.co/models?pipeline_tag=sentence-similarity' target = '_blank'>List of Huggingface semantic search models</a><br/><br/><br/></div>"
225
+ st.markdown(hf_link_str, unsafe_allow_html=True)
226
  if (app_mode == SEM_SIMILARITY):
227
  main_index = st.number_input('Step 3. Enter index of sentence in file to make it the main sentence',value=1,min_value = 1)
228
  else:
 
248
  if (len(sentences) > MAX_INPUT):
249
  st.info(f"Input sentence count exceeds maximum sentence limit. First {MAX_INPUT} out of {len(sentences)} sentences chosen")
250
  sentences = sentences[:MAX_INPUT]
251
+ if (len(custom_model_selection) != 0):
252
+ run_model = custom_model_selection
253
+ else:
254
+ run_model = selected_model
255
  st.session_state["model_name"] = selected_model
256
  st.session_state["main_index"] = main_index
257
+ results = run_test(model_names,run_model,sentences,display_area,main_index - 1,(uploaded_file is not None),(len(custom_model_selection) != 0))
258
  display_area.empty()
259
  with display_area.container():
260
  device = 'GPU' if torch.cuda.is_available() else 'CPU'
261
  response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for {len(sentences)} sentences"
262
+ if (len(custom_model_selection) != 0):
263
+ st.info("Custom model overrides model selection in step 2 above. So please clear the custom model text box to choose models from step 2")
264
+ display_results(sentences,main_index - 1,results,response_info,app_mode,run_model)
265
  #st.json(results)
266
  st.download_button(
267
  label="Download results as json",
doc_app_models.json CHANGED
@@ -30,7 +30,7 @@
30
  "orig_author_url":"https://github.com/UKPLab",
31
  "orig_author":"Ubiquitous Knowledge Processing Lab",
32
  "sota_info": {
33
- "task":"Over 3.8 million downloads from huggingface",
34
  "sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2"
35
  },
36
  "paper_url":"https://arxiv.org/abs/1908.10084",
@@ -42,7 +42,7 @@
42
  "orig_author_url":"https://github.com/UKPLab",
43
  "orig_author":"Ubiquitous Knowledge Processing Lab",
44
  "sota_info": {
45
- "task":"Over 2 million downloads from huggingface",
46
  "sota_link":"https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2"
47
  },
48
  "paper_url":"https://arxiv.org/abs/1908.10084",
@@ -54,7 +54,7 @@
54
  "orig_author_url":"https://github.com/UKPLab",
55
  "orig_author":"Ubiquitous Knowledge Processing Lab",
56
  "sota_info": {
57
- "task":"Over 700,000 downloads from huggingface",
58
  "sota_link":"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens"
59
  },
60
  "paper_url":"https://arxiv.org/abs/1908.10084",
@@ -66,7 +66,7 @@
66
  "orig_author_url":"https://github.com/UKPLab",
67
  "orig_author":"Ubiquitous Knowledge Processing Lab",
68
  "sota_info": {
69
- "task":"Over 500,000 downloads from huggingface",
70
  "sota_link":"https://huggingface.co/sentence-transformers/all-mpnet-base-v2"
71
  },
72
  "paper_url":"https://arxiv.org/abs/1908.10084",
@@ -78,7 +78,7 @@
78
  "orig_author_url":"https://github.com/UKPLab",
79
  "orig_author":"Ubiquitous Knowledge Processing Lab",
80
  "sota_info": {
81
- "task":"Over 500,000 downloads from huggingface",
82
  "sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2"
83
  },
84
  "paper_url":"https://arxiv.org/abs/1908.10084",
 
30
  "orig_author_url":"https://github.com/UKPLab",
31
  "orig_author":"Ubiquitous Knowledge Processing Lab",
32
  "sota_info": {
33
+ "task":"Over 3.8 million downloads from Huggingface",
34
  "sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2"
35
  },
36
  "paper_url":"https://arxiv.org/abs/1908.10084",
 
42
  "orig_author_url":"https://github.com/UKPLab",
43
  "orig_author":"Ubiquitous Knowledge Processing Lab",
44
  "sota_info": {
45
+ "task":"Over 2 million downloads from Huggingface",
46
  "sota_link":"https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2"
47
  },
48
  "paper_url":"https://arxiv.org/abs/1908.10084",
 
54
  "orig_author_url":"https://github.com/UKPLab",
55
  "orig_author":"Ubiquitous Knowledge Processing Lab",
56
  "sota_info": {
57
+ "task":"Over 700,000 downloads from Huggingface",
58
  "sota_link":"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens"
59
  },
60
  "paper_url":"https://arxiv.org/abs/1908.10084",
 
66
  "orig_author_url":"https://github.com/UKPLab",
67
  "orig_author":"Ubiquitous Knowledge Processing Lab",
68
  "sota_info": {
69
+ "task":"Over 500,000 downloads from Huggingface",
70
  "sota_link":"https://huggingface.co/sentence-transformers/all-mpnet-base-v2"
71
  },
72
  "paper_url":"https://arxiv.org/abs/1908.10084",
 
78
  "orig_author_url":"https://github.com/UKPLab",
79
  "orig_author":"Ubiquitous Knowledge Processing Lab",
80
  "sota_info": {
81
+ "task":"Over 500,000 downloads from Huggingface",
82
  "sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2"
83
  },
84
  "paper_url":"https://arxiv.org/abs/1908.10084",