paper-matching / score.py
jskim's picture
init files
6eff5e7
raw
history blame
5.02 kB
from sentence_transformers import util
from nltk.tokenize import sent_tokenize
import torch
import numpy as np
def compute_sentencewise_scores(model, query_sents, candidate_sents):
# list of sentences from query and candidate
q_v, c_v = get_embedding(model, query_sents, candidate_sents)
return util.cos_sim(q_v, c_v)
def get_embedding(model, query_sents, candidate_sents):
q_v = model.encode(query_sents)
c_v = model.encode(candidate_sents)
return q_v, c_v
def get_top_k(score_mat, K=3):
"""
Pick top K sentences to show
"""
idx = torch.argsort(-score_mat)
picked_sent = idx[:,:K]
picked_scores = torch.vstack(
[score_mat[i,picked_sent[i]] for i in range(picked_sent.shape[0])]
)
return picked_sent, picked_scores
def get_words(sent):
words = []
sent_start_id = [] # keep track of the word index where the new sentence starts
counter = 0
for x in sent:
w = x.split()
nw = len(w)
counter += nw
words.append(w)
sent_start_id.append(counter)
words = [x.split() for x in sent]
all_words = [item for sublist in words for item in sublist]
sent_start_id.pop()
sent_start_id = [0] + sent_start_id
assert(len(sent_start_id) == len(sent))
return words, all_words, sent_start_id
def mark_words(words, all_words, sent_start_id, sent_ids, sent_scores):
num_query_sent = sent_ids.shape[0]
num_words = len(all_words)
output = dict()
output['all_words'] = all_words
output['words_by_sentence'] = words
# for each query sentence, mark the highlight information
for i in range(num_query_sent):
is_selected_sent = np.zeros(num_words)
is_selected_phrase = np.zeros(num_words)
word_scores = np.zeros(num_words) + 1e-4
# get sentence selection information
for sid, sscore in zip(sent_ids[i], sent_scores[i]):
#print(len(sent_start_id), sid, sid+1)
if sid+1 < len(sent_start_id):
sent_range = (sent_start_id[sid], sent_start_id[sid+1])
is_selected_sent[sent_range[0]:sent_range[1]] = 1
word_scores[sent_range[0]:sent_range[1]] = sscore
else:
is_selected_sent[sent_range[0]:] = 1
word_scores[sent_range[0]:] = sscore
# TODO get phrase selection information
output[i] = {
'is_selected_sent': is_selected_sent,
'is_selected_phrase': is_selected_phrase,
'scores': word_scores
}
return output
def get_highlight_info(model, text1, text2, K=3):
sent1 = sent_tokenize(text1) # query
sent2 = sent_tokenize(text2) # candidate
score_mat = compute_sentencewise_scores(model, sent1, sent2)
sent_ids, sent_scores = get_top_k(score_mat, K=K)
#print(sent_ids, sent_scores)
words1, all_words1, sent_start_id1 = get_words(sent2)
#print(all_words1, sent_start_id1)
info = mark_words(words1, all_words1, sent_start_id1, sent_ids, sent_scores)
return sent_ids, sent_scores, info
## Document-level operations
def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20):
# concatenate title and abstract
title_abs = []
for t, a in zip(titles, abstracts):
if t is not None and a is not None:
title_abs.append(t + ' [SEP] ' + a)
num_docs = len(title_abs)
no_iter = int(np.ceil(num_docs / batch))
# preprocess the input
scores = []
with torch.no_grad():
# batch
for i in range(no_iter):
inputs = tokenizer(
[query] + title_abs[i*batch:(i+1)*batch],
padding=True,
truncation=True,
return_tensors="pt",
max_length=512
)
inputs.to(doc_model.device)
result = doc_model(**inputs)
# take the first token in the batch as the embedding
embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy()
# compute cosine similarity
q_emb = embeddings[0,:]
p_emb = embeddings[1:,:]
nn = np.linalg.norm(q_emb) * np.linalg.norm(p_emb, axis=1)
scores += list(np.dot(p_emb, q_emb) / nn)
assert(len(scores) == num_docs)
return scores
def compute_overall_score(doc_model, tokenizer, query, papers, batch=5):
scores = []
titles = []
abstracts = []
for p in papers:
titles.append(p['title'])
abstracts.append(p['abstract'])
scores = predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=batch)
idx_sorted = np.argsort(scores)[::-1]
titles_sorted = [titles[x] for x in idx_sorted]
abstracts_sorted = [abstracts[x] for x in idx_sorted]
scores_sorted = [scores[x] for x in idx_sorted]
return titles_sorted, abstracts_sorted, scores_sorted