Spaces:
Build error
Build error
taskswithcode
commited on
Commit
·
aea620e
1
Parent(s):
0f0b570
Added
Browse files- app.py +292 -0
- clus_app_examples.json +5 -0
- clus_app_models.json +90 -0
- imdb_sent.txt +62 -0
- larger_test.txt +52 -0
- long_form_logo_with_icon.png +0 -0
- requirements.txt +4 -0
- run.sh +2 -0
- small_test.txt +30 -0
- twc_clustering.py +93 -0
- twc_embeddings.py +407 -0
app.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import sys
|
3 |
+
import streamlit as st
|
4 |
+
import string
|
5 |
+
from io import StringIO
|
6 |
+
import pdb
|
7 |
+
import json
|
8 |
+
from twc_embeddings import HFModel,SimCSEModel,SGPTModel,CausalLMModel,SGPTQnAModel
|
9 |
+
from twc_clustering import TWCClustering
|
10 |
+
import torch
|
11 |
+
import requests
|
12 |
+
import socket
|
13 |
+
|
14 |
+
|
15 |
+
MAX_INPUT = 100
|
16 |
+
|
17 |
+
SEM_SIMILARITY="1"
|
18 |
+
DOC_RETRIEVAL="2"
|
19 |
+
CLUSTERING="3"
|
20 |
+
|
21 |
+
|
22 |
+
use_case = {"1":"Finding similar phrases/sentences","2":"Retrieving semantically matching information to a query. It may not be a factual match","3":"Clustering"}
|
23 |
+
use_case_url = {"1":"https://huggingface.co/spaces/taskswithcode/semantic_similarity","2":"https://huggingface.co/spaces/taskswithcode/semantic_search","3":""}
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
28 |
+
|
29 |
+
|
30 |
+
APP_NAME = "hf/semantic_clustering"
|
31 |
+
INFO_URL = "http://www.taskswithcode.com/stats/"
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
def get_views(action):
|
38 |
+
ret_val = 0
|
39 |
+
hostname = socket.gethostname()
|
40 |
+
ip_address = socket.gethostbyname(hostname)
|
41 |
+
if ("view_count" not in st.session_state):
|
42 |
+
try:
|
43 |
+
app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address}
|
44 |
+
res = requests.post(INFO_URL, json = app_info).json()
|
45 |
+
print(res)
|
46 |
+
data = res["count"]
|
47 |
+
except:
|
48 |
+
data = 0
|
49 |
+
ret_val = data
|
50 |
+
st.session_state["view_count"] = data
|
51 |
+
else:
|
52 |
+
ret_val = st.session_state["view_count"]
|
53 |
+
if (action != "init"):
|
54 |
+
app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address}
|
55 |
+
res = requests.post(INFO_URL, json = app_info).json()
|
56 |
+
return "{:,}".format(ret_val)
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
def construct_model_info_for_display(model_names):
|
62 |
+
options_arr = []
|
63 |
+
markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/><i>These are either state-of-the-art or the most downloaded models on Huggingface</i></div>"
|
64 |
+
markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
|
65 |
+
for node in model_names:
|
66 |
+
options_arr .append(node["name"])
|
67 |
+
if (node["mark"] == "True"):
|
68 |
+
markdown_str += f"<div style=\"font-size:16px; color: #5f5f5f; text-align: left\"> • Model: <a href=\'{node['paper_url']}\' target='_blank'>{node['name']}</a><br/> Code released by: <a href=\'{node['orig_author_url']}\' target='_blank'>{node['orig_author']}</a><br/> Model info: <a href=\'{node['sota_info']['sota_link']}\' target='_blank'>{node['sota_info']['task']}</a></div>"
|
69 |
+
if ("Note" in node):
|
70 |
+
markdown_str += f"<div style=\"font-size:16px; color: #a91212; text-align: left\"> {node['Note']}<a href=\'{node['alt_url']}\' target='_blank'>link</a></div>"
|
71 |
+
markdown_str += "<div style=\"font-size:16px; color: #5f5f5f; text-align: left\"><br/></div>"
|
72 |
+
|
73 |
+
markdown_str += "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><b>Note:</b><br/>• Uploaded files are loaded into non-persistent memory for the duration of the computation. They are not cached</div>"
|
74 |
+
limit = "{:,}".format(MAX_INPUT)
|
75 |
+
markdown_str += f"<div style=\"font-size:12px; color: #9f9f9f; text-align: left\">• User uploaded file has a maximum limit of {limit} sentences.</div>"
|
76 |
+
return options_arr,markdown_str
|
77 |
+
|
78 |
+
|
79 |
+
st.set_page_config(page_title='TWC - Compare popular/state-of-the-art models for tasks using sentence embeddings', page_icon="logo.jpg", layout='centered', initial_sidebar_state='auto',
|
80 |
+
menu_items={
|
81 |
+
'About': 'This app was created by taskswithcode. http://taskswithcode.com'
|
82 |
+
|
83 |
+
})
|
84 |
+
col,pad = st.columns([85,15])
|
85 |
+
|
86 |
+
with col:
|
87 |
+
st.image("long_form_logo_with_icon.png")
|
88 |
+
|
89 |
+
|
90 |
+
@st.experimental_memo
|
91 |
+
def load_model(model_name,model_class,load_model_name):
|
92 |
+
try:
|
93 |
+
ret_model = None
|
94 |
+
obj_class = globals()[model_class]
|
95 |
+
ret_model = obj_class()
|
96 |
+
ret_model.init_model(load_model_name)
|
97 |
+
assert(ret_model is not None)
|
98 |
+
except Exception as e:
|
99 |
+
st.error("Unable to load model:" + model_name + " " + load_model_name + " " + str(e))
|
100 |
+
pass
|
101 |
+
return ret_model
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
@st.experimental_memo
|
106 |
+
def cached_compute_similarity(sentences,_model,model_name,threshold,_cluster):
|
107 |
+
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
108 |
+
results = _cluster.cluster(None,texts,embeddings,threshold)
|
109 |
+
return results
|
110 |
+
|
111 |
+
|
112 |
+
def uncached_compute_similarity(sentences,_model,model_name,threshold,cluster):
|
113 |
+
with st.spinner('Computing vectors for sentences'):
|
114 |
+
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
|
115 |
+
results = cluster.cluster(None,texts,embeddings,threshold)
|
116 |
+
#st.success("Similarity computation complete")
|
117 |
+
return results
|
118 |
+
|
119 |
+
DEFAULT_HF_MODEL = "sentence-transformers/paraphrase-MiniLM-L6-v2"
|
120 |
+
def get_model_info(model_names,model_name):
|
121 |
+
for node in model_names:
|
122 |
+
if (model_name == node["name"]):
|
123 |
+
return node,model_name
|
124 |
+
return get_model_info(model_names,DEFAULT_HF_MODEL)
|
125 |
+
|
126 |
+
|
127 |
+
def run_test(model_names,model_name,sentences,display_area,threshold,user_uploaded,custom_model):
|
128 |
+
display_area.text("Loading model:" + model_name)
|
129 |
+
#Note. model_name may get mapped to new name in the call below for custom models
|
130 |
+
orig_model_name = model_name
|
131 |
+
model_info,model_name = get_model_info(model_names,model_name)
|
132 |
+
if (model_name != orig_model_name):
|
133 |
+
load_model_name = orig_model_name
|
134 |
+
else:
|
135 |
+
load_model_name = model_info["model"]
|
136 |
+
if ("Note" in model_info):
|
137 |
+
fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
|
138 |
+
display_area.write(fail_link)
|
139 |
+
model = load_model(model_name,model_info["class"],load_model_name)
|
140 |
+
display_area.text("Model " + model_name + " load complete")
|
141 |
+
try:
|
142 |
+
if (user_uploaded):
|
143 |
+
results = uncached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"])
|
144 |
+
else:
|
145 |
+
display_area.text("Computing vectors for sentences")
|
146 |
+
results = cached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"])
|
147 |
+
display_area.text("Similarity computation complete")
|
148 |
+
return results
|
149 |
+
|
150 |
+
except Exception as e:
|
151 |
+
st.error("Some error occurred during prediction" + str(e))
|
152 |
+
st.stop()
|
153 |
+
return {}
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
def display_results(orig_sentences,results,response_info,app_mode,model_name):
|
160 |
+
main_sent = f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">{response_info}<br/><br/></div>"
|
161 |
+
main_sent += f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">Showing results for model: <b>{model_name}</b></div>"
|
162 |
+
score_text = "cosine distance"
|
163 |
+
main_sent += f"<div style=\"font-size:14px; color: #6f6f6f; text-align: left\">Clustering by {score_text}. <b>{len(results['clusters'])} clusters</b>. mean:{results['info']['mean']:.2f} std:{results['info']['std']:.2f} threshold hints:{str(results['info']['zscores'])}</div>"
|
164 |
+
body_sent = []
|
165 |
+
download_data = {}
|
166 |
+
for i in range(len(results["clusters"])):
|
167 |
+
pivot_index = results["clusters"][i]["pivot_index"]
|
168 |
+
pivot_sent = orig_sentences[pivot_index]
|
169 |
+
pivot_index += 1
|
170 |
+
d_cluster = {}
|
171 |
+
download_data[i + 1] = d_cluster
|
172 |
+
d_cluster["pivot"] = {"pivot_index":pivot_index,"sent":pivot_sent,"children":{}}
|
173 |
+
body_sent.append(f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\">{pivot_index}] {pivot_sent} <b><i>(Cluster {i+1})</i></b> </div>")
|
174 |
+
neighs_dict = results["clusters"][i]["neighs"]
|
175 |
+
for key in neighs_dict:
|
176 |
+
cosine_dist = neighs_dict[key]
|
177 |
+
child_index = key
|
178 |
+
sentence = orig_sentences[child_index]
|
179 |
+
child_index += 1
|
180 |
+
body_sent.append(f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\">{child_index}] {sentence} <b>{cosine_dist:.2f}</b></div>")
|
181 |
+
d_cluster["pivot"]["children"][sentence] = f"{cosine_dist:.2f}"
|
182 |
+
body_sent.append(f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"> </div>")
|
183 |
+
main_sent = main_sent + "\n" + '\n'.join(body_sent)
|
184 |
+
st.markdown(main_sent,unsafe_allow_html=True)
|
185 |
+
st.session_state["download_ready"] = json.dumps(download_data,indent=4)
|
186 |
+
get_views("submit")
|
187 |
+
|
188 |
+
|
189 |
+
def init_session():
|
190 |
+
if ("model_name" not in st.session_state):
|
191 |
+
st.session_state["model_name"] = "ss_test"
|
192 |
+
st.session_state["download_ready"] = None
|
193 |
+
st.session_state["model_name"] = "ss_test"
|
194 |
+
st.session_state["threshold"] = 1.5
|
195 |
+
st.session_state["file_name"] = "default"
|
196 |
+
st.session_state["cluster"] = TWCClustering()
|
197 |
+
else:
|
198 |
+
print("Skipping init session")
|
199 |
+
|
200 |
+
def app_main(app_mode,example_files,model_name_files):
|
201 |
+
init_session()
|
202 |
+
with open(example_files) as fp:
|
203 |
+
example_file_names = json.load(fp)
|
204 |
+
with open(model_name_files) as fp:
|
205 |
+
model_names = json.load(fp)
|
206 |
+
curr_use_case = use_case[app_mode].split(".")[0]
|
207 |
+
st.markdown("<h5 style='text-align: center;'>Compare popular/state-of-the-art models for tasks using sentence embeddings</h5>", unsafe_allow_html=True)
|
208 |
+
st.markdown(f"<p style='font-size:14px; color: #4f4f4f; text-align: center'><i>Or compare your own model with state-of-the-art/popular models</p>", unsafe_allow_html=True)
|
209 |
+
st.markdown(f"<div style='color: #4f4f4f; text-align: left'>Use cases for sentence embeddings<br/> • <a href=\'{use_case_url['1']}\' target='_blank'>{use_case['1']}</a><br/> • <a href=\'{use_case_url['2']}\' target='_blank'>{use_case['2']}</a><br/> • {use_case['3']}<br/><i>This app illustrates <b>'{curr_use_case}'</b> use case</i></div>", unsafe_allow_html=True)
|
210 |
+
st.markdown(f"<div style='color: #9f9f9f; text-align: right'>views: {get_views('init')}</div>", unsafe_allow_html=True)
|
211 |
+
|
212 |
+
|
213 |
+
try:
|
214 |
+
|
215 |
+
|
216 |
+
with st.form('twc_form'):
|
217 |
+
|
218 |
+
step1_line = "Step 1. Upload text file(one sentence in a line) or choose an example text file below"
|
219 |
+
if (app_mode == DOC_RETRIEVAL):
|
220 |
+
step1_line += ". The first line is treated as the query"
|
221 |
+
uploaded_file = st.file_uploader(step1_line, type=".txt")
|
222 |
+
|
223 |
+
selected_file_index = st.selectbox(label=f'Example files ({len(example_file_names)})',
|
224 |
+
options = list(dict.keys(example_file_names)), index=0, key = "twc_file")
|
225 |
+
st.write("")
|
226 |
+
options_arr,markdown_str = construct_model_info_for_display(model_names)
|
227 |
+
selection_label = 'Step 2. Select Model'
|
228 |
+
selected_model = st.selectbox(label=selection_label,
|
229 |
+
options = options_arr, index=0, key = "twc_model")
|
230 |
+
st.write("")
|
231 |
+
custom_model_selection = st.text_input("Model not listed above? Type any Huggingface semantic search model name ", "",key="custom_model")
|
232 |
+
hf_link_str = "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><a href='https://huggingface.co/models?pipeline_tag=sentence-similarity' target = '_blank'>List of Huggingface semantic search models</a><br/><br/><br/></div>"
|
233 |
+
st.markdown(hf_link_str, unsafe_allow_html=True)
|
234 |
+
threshold = st.number_input('Step 3. Choose a zscore threshold (number of std devs from mean)',value=st.session_state["threshold"],min_value = 0.0,step=.01)
|
235 |
+
st.write("")
|
236 |
+
submit_button = st.form_submit_button('Run')
|
237 |
+
|
238 |
+
|
239 |
+
input_status_area = st.empty()
|
240 |
+
display_area = st.empty()
|
241 |
+
if submit_button:
|
242 |
+
start = time.time()
|
243 |
+
if uploaded_file is not None:
|
244 |
+
st.session_state["file_name"] = uploaded_file.name
|
245 |
+
sentences = StringIO(uploaded_file.getvalue().decode("utf-8")).read()
|
246 |
+
else:
|
247 |
+
st.session_state["file_name"] = example_file_names[selected_file_index]["name"]
|
248 |
+
sentences = open(example_file_names[selected_file_index]["name"]).read()
|
249 |
+
sentences = sentences.split("\n")[:-1]
|
250 |
+
if (len(sentences) > MAX_INPUT):
|
251 |
+
st.info(f"Input sentence count exceeds maximum sentence limit. First {MAX_INPUT} out of {len(sentences)} sentences chosen")
|
252 |
+
sentences = sentences[:MAX_INPUT]
|
253 |
+
if (len(custom_model_selection) != 0):
|
254 |
+
run_model = custom_model_selection
|
255 |
+
else:
|
256 |
+
run_model = selected_model
|
257 |
+
st.session_state["model_name"] = selected_model
|
258 |
+
st.session_state["threshold"] = threshold
|
259 |
+
results = run_test(model_names,run_model,sentences,display_area,threshold,(uploaded_file is not None),(len(custom_model_selection) != 0))
|
260 |
+
display_area.empty()
|
261 |
+
with display_area.container():
|
262 |
+
device = 'GPU' if torch.cuda.is_available() else 'CPU'
|
263 |
+
response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for {len(sentences)} sentences"
|
264 |
+
if (len(custom_model_selection) != 0):
|
265 |
+
st.info("Custom model overrides model selection in step 2 above. So please clear the custom model text box to choose models from step 2")
|
266 |
+
display_results(sentences,results,response_info,app_mode,run_model)
|
267 |
+
#st.json(results)
|
268 |
+
st.download_button(
|
269 |
+
label="Download results as json",
|
270 |
+
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
|
271 |
+
disabled = False if st.session_state["download_ready"] != None else True,
|
272 |
+
file_name= (st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/","_"),
|
273 |
+
mime='text/json',
|
274 |
+
key ="download"
|
275 |
+
)
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
except Exception as e:
|
280 |
+
st.error("Some error occurred during loading" + str(e))
|
281 |
+
st.stop()
|
282 |
+
|
283 |
+
st.markdown(markdown_str, unsafe_allow_html=True)
|
284 |
+
|
285 |
+
|
286 |
+
|
287 |
+
if __name__ == "__main__":
|
288 |
+
#print("comand line input:",len(sys.argv),str(sys.argv))
|
289 |
+
#app_main(sys.argv[1],sys.argv[2],sys.argv[3])
|
290 |
+
#app_main("1","sim_app_examples.json","sim_app_models.json")
|
291 |
+
app_main("3","clus_app_examples.json","clus_app_models.json")
|
292 |
+
|
clus_app_examples.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"Machine learning terms (phrases test)": {"name":"small_test.txt"},
|
3 |
+
"Customer feedback mixed with noise":{"name":"larger_test.txt"},
|
4 |
+
"Movie reviews": {"name":"imdb_sent.txt"}
|
5 |
+
}
|
clus_app_models.json
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
|
3 |
+
{ "name":"sentence-transformers/all-MiniLM-L6-v2",
|
4 |
+
"model":"sentence-transformers/all-MiniLM-L6-v2",
|
5 |
+
"fork_url":"https://github.com/taskswithcode/sentence_similarity_hf_model",
|
6 |
+
"orig_author_url":"https://github.com/UKPLab",
|
7 |
+
"orig_author":"Ubiquitous Knowledge Processing Lab",
|
8 |
+
"sota_info": {
|
9 |
+
"task":"Over 3.8 million downloads from Huggingface",
|
10 |
+
"sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2"
|
11 |
+
},
|
12 |
+
"paper_url":"https://arxiv.org/abs/1908.10084",
|
13 |
+
"mark":"True",
|
14 |
+
"class":"HFModel"},
|
15 |
+
{ "name":"sentence-transformers/paraphrase-MiniLM-L6-v2",
|
16 |
+
"model":"sentence-transformers/paraphrase-MiniLM-L6-v2",
|
17 |
+
"fork_url":"https://github.com/taskswithcode/sentence_similarity_hf_model",
|
18 |
+
"orig_author_url":"https://github.com/UKPLab",
|
19 |
+
"orig_author":"Ubiquitous Knowledge Processing Lab",
|
20 |
+
"sota_info": {
|
21 |
+
"task":"Over 2 million downloads from Huggingface",
|
22 |
+
"sota_link":"https://huggingface.co/sentence-transformers/paraphrase-MiniLM-L6-v2"
|
23 |
+
},
|
24 |
+
"paper_url":"https://arxiv.org/abs/1908.10084",
|
25 |
+
"mark":"True",
|
26 |
+
"class":"HFModel"},
|
27 |
+
{ "name":"sentence-transformers/bert-base-nli-mean-tokens",
|
28 |
+
"model":"sentence-transformers/bert-base-nli-mean-tokens",
|
29 |
+
"fork_url":"https://github.com/taskswithcode/sentence_similarity_hf_model",
|
30 |
+
"orig_author_url":"https://github.com/UKPLab",
|
31 |
+
"orig_author":"Ubiquitous Knowledge Processing Lab",
|
32 |
+
"sota_info": {
|
33 |
+
"task":"Over 700,000 downloads from Huggingface",
|
34 |
+
"sota_link":"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens"
|
35 |
+
},
|
36 |
+
"paper_url":"https://arxiv.org/abs/1908.10084",
|
37 |
+
"mark":"True",
|
38 |
+
"class":"HFModel"},
|
39 |
+
{ "name":"sentence-transformers/all-mpnet-base-v2",
|
40 |
+
"model":"sentence-transformers/all-mpnet-base-v2",
|
41 |
+
"fork_url":"https://github.com/taskswithcode/sentence_similarity_hf_model",
|
42 |
+
"orig_author_url":"https://github.com/UKPLab",
|
43 |
+
"orig_author":"Ubiquitous Knowledge Processing Lab",
|
44 |
+
"sota_info": {
|
45 |
+
"task":"Over 500,000 downloads from Huggingface",
|
46 |
+
"sota_link":"https://huggingface.co/sentence-transformers/all-mpnet-base-v2"
|
47 |
+
},
|
48 |
+
"paper_url":"https://arxiv.org/abs/1908.10084",
|
49 |
+
"mark":"True",
|
50 |
+
"class":"HFModel"},
|
51 |
+
{ "name":"sentence-transformers/all-MiniLM-L12-v2",
|
52 |
+
"model":"sentence-transformers/all-MiniLM-L12-v2",
|
53 |
+
"fork_url":"https://github.com/taskswithcode/sentence_similarity_hf_model",
|
54 |
+
"orig_author_url":"https://github.com/UKPLab",
|
55 |
+
"orig_author":"Ubiquitous Knowledge Processing Lab",
|
56 |
+
"sota_info": {
|
57 |
+
"task":"Over 500,000 downloads from Huggingface",
|
58 |
+
"sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2"
|
59 |
+
},
|
60 |
+
"paper_url":"https://arxiv.org/abs/1908.10084",
|
61 |
+
"mark":"True",
|
62 |
+
"class":"HFModel"},
|
63 |
+
|
64 |
+
{ "name":"SGPT-125M",
|
65 |
+
"model":"Muennighoff/SGPT-125M-weightedmean-nli-bitfit",
|
66 |
+
"fork_url":"https://github.com/taskswithcode/sgpt",
|
67 |
+
"orig_author_url":"https://github.com/Muennighoff",
|
68 |
+
"orig_author":"Niklas Muennighoff",
|
69 |
+
"sota_info": {
|
70 |
+
"task":"#1 in multiple information retrieval & search tasks(smaller variant)",
|
71 |
+
"sota_link":"https://paperswithcode.com/paper/sgpt-gpt-sentence-embeddings-for-semantic"
|
72 |
+
},
|
73 |
+
"paper_url":"https://arxiv.org/abs/2202.08904v5",
|
74 |
+
"mark":"True",
|
75 |
+
"class":"SGPTModel"},
|
76 |
+
{ "name":"SIMCSE-base" ,
|
77 |
+
"model":"princeton-nlp/sup-simcse-roberta-base",
|
78 |
+
"fork_url":"https://github.com/taskswithcode/SimCSE",
|
79 |
+
"orig_author_url":"https://github.com/princeton-nlp",
|
80 |
+
"orig_author":"Princeton Natural Language Processing",
|
81 |
+
"sota_info": {
|
82 |
+
"task":"Within top 10 in multiple semantic textual similarity tasks(smaller variant)",
|
83 |
+
"sota_link":"https://paperswithcode.com/paper/simcse-simple-contrastive-learning-of"
|
84 |
+
},
|
85 |
+
"paper_url":"https://arxiv.org/abs/2104.08821v4",
|
86 |
+
"mark":"True",
|
87 |
+
"class":"SimCSEModel","sota_link":"https://paperswithcode.com/sota/semantic-textual-similarity-on-sick"}
|
88 |
+
|
89 |
+
|
90 |
+
]
|
imdb_sent.txt
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"A rating of ""1"" does not begin to express how dull, depressing and relentlessly bad this movie is."
|
2 |
+
Hated it with all my being. Worst movie ever. Mentally- scarred. Help me. It was that bad.TRUST ME!!!
|
3 |
+
"Long, boring, blasphemous. Never have I been so glad to see ending credits roll."
|
4 |
+
This film made John Glover a star. Alan Raimy is one of the most compelling character that I have ever seen on film. And I mean that sport.
|
5 |
+
"Were I not with friends, and so cheap, I would have walked out. It failed miserably as satire and didn't even have the redemption of camp."
|
6 |
+
For pure gothic vampire cheese nothing can compare to the Subspecies films. I highly recommend each and every one of them.
|
7 |
+
"A great film in its genre, the direction, acting, most especially the casting of the film makes it even more powerful. A must see."
|
8 |
+
"This is a terrible movie, don't waste your money on it. Don't even watch it for free. That's all I have to say."
|
9 |
+
I wouldn't rent this one even on dollar rental night.
|
10 |
+
"More suspenseful, more subtle, much, much more disturbing...."
|
11 |
+
This is a good film. This is very funny. Yet after this film there were no good Ernest films!
|
12 |
+
A touching movie. It is full of emotions and wonderful acting. I could have sat through it a second time.
|
13 |
+
"Great movie - especially the music - Etta James - ""At Last"". This speaks volumes when you have finally found that special someone."
|
14 |
+
If you've ever had a mad week-end out with your mates then you'll appreciate this film. Excellent fun and a laugh a minute.
|
15 |
+
"I think it's one of the greatest movies which are ever made, and I've seen many... The book is better, but it's still a very good movie!"
|
16 |
+
Brilliant and moving performances by Tom Courtenay and Peter Finch.
|
17 |
+
The characters are unlikeable and the script is awful. It's a waste of the talents of Deneuve and Auteuil.
|
18 |
+
You've got to be kidding. This movie sucked for the sci-fi fans. I would only recommend watching this only if you think Armageddon was good.
|
19 |
+
Ten minutes of people spewing gallons of pink vomit. Recurring scenes of enormous piles of dog excrement - need one say more???
|
20 |
+
"As usual, Sean Connery does a great job. Lawrence Fishburn is good, but I have a hard time not seeing him as Ike Turner."
|
21 |
+
This movie is terrible but it has some good effects.
|
22 |
+
You'd better choose Paul Verhoeven's even if you have watched it.
|
23 |
+
"Brilliant. Ranks along with Citizen Kane, The Matrix and Godfathers. Must see, at least for basset in her early days. Watch it."
|
24 |
+
"I don't know why I like this movie so well, but I never get tired of watching it."
|
25 |
+
The one-liners fly so fast in this movie that you can watch it over and over and still catch new ones. By far one of the best of this genre.
|
26 |
+
"Don't waste your time and money on it. It's not quite as bad as ""Adrenalin"", by the same director but that's not saying much."
|
27 |
+
"Read the book, forget the movie!"
|
28 |
+
This is a great movie. Too bad it is not available on home video.
|
29 |
+
"Very intelligent language usage of Ali, which you musn't miss! In one word: (eeh sentence...) Wicked, so keep it real and pass it on!"
|
30 |
+
Primary plot!Primary direction!Poor interpretation.
|
31 |
+
"If you like Pauly Shore, you'll love Son in Law. If you hate Pauly Shore, then, well...I liked it!"
|
32 |
+
Just love the interplay between two great characters of stage & screen - Veidt & Barrymore
|
33 |
+
"This movie will always be a Broadway and Movie classic, as long as there are still people who sing, dance, and act."
|
34 |
+
This is the greatest movie ever. If you have written it off with out ever seeing it. You must give it a second try.
|
35 |
+
"What a script, what a story, what a mess!"
|
36 |
+
"I caught this film late at night on HBO. Talk about wooden acting, unbelievable plot, et al. Very little going in its favor. Skip it."
|
37 |
+
This is without a doubt the worst movie I have ever seen. It is not funny. It is not interesting and should not have been made.
|
38 |
+
Ming The Merciless does a little Bardwork and a movie most foul!
|
39 |
+
This is quite possibly the worst sequel ever made. The script is unfunny and the acting stinks. The exact opposite of the original.
|
40 |
+
"This is the definitive movie version of Hamlet. Branagh cuts nothing, but there are no wasted moments."
|
41 |
+
My favorite movie. What a great story this really was. I'd just like to be able to buy a copy of it but this does not seem possible.
|
42 |
+
"Comment this movie is impossible. Is terrible, very improbable, bad interpretation e direction. Not look!!!!!"
|
43 |
+
"Brilliant movie. The drawings were just amazing. Too bad it ended before it begun. I´ve waited 21 years for a sequel, but nooooo!!!"
|
44 |
+
a mesmerizing film that certainly keeps your attention... Ben Daniels is fascinating (and courageous) to watch.
|
45 |
+
"This is a very cool movie. The ending of the movie is a bit more defined than the play's ending, but either way it is still a good movie."
|
46 |
+
"Without a doubt, one of Tobe Hoppor's best! Epic storytellng, great special effects, and The Spacegirl (vamp me baby!)."
|
47 |
+
I hope this group of film-makers never re-unites.
|
48 |
+
Unwatchable. You can't even make it past the first three minutes. And this is coming from a huge Adam Sandler fan!!1
|
49 |
+
"One of the funniest movies made in recent years. Good characterization, plot and exceptional chemistry make this one a classic"
|
50 |
+
"Add this little gem to your list of holiday regulars. It is sweet, funny, and endearing"
|
51 |
+
"no comment - stupid movie, acting average or worse... screenplay - no sense at all... SKIP IT!"
|
52 |
+
"If you haven't seen this, it's terrible. It is pure trash. I saw this about 17 years ago, and I'm still screwed up from it."
|
53 |
+
Absolutely fantastic! Whatever I say wouldn't do this underrated movie the justice it deserves. Watch it now! FANTASTIC!
|
54 |
+
"As a big fan of Tiny Toon Adventures, I loved this movie!!! It was so funny!!! It really captured how cartoons spent their summers."
|
55 |
+
Widow hires a psychopath as a handyman. Sloppy film noir thriller which doesn't make much of its tension promising set-up. (3/10)
|
56 |
+
The Fiendish Plot of Dr. Fu Manchu (1980). This is hands down the worst film I've ever seen. What a sad way for a great comedian to go out.
|
57 |
+
"Obviously written for the stage. Lightweight but worthwhile. How can you go wrong with Ralph Richardson, Olivier and Merle Oberon."
|
58 |
+
This movie turned out to be better than I had expected it to be. Some parts were pretty funny. It was nice to have a movie with a new plot.
|
59 |
+
This movie is terrible. It's about some no brain surfin dude that inherits some company. Does Carrot Top have no shame?
|
60 |
+
Adrian Pasdar is excellent is this film. He makes a fascinating woman.
|
61 |
+
"An unfunny, unworthy picture which is an undeserving end to Peter Sellers' career. It is a pity this movie was ever made."
|
62 |
+
"The plot was really weak and confused. This is a true Oprah flick. (In Oprah's world, all men are evil and all women are victims.)"
|
larger_test.txt
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
do u really want me to unistall this app...whenever i open this app, u ask for review...how many times i have to give review-feedback...
|
2 |
+
I don't like how it asks to give a review everytime I open the app
|
3 |
+
Stop asking for review everytime I open the app..it's pathetic..the updated version sucks
|
4 |
+
If i already provided the review for this application but why this application is asking for reviews every time when I am opening the application so improve this feature. This feature is very irritating. Apart from that overall experience is very good.
|
5 |
+
as you guys bother me so much for the review even i gave my opinion already but every time i open the app it ask for review so i gave it 1 star , previsiousally it was 4 star.
|
6 |
+
repeatedly asking to rate the app...
|
7 |
+
Very irritating. Everytime i open app it asked for review hence giving 2 instead of 4
|
8 |
+
stop asking for ratings every time when open the app. I had rated this app 5 star but now every time app asking for give rating, its disgusting. so I'll give only one star
|
9 |
+
I swear if i see that feedback ad one more time im gonna uninstall this app and start using another one else
|
10 |
+
I'm am downgrading my rating because the app is good and I also gave it 5 satar but why I am getting unnecessary pop up to give it review please fix it
|
11 |
+
No rating ... worsted app ... please playstore delete this app
|
12 |
+
Much bad experience . when I used to open the app it requires feedback every time.
|
13 |
+
I already rated it then why always it pop up... Its irritate me a lot everytime when I open this app... Plz fix this
|
14 |
+
This app any time ask me for rating i hate this
|
15 |
+
Very Good app but asks to rate it all the time....all these popups are annoying when you are in hurry
|
16 |
+
I'm already rated this app. And now from one week and adove this app is asking for rating please solve the problem as soon as possible. Thank you
|
17 |
+
The app is too good but it send me notification again and again to rate it that's why am I giving one star to it
|
18 |
+
Constantly asks me to rate the app! So annoying.
|
19 |
+
Today again i am go to rating this app due to its ad less and best interface with good features again more
|
20 |
+
App is very disturbing .. very bad app
|
21 |
+
If I don't want to rate it's my personal choice, so why this app gives notification every single time,, it's quite frustrating therefore 1star Other wise app is best for it's work
|
22 |
+
For frustrating me every time to rate your app
|
23 |
+
Super exalent app can you pls reply to my comment how is my review so thanks to provide this app thanks
|
24 |
+
Vey bad app so disturbance ... All time get notification about rating... That allready done
|
25 |
+
Earlier I had given 5 stars to this app but even after giving review in this app, it speaks to rate now, so I removed 5 stars in edit review and put 1 star, now this app will be happy.
|
26 |
+
Every time I open the app it asking rating. I rated 4 before now I de rate to 1."
|
27 |
+
I love this app. So damn good
|
28 |
+
This app rocks!!!!!!
|
29 |
+
This app totally sucks
|
30 |
+
Wow what a useful app
|
31 |
+
I cant live without this app!
|
32 |
+
Shit! This app rocks. I can never imagine going out without using this app. So damn useful!
|
33 |
+
Elon musk is the founder of SpaceX
|
34 |
+
Parasites suck blood out of deer
|
35 |
+
A review of his conduct revealed he violated the rules everytime he downloaded movies
|
36 |
+
My god. If only I could rate this app 100 stars for its excellence
|
37 |
+
The board conducted a review and determined electons were fair
|
38 |
+
WTF!
|
39 |
+
Crossing the chasm is a great book review that is often quoted by readers
|
40 |
+
Why am I seeing everything in double like I m drink - is my vision going bad???
|
41 |
+
Expolanets keep going round and round their stars many times a day
|
42 |
+
I have recommended this app to so many friends and they love it too
|
43 |
+
The sale of electric cars has gone up since the increse in gas prices
|
44 |
+
Stable diffusion app is the rage on the internet with multiple people either downloading on the laptop and trying it or useing the web interface
|
45 |
+
OpenAI is trying to make money by exposin their NLP apps through an API
|
46 |
+
Co:here, Ai21 and other are trying to emulate OpenAIs business model and exposing NLP apps that depend on LLMs through metered APIs
|
47 |
+
Serverless GPUs have emerged as a new business model catering to end users who want to host apps
|
48 |
+
Cerberas, Sambanova are betting on models to grow larger and harder to train on traditional GPUS
|
49 |
+
Nvidia released Hopper series as a successor to A100 series
|
50 |
+
Oh my god! I am done with this app!
|
51 |
+
Oh my god! I love this sweet puppy. He rounds around the chair so many times
|
52 |
+
I plan to write a nasty review for this shitty movie
|
long_form_logo_with_icon.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
scipy
|
3 |
+
torch
|
4 |
+
sentencepiece
|
run.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
streamlit run app.py --server.port 80 "1" "sim_app_examples.json" "sim_app_models.json"
|
2 |
+
|
small_test.txt
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
machine learning
|
2 |
+
Transformers have become the staple architecture for deep learning models
|
3 |
+
NLP
|
4 |
+
Diffusion models
|
5 |
+
natural language processing
|
6 |
+
deep learning
|
7 |
+
Deep Learning
|
8 |
+
Support vector machines
|
9 |
+
random forests
|
10 |
+
probability distribution
|
11 |
+
Cross entropy loss
|
12 |
+
Kullback leibler divergence
|
13 |
+
Shannon entropy
|
14 |
+
Activation functions
|
15 |
+
ATM
|
16 |
+
deep fakes
|
17 |
+
AGI
|
18 |
+
AI
|
19 |
+
deep trouble
|
20 |
+
artificial intelligence
|
21 |
+
deep diving
|
22 |
+
artificial snow
|
23 |
+
shallow waters
|
24 |
+
deep end
|
25 |
+
RELU
|
26 |
+
sigmoid
|
27 |
+
GELU
|
28 |
+
RNN
|
29 |
+
CNN
|
30 |
+
Gaussian
|
twc_clustering.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.spatial.distance import cosine
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import pdb
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import numpy as np
|
8 |
+
import time
|
9 |
+
from collections import OrderedDict
|
10 |
+
|
11 |
+
|
12 |
+
class TWCClustering:
|
13 |
+
def __init__(self):
|
14 |
+
print("In Zscore Clustering")
|
15 |
+
|
16 |
+
def compute_matrix(self,embeddings):
|
17 |
+
print("Computing similarity matrix ...)")
|
18 |
+
embeddings= np.array(embeddings)
|
19 |
+
start = time.time()
|
20 |
+
vec_a = embeddings.T #vec_a shape (1024,)
|
21 |
+
vec_a = vec_a/np.linalg.norm(vec_a,axis=0) #Norm is along axis 0 - rows
|
22 |
+
vec_a = vec_a.T #vec_a shape becomes (,1024)
|
23 |
+
similarity_matrix = np.inner(vec_a,vec_a)
|
24 |
+
end = time.time()
|
25 |
+
time_val = (end-start)*1000
|
26 |
+
print(f"Similarity matrix computation complete. Time taken:{(time_val/(1000*60)):.2f} minutes")
|
27 |
+
return similarity_matrix
|
28 |
+
|
29 |
+
def get_terms_above_threshold(self,matrix,embeddings,pivot_index,threshold):
|
30 |
+
run_index = pivot_index
|
31 |
+
picked_arr = []
|
32 |
+
while (run_index < len(embeddings)):
|
33 |
+
if (matrix[pivot_index][run_index] >= threshold):
|
34 |
+
#picked_arr.append({"index":run_index,"val":matrix[pivot_index][run_index]})
|
35 |
+
picked_arr.append({"index":run_index})
|
36 |
+
run_index += 1
|
37 |
+
return picked_arr
|
38 |
+
|
39 |
+
def update_picked_dict(self,picked_dict,in_dict):
|
40 |
+
for key in in_dict:
|
41 |
+
picked_dict[key] = 1
|
42 |
+
|
43 |
+
def find_pivot_subgraph(self,pivot_index,arr,matrix,threshold):
|
44 |
+
center_index = pivot_index
|
45 |
+
center_score = 0
|
46 |
+
center_dict = {}
|
47 |
+
for i in range(len(arr)):
|
48 |
+
node_i_index = arr[i]["index"]
|
49 |
+
running_score = 0
|
50 |
+
temp_dict = {}
|
51 |
+
for j in range(len(arr)):
|
52 |
+
node_j_index = arr[j]["index"]
|
53 |
+
cosine_dist = matrix[node_i_index][node_j_index]
|
54 |
+
if (cosine_dist < threshold):
|
55 |
+
continue
|
56 |
+
running_score += cosine_dist
|
57 |
+
temp_dict[node_j_index] = cosine_dist
|
58 |
+
if (running_score > center_score):
|
59 |
+
center_index = node_i_index
|
60 |
+
center_dict = temp_dict
|
61 |
+
center_score = running_score
|
62 |
+
sorted_d = OrderedDict(sorted(center_dict.items(), key=lambda kv: kv[1], reverse=True))
|
63 |
+
return {"pivot_index":center_index,"orig_index":pivot_index,"neighs":sorted_d}
|
64 |
+
|
65 |
+
|
66 |
+
def cluster(self,output_file,texts,embeddings,threshold = 1.5):
|
67 |
+
matrix = self.compute_matrix(embeddings)
|
68 |
+
mean = np.mean(matrix)
|
69 |
+
std = np.std(matrix)
|
70 |
+
zscores = []
|
71 |
+
inc = 0
|
72 |
+
value = mean
|
73 |
+
while (value < 1):
|
74 |
+
zscores.append(round(value,2))
|
75 |
+
inc += 1
|
76 |
+
value = mean + inc*std
|
77 |
+
print("In clustering:",round(std,2),zscores)
|
78 |
+
cluster_dict = {}
|
79 |
+
cluster_dict["clusters"] = []
|
80 |
+
picked_dict = {}
|
81 |
+
|
82 |
+
for i in range(len(embeddings)):
|
83 |
+
if (i in picked_dict):
|
84 |
+
continue
|
85 |
+
zscore = mean + threshold*std
|
86 |
+
arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
|
87 |
+
cluster_info = self.find_pivot_subgraph(i,arr,matrix,zscore)
|
88 |
+
self.update_picked_dict(picked_dict,cluster_info["neighs"])
|
89 |
+
cluster_dict["clusters"].append(cluster_info)
|
90 |
+
cluster_dict["info"] ={"mean":mean,"std":std,"zscores":zscores}
|
91 |
+
return cluster_dict
|
92 |
+
|
93 |
+
|
twc_embeddings.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer
|
2 |
+
from transformers import AutoModelForCausalLM
|
3 |
+
from scipy.spatial.distance import cosine
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
import pdb
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
def read_text(input_file):
|
11 |
+
arr = open(input_file).read().split("\n")
|
12 |
+
return arr[:-1]
|
13 |
+
|
14 |
+
|
15 |
+
class CausalLMModel:
|
16 |
+
def __init__(self):
|
17 |
+
self.model = None
|
18 |
+
self.tokenizer = None
|
19 |
+
self.debug = False
|
20 |
+
print("In CausalLMModel Constructor")
|
21 |
+
|
22 |
+
def init_model(self,model_name = None):
|
23 |
+
# Get our models - The package will take care of downloading the models automatically
|
24 |
+
# For best performance: Muennighoff/SGPT-5.8B-weightedmean-nli-bitfit
|
25 |
+
if (self.debug):
|
26 |
+
print("Init model",model_name)
|
27 |
+
# For best performance: EleutherAI/gpt-j-6B
|
28 |
+
if (model_name is None):
|
29 |
+
model_name = "EleutherAI/gpt-neo-125M"
|
30 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
31 |
+
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
32 |
+
self.model.eval()
|
33 |
+
self.prompt = 'Documents are searched to find matches with the same content.\nThe document "{}" is a good search result for "'
|
34 |
+
|
35 |
+
def compute_embeddings(self,input_data,is_file):
|
36 |
+
if (self.debug):
|
37 |
+
print("Computing embeddings for:", input_data[:20])
|
38 |
+
model = self.model
|
39 |
+
tokenizer = self.tokenizer
|
40 |
+
|
41 |
+
texts = read_text(input_data) if is_file == True else input_data
|
42 |
+
query = texts[0]
|
43 |
+
docs = texts[1:]
|
44 |
+
|
45 |
+
# Tokenize input texts
|
46 |
+
|
47 |
+
#print(f"Query: {query}")
|
48 |
+
scores = []
|
49 |
+
for doc in docs:
|
50 |
+
context = self.prompt.format(doc)
|
51 |
+
|
52 |
+
context_enc = tokenizer.encode(context, add_special_tokens=False)
|
53 |
+
continuation_enc = tokenizer.encode(query, add_special_tokens=False)
|
54 |
+
# Slice off the last token, as we take its probability from the one before
|
55 |
+
model_input = torch.tensor(context_enc+continuation_enc[:-1])
|
56 |
+
continuation_len = len(continuation_enc)
|
57 |
+
input_len, = model_input.shape
|
58 |
+
|
59 |
+
# [seq_len] -> [seq_len, vocab]
|
60 |
+
logprobs = torch.nn.functional.log_softmax(model(model_input)[0], dim=-1).cpu()
|
61 |
+
# [seq_len, vocab] -> [continuation_len, vocab]
|
62 |
+
logprobs = logprobs[input_len-continuation_len:]
|
63 |
+
# Gather the log probabilities of the continuation tokens -> [continuation_len]
|
64 |
+
logprobs = torch.gather(logprobs, 1, torch.tensor(continuation_enc).unsqueeze(-1)).squeeze(-1)
|
65 |
+
score = torch.sum(logprobs)
|
66 |
+
scores.append(score.tolist())
|
67 |
+
return texts,scores
|
68 |
+
|
69 |
+
def output_results(self,output_file,texts,scores,main_index = 0):
|
70 |
+
cosine_dict = {}
|
71 |
+
docs = texts[1:]
|
72 |
+
if (self.debug):
|
73 |
+
print("Total sentences",len(texts))
|
74 |
+
assert(len(scores) == len(docs))
|
75 |
+
for i in range(len(docs)):
|
76 |
+
cosine_dict[docs[i]] = scores[i]
|
77 |
+
|
78 |
+
if (self.debug):
|
79 |
+
print("Input sentence:",texts[main_index])
|
80 |
+
sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
|
81 |
+
if (self.debug):
|
82 |
+
for key in sorted_dict:
|
83 |
+
print("Document score for \"%s\" is: %.3f" % (key[:100], sorted_dict[key]))
|
84 |
+
if (output_file is not None):
|
85 |
+
with open(output_file,"w") as fp:
|
86 |
+
fp.write(json.dumps(sorted_dict,indent=0))
|
87 |
+
return sorted_dict
|
88 |
+
|
89 |
+
|
90 |
+
class SGPTQnAModel:
|
91 |
+
def __init__(self):
|
92 |
+
self.model = None
|
93 |
+
self.tokenizer = None
|
94 |
+
self.debug = False
|
95 |
+
print("In SGPT Q&A Constructor")
|
96 |
+
|
97 |
+
|
98 |
+
def init_model(self,model_name = None):
|
99 |
+
# Get our models - The package will take care of downloading the models automatically
|
100 |
+
# For best performance: Muennighoff/SGPT-5.8B-weightedmean-nli-bitfit
|
101 |
+
if (self.debug):
|
102 |
+
print("Init model",model_name)
|
103 |
+
if (model_name is None):
|
104 |
+
model_name = "Muennighoff/SGPT-125M-weightedmean-msmarco-specb-bitfit"
|
105 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
106 |
+
self.model = AutoModel.from_pretrained(model_name)
|
107 |
+
self.model.eval()
|
108 |
+
self.SPECB_QUE_BOS = self.tokenizer.encode("[", add_special_tokens=False)[0]
|
109 |
+
self.SPECB_QUE_EOS = self.tokenizer.encode("]", add_special_tokens=False)[0]
|
110 |
+
|
111 |
+
self.SPECB_DOC_BOS = self.tokenizer.encode("{", add_special_tokens=False)[0]
|
112 |
+
self.SPECB_DOC_EOS = self.tokenizer.encode("}", add_special_tokens=False)[0]
|
113 |
+
|
114 |
+
|
115 |
+
def tokenize_with_specb(self,texts, is_query):
|
116 |
+
# Tokenize without padding
|
117 |
+
batch_tokens = self.tokenizer(texts, padding=False, truncation=True)
|
118 |
+
# Add special brackets & pay attention to them
|
119 |
+
for seq, att in zip(batch_tokens["input_ids"], batch_tokens["attention_mask"]):
|
120 |
+
if is_query:
|
121 |
+
seq.insert(0, self.SPECB_QUE_BOS)
|
122 |
+
seq.append(self.SPECB_QUE_EOS)
|
123 |
+
else:
|
124 |
+
seq.insert(0, self.SPECB_DOC_BOS)
|
125 |
+
seq.append(self.SPECB_DOC_EOS)
|
126 |
+
att.insert(0, 1)
|
127 |
+
att.append(1)
|
128 |
+
# Add padding
|
129 |
+
batch_tokens = self.tokenizer.pad(batch_tokens, padding=True, return_tensors="pt")
|
130 |
+
return batch_tokens
|
131 |
+
|
132 |
+
def get_weightedmean_embedding(self,batch_tokens, model):
|
133 |
+
# Get the embeddings
|
134 |
+
with torch.no_grad():
|
135 |
+
# Get hidden state of shape [bs, seq_len, hid_dim]
|
136 |
+
last_hidden_state = self.model(**batch_tokens, output_hidden_states=True, return_dict=True).last_hidden_state
|
137 |
+
|
138 |
+
# Get weights of shape [bs, seq_len, hid_dim]
|
139 |
+
weights = (
|
140 |
+
torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
|
141 |
+
.unsqueeze(0)
|
142 |
+
.unsqueeze(-1)
|
143 |
+
.expand(last_hidden_state.size())
|
144 |
+
.float().to(last_hidden_state.device)
|
145 |
+
)
|
146 |
+
|
147 |
+
# Get attn mask of shape [bs, seq_len, hid_dim]
|
148 |
+
input_mask_expanded = (
|
149 |
+
batch_tokens["attention_mask"]
|
150 |
+
.unsqueeze(-1)
|
151 |
+
.expand(last_hidden_state.size())
|
152 |
+
.float()
|
153 |
+
)
|
154 |
+
|
155 |
+
# Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
|
156 |
+
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
|
157 |
+
sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
|
158 |
+
|
159 |
+
embeddings = sum_embeddings / sum_mask
|
160 |
+
|
161 |
+
return embeddings
|
162 |
+
|
163 |
+
def compute_embeddings(self,input_data,is_file):
|
164 |
+
if (self.debug):
|
165 |
+
print("Computing embeddings for:", input_data[:20])
|
166 |
+
model = self.model
|
167 |
+
tokenizer = self.tokenizer
|
168 |
+
|
169 |
+
texts = read_text(input_data) if is_file == True else input_data
|
170 |
+
|
171 |
+
queries = [texts[0]]
|
172 |
+
docs = texts[1:]
|
173 |
+
query_embeddings = self.get_weightedmean_embedding(self.tokenize_with_specb(queries, is_query=True), self.model)
|
174 |
+
doc_embeddings = self.get_weightedmean_embedding(self.tokenize_with_specb(docs, is_query=False), self.model)
|
175 |
+
return texts,(query_embeddings,doc_embeddings)
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
def output_results(self,output_file,texts,embeddings,main_index = 0):
|
180 |
+
# Calculate cosine similarities
|
181 |
+
# Cosine similarities are in [-1, 1]. Higher means more similar
|
182 |
+
query_embeddings = embeddings[0]
|
183 |
+
doc_embeddings = embeddings[1]
|
184 |
+
cosine_dict = {}
|
185 |
+
queries = [texts[0]]
|
186 |
+
docs = texts[1:]
|
187 |
+
if (self.debug):
|
188 |
+
print("Total sentences",len(texts))
|
189 |
+
for i in range(len(docs)):
|
190 |
+
cosine_dict[docs[i]] = 1 - cosine(query_embeddings[0], doc_embeddings[i])
|
191 |
+
|
192 |
+
if (self.debug):
|
193 |
+
print("Input sentence:",texts[main_index])
|
194 |
+
sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
|
195 |
+
if (self.debug):
|
196 |
+
for key in sorted_dict:
|
197 |
+
print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
|
198 |
+
if (output_file is not None):
|
199 |
+
with open(output_file,"w") as fp:
|
200 |
+
fp.write(json.dumps(sorted_dict,indent=0))
|
201 |
+
return sorted_dict
|
202 |
+
|
203 |
+
|
204 |
+
class SimCSEModel:
|
205 |
+
def __init__(self):
|
206 |
+
self.model = None
|
207 |
+
self.tokenizer = None
|
208 |
+
self.debug = False
|
209 |
+
print("In SimCSE constructor")
|
210 |
+
|
211 |
+
def init_model(self,model_name = None):
|
212 |
+
if (model_name == None):
|
213 |
+
model_name = "princeton-nlp/sup-simcse-roberta-large"
|
214 |
+
#self.model = SimCSE(model_name)
|
215 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
216 |
+
self.model = AutoModel.from_pretrained(model_name)
|
217 |
+
|
218 |
+
def compute_embeddings(self,input_data,is_file):
|
219 |
+
texts = read_text(input_data) if is_file == True else input_data
|
220 |
+
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
221 |
+
with torch.no_grad():
|
222 |
+
embeddings = self.model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
|
223 |
+
return texts,embeddings
|
224 |
+
|
225 |
+
def output_results(self,output_file,texts,embeddings,main_index = 0):
|
226 |
+
# Calculate cosine similarities
|
227 |
+
# Cosine similarities are in [-1, 1]. Higher means more similar
|
228 |
+
cosine_dict = {}
|
229 |
+
#print("Total sentences",len(texts))
|
230 |
+
for i in range(len(texts)):
|
231 |
+
cosine_dict[texts[i]] = 1 - cosine(embeddings[main_index], embeddings[i])
|
232 |
+
|
233 |
+
#print("Input sentence:",texts[main_index])
|
234 |
+
sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
|
235 |
+
if (self.debug):
|
236 |
+
for key in sorted_dict:
|
237 |
+
print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
|
238 |
+
if (output_file is not None):
|
239 |
+
with open(output_file,"w") as fp:
|
240 |
+
fp.write(json.dumps(sorted_dict,indent=0))
|
241 |
+
return sorted_dict
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
class SGPTModel:
|
246 |
+
def __init__(self):
|
247 |
+
self.model = None
|
248 |
+
self.tokenizer = None
|
249 |
+
self.debug = False
|
250 |
+
print("In SGPT Constructor")
|
251 |
+
|
252 |
+
|
253 |
+
def init_model(self,model_name = None):
|
254 |
+
# Get our models - The package will take care of downloading the models automatically
|
255 |
+
# For best performance: Muennighoff/SGPT-5.8B-weightedmean-nli-bitfit
|
256 |
+
if (self.debug):
|
257 |
+
print("Init model",model_name)
|
258 |
+
if (model_name is None):
|
259 |
+
model_name = "Muennighoff/SGPT-125M-weightedmean-nli-bitfit"
|
260 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
261 |
+
self.model = AutoModel.from_pretrained(model_name)
|
262 |
+
#self.tokenizer = AutoTokenizer.from_pretrained("Muennighoff/SGPT-1.3B-weightedmean-msmarco-specb-bitfit")
|
263 |
+
#self.model = AutoModel.from_pretrained("Muennighoff/SGPT-1.3B-weightedmean-msmarco-specb-bitfit")
|
264 |
+
#self.tokenizer = AutoTokenizer.from_pretrained("Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit")
|
265 |
+
#self.model = AutoModel.from_pretrained("Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit")
|
266 |
+
# Deactivate Dropout (There is no dropout in the above models so it makes no difference here but other SGPT models may have dropout)
|
267 |
+
self.model.eval()
|
268 |
+
|
269 |
+
def compute_embeddings(self,input_data,is_file):
|
270 |
+
if (self.debug):
|
271 |
+
print("Computing embeddings for:", input_data[:20])
|
272 |
+
model = self.model
|
273 |
+
tokenizer = self.tokenizer
|
274 |
+
|
275 |
+
texts = read_text(input_data) if is_file == True else input_data
|
276 |
+
|
277 |
+
# Tokenize input texts
|
278 |
+
batch_tokens = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
279 |
+
|
280 |
+
# Get the embeddings
|
281 |
+
with torch.no_grad():
|
282 |
+
# Get hidden state of shape [bs, seq_len, hid_dim]
|
283 |
+
last_hidden_state = model(**batch_tokens, output_hidden_states=True, return_dict=True).last_hidden_state
|
284 |
+
|
285 |
+
# Get weights of shape [bs, seq_len, hid_dim]
|
286 |
+
weights = (
|
287 |
+
torch.arange(start=1, end=last_hidden_state.shape[1] + 1)
|
288 |
+
.unsqueeze(0)
|
289 |
+
.unsqueeze(-1)
|
290 |
+
.expand(last_hidden_state.size())
|
291 |
+
.float().to(last_hidden_state.device)
|
292 |
+
)
|
293 |
+
|
294 |
+
# Get attn mask of shape [bs, seq_len, hid_dim]
|
295 |
+
input_mask_expanded = (
|
296 |
+
batch_tokens["attention_mask"]
|
297 |
+
.unsqueeze(-1)
|
298 |
+
.expand(last_hidden_state.size())
|
299 |
+
.float()
|
300 |
+
)
|
301 |
+
|
302 |
+
# Perform weighted mean pooling across seq_len: bs, seq_len, hidden_dim -> bs, hidden_dim
|
303 |
+
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded * weights, dim=1)
|
304 |
+
sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
|
305 |
+
|
306 |
+
embeddings = sum_embeddings / sum_mask
|
307 |
+
return texts,embeddings
|
308 |
+
|
309 |
+
def output_results(self,output_file,texts,embeddings,main_index = 0):
|
310 |
+
# Calculate cosine similarities
|
311 |
+
# Cosine similarities are in [-1, 1]. Higher means more similar
|
312 |
+
cosine_dict = {}
|
313 |
+
if (self.debug):
|
314 |
+
print("Total sentences",len(texts))
|
315 |
+
for i in range(len(texts)):
|
316 |
+
cosine_dict[texts[i]] = 1 - cosine(embeddings[main_index], embeddings[i])
|
317 |
+
|
318 |
+
if (self.debug):
|
319 |
+
print("Input sentence:",texts[main_index])
|
320 |
+
sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
|
321 |
+
if (self.debug):
|
322 |
+
for key in sorted_dict:
|
323 |
+
print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
|
324 |
+
if (output_file is not None):
|
325 |
+
with open(output_file,"w") as fp:
|
326 |
+
fp.write(json.dumps(sorted_dict,indent=0))
|
327 |
+
return sorted_dict
|
328 |
+
|
329 |
+
|
330 |
+
|
331 |
+
|
332 |
+
|
333 |
+
class HFModel:
|
334 |
+
def __init__(self):
|
335 |
+
self.model = None
|
336 |
+
self.tokenizer = None
|
337 |
+
self.debug = False
|
338 |
+
print("In HF Constructor")
|
339 |
+
|
340 |
+
|
341 |
+
def init_model(self,model_name = None):
|
342 |
+
# Get our models - The package will take care of downloading the models automatically
|
343 |
+
# For best performance: Muennighoff/SGPT-5.8B-weightedmean-nli-bitfit
|
344 |
+
#print("Init model",model_name)
|
345 |
+
if (model_name is None):
|
346 |
+
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
347 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
348 |
+
self.model = AutoModel.from_pretrained(model_name)
|
349 |
+
self.model.eval()
|
350 |
+
|
351 |
+
def mean_pooling(self,model_output, attention_mask):
|
352 |
+
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
353 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
354 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
355 |
+
|
356 |
+
def compute_embeddings(self,input_data,is_file):
|
357 |
+
#print("Computing embeddings for:", input_data[:20])
|
358 |
+
model = self.model
|
359 |
+
tokenizer = self.tokenizer
|
360 |
+
|
361 |
+
texts = read_text(input_data) if is_file == True else input_data
|
362 |
+
|
363 |
+
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
|
364 |
+
|
365 |
+
# Compute token embeddings
|
366 |
+
with torch.no_grad():
|
367 |
+
model_output = model(**encoded_input)
|
368 |
+
|
369 |
+
# Perform pooling
|
370 |
+
sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
|
371 |
+
|
372 |
+
# Normalize embeddings
|
373 |
+
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
374 |
+
|
375 |
+
return texts,sentence_embeddings
|
376 |
+
|
377 |
+
def output_results(self,output_file,texts,embeddings,main_index = 0):
|
378 |
+
# Calculate cosine similarities
|
379 |
+
# Cosine similarities are in [-1, 1]. Higher means more similar
|
380 |
+
cosine_dict = {}
|
381 |
+
#print("Total sentences",len(texts))
|
382 |
+
for i in range(len(texts)):
|
383 |
+
cosine_dict[texts[i]] = 1 - cosine(embeddings[main_index], embeddings[i])
|
384 |
+
|
385 |
+
#print("Input sentence:",texts[main_index])
|
386 |
+
sorted_dict = dict(sorted(cosine_dict.items(), key=lambda item: item[1],reverse = True))
|
387 |
+
if (self.debug):
|
388 |
+
for key in sorted_dict:
|
389 |
+
print("Cosine similarity with \"%s\" is: %.3f" % (key, sorted_dict[key]))
|
390 |
+
if (output_file is not None):
|
391 |
+
with open(output_file,"w") as fp:
|
392 |
+
fp.write(json.dumps(sorted_dict,indent=0))
|
393 |
+
return sorted_dict
|
394 |
+
|
395 |
+
|
396 |
+
|
397 |
+
if __name__ == '__main__':
|
398 |
+
parser = argparse.ArgumentParser(description='SGPT model for sentence embeddings ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
399 |
+
parser.add_argument('-input', action="store", dest="input",required=True,help="Input file with sentences")
|
400 |
+
parser.add_argument('-output', action="store", dest="output",default="output.txt",help="Output file with results")
|
401 |
+
parser.add_argument('-model', action="store", dest="model",default="sentence-transformers/all-MiniLM-L6-v2",help="model name")
|
402 |
+
|
403 |
+
results = parser.parse_args()
|
404 |
+
obj = HFModel()
|
405 |
+
obj.init_model(results.model)
|
406 |
+
texts, embeddings = obj.compute_embeddings(results.input,is_file = True)
|
407 |
+
results = obj.output_results(results.output,texts,embeddings)
|