paper-matching / score.py
jskim's picture
better output formatting (removing spaces around punctuations)
5b4e16a
raw
history blame
11.3 kB
from sentence_transformers import util
from nltk.tokenize import sent_tokenize
from nltk import word_tokenize, pos_tag
import torch
import numpy as np
import tqdm
def compute_sentencewise_scores(model, query_sents, candidate_sents):
# list of sentences from query and candidate
q_v, c_v = get_embedding(model, query_sents, candidate_sents)
return util.cos_sim(q_v, c_v)
def get_embedding(model, query_sents, candidate_sents):
q_v = model.encode(query_sents)
c_v = model.encode(candidate_sents)
return q_v, c_v
def get_top_k(score_mat, K=3):
"""
Pick top K sentences to show
"""
idx = torch.argsort(-score_mat)
picked_sent = idx[:,:K]
picked_scores = torch.vstack(
[score_mat[i,picked_sent[i]] for i in range(picked_sent.shape[0])]
)
return picked_sent, picked_scores
def get_words(sent):
"""
Input: list of sentences
Output: list of list of words per sentence, all words in, index of starting words for each sentence
"""
words = []
sent_start_id = [] # keep track of the word index where the new sentence starts
counter = 0
for x in sent:
#w = x.split()
w = word_tokenize(x)
nw = len(w)
counter += nw
words.append(w)
sent_start_id.append(counter)
words = [word_tokenize(x) for x in sent]
all_words = [item for sublist in words for item in sublist]
sent_start_id.pop()
sent_start_id = [0] + sent_start_id
assert(len(sent_start_id) == len(sent))
return words, all_words, sent_start_id
def get_match_phrase(w1, w2, method='pos'):
"""
Input: list of words for query and candidate text
Output: word list and binary mask of matching phrases between the inputs
"""
mask1 = np.zeros(len(w1))
mask2 = np.zeros(len(w2))
if method == 'pos':
# POS tags that should be considered for matching phrase
include = [
'NN',
'NNS',
'NNP',
'NNPS',
'LS',
'SYM',
'FW'
]
pos1 = pos_tag(w1)
pos2 = pos_tag(w2)
for i, (w, p) in enumerate(pos2):
if w.lower() in w1 and p in include:
j = w1.index(w.lower())
mask2[i] = 1
mask1[j] = 1
return mask1, mask2
def remove_spaces(words, attrs):
# make the output more readable by removing unnecessary spacings from the tokenizer
# e.g.
# 1. spacing for parenthesis
# 2. spacing for single/double quotations
# 3. spacing for commas and periods
# 4. spacing for possessive quotations
assert(len(words) == len(attrs))
word_out, attr_out = [], []
idx, single_q, double_q = 0, 0, 0
while idx < len(words):
# stick to the word that appears right before
if words[idx] in [',', '.', '%', ')', ':', '?', ';', "'s"]:
ww = word_out.pop()
aa = attr_out.pop()
word_out.append(ww + words[idx])
attr_out.append(aa)
idx += 1
# stick to the word that appears right after
elif words[idx] in ["("]:
word_out.append(words[idx] + words[idx+1])
attr_out.append(attrs[idx+1])
idx += 2
# quotes
elif words[idx] == '"':
double_q += 1
if double_q == 2:
# this is closing quote: stick to word before
ww = word_out.pop()
aa = attr_out.pop()
word_out.append(ww + words[idx])
attr_out.append(aa)
idx += 1
double_q = 0
else:
# this is opening quote: stick to the word after
word_out.append(words[idx] + words[idx+1])
attr_out.append(attrs[idx+1])
idx += 2
elif words[idx] == "'":
single_q += 1
if single_q == 2:
# this is closing quote: stick to word before
ww = word_out.pop()
aa = attr_out.pop()
word_out.append(ww + words[idx])
attr_out.append(aa)
idx += 1
single_q = 0
else:
if words[idx-1][-1] == 's': #possessive quote
# stick to the word before, reset counter
ww = word_out.pop()
aa = attr_out.pop()
word_out.append(ww + words[idx])
attr_out.append(aa)
idx += 1
single_q = 0
else:
# this is opening quote: stick to the word after
word_out.append(words[idx] + words[idx+1])
attr_out.append(attrs[idx+1])
idx += 2
else:
word_out.append(words[idx])
attr_out.append(attrs[idx])
idx += 1
assert(len(word_out) == len(attr_out))
return word_out, attr_out
def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
"""
Mark the words that are highlighted, both by in terms of sentence and phrase
"""
num_query_sent = sent_ids.shape[0]
num_words = len(all_words)
output = dict()
output['all_words'] = all_words
output['words_by_sentence'] = words
# for each query sentence, mark the highlight information
for i in range(num_query_sent):
query_words = word_tokenize(query_sents[i])
is_selected_sent = np.zeros(num_words)
is_selected_phrase = np.zeros(num_words)
word_scores = np.zeros(num_words)
# for each selected sentences from the candidate, compile information
for sid, sscore in zip(sent_ids[i], sent_scores[i]):
#print(len(sent_start_id), sid, sid+1)
if sid+1 < len(sent_start_id):
sent_range = (sent_start_id[sid], sent_start_id[sid+1])
is_selected_sent[sent_range[0]:sent_range[1]] = 1
word_scores[sent_range[0]:sent_range[1]] = sscore
_, is_selected_phrase[sent_range[0]:sent_range[1]] = \
get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]])
else:
is_selected_sent[sent_start_id[sid]:] = 1
word_scores[sent_start_id[sid]:] = sscore
_, is_selected_phrase[sent_start_id[sid]:] = \
get_match_phrase(query_words, all_words[sent_start_id[sid]:])
# update selected phrase scores (-1 meaning a different color in gradio)
word_scores[is_selected_sent+is_selected_phrase==2] = -0.5
output[i] = {
'is_selected_sent': is_selected_sent,
'is_selected_phrase': is_selected_phrase,
'scores': word_scores
}
return output
def get_highlight_info(model, text1, text2, K=None):
"""
Get highlight information from two texts
"""
sent1 = sent_tokenize(text1) # query
sent2 = sent_tokenize(text2) # candidate
if K is None: # if K is not set, select based on the length of the candidate
K = int(len(sent2) / 3)
score_mat = compute_sentencewise_scores(model, sent1, sent2)
sent_ids, sent_scores = get_top_k(score_mat, K=K)
words2, all_words2, sent_start_id2 = get_words(sent2)
info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
# get top sentence pairs from the query and candidate (score, index_pair)
top_pair_num = 5
top_pairs = []
ii = np.unravel_index(np.argsort(np.array(sent_scores).ravel())[-top_pair_num:], sent_scores.shape)
for i, j in zip(ii[0][::-1], ii[1][::-1]):
score = sent_scores[i,j].item()
index_pair = (i, sent_ids[i,j].item())
top_pairs.append((score, index_pair)) # list of (score, (sent_id_query, sent_id_candidate))
# convert top_pairs to corresponding highlights format for GRadio Interpretation component
top_pairs_info = dict()
count = 0
for s, (sidq, sidc) in top_pairs:
q_sent = sent1[sidq]
c_sent = sent2[sidc]
q_words = word_tokenize(q_sent)
c_words = word_tokenize(c_sent)
mask1, mask2 = get_match_phrase(q_words, c_words)
sc = 0.5
mask1 *= -sc # mark matching phrases as blue (-1: darkest)
mask2 *= -sc # mark matching phrases as blue
assert(len(mask1) == len(q_words) and len(mask2) == len(c_words))
# spacing
q_words, mask1 = remove_spaces(q_words, mask1)
c_words, mask2 = remove_spaces(c_words, mask2)
top_pairs_info[count] = {
'query': {
'original': q_sent,
'interpretation': list(zip(q_words, mask1))
},
'candidate': {
'original': c_sent,
'interpretation': list(zip(c_words, mask2))
},
'score': s,
'sent_idx': (sidq, sidc)
}
count += 1
return sent_ids, sent_scores, info, top_pairs_info
### Document-level operations
def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20):
# compute document scores for each papers
# concatenate title and abstract
title_abs = []
for t, a in zip(titles, abstracts):
if t is not None and a is not None:
title_abs.append(t + ' [SEP] ' + a)
num_docs = len(title_abs)
no_iter = int(np.ceil(num_docs / batch))
scores = []
with torch.no_grad():
# batch
for i in tqdm.tqdm(range(no_iter)):
# preprocess the input
inputs = tokenizer(
[query] + title_abs[i*batch:(i+1)*batch],
padding=True,
truncation=True,
return_tensors="pt",
max_length=512
)
inputs.to(doc_model.device)
result = doc_model(**inputs)
# take the first token in the batch as the embedding
embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy()
# compute cosine similarity
q_emb = embeddings[0,:]
p_emb = embeddings[1:,:]
nn = np.linalg.norm(q_emb) * np.linalg.norm(p_emb, axis=1)
scores += list(np.dot(p_emb, q_emb) / nn)
assert(len(scores) == num_docs)
return scores
def compute_document_score(doc_model, tokenizer, query, papers, batch=5):
scores = []
titles = []
abstracts = []
urls = []
for p in papers:
if p['title'] is not None and p['abstract'] is not None:
titles.append(p['title'])
abstracts.append(p['abstract'])
urls.append(p['url'])
scores = predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=batch)
assert(len(scores) == len(abstracts))
idx_sorted = np.argsort(scores)[::-1]
titles_sorted = [titles[x] for x in idx_sorted]
abstracts_sorted = [abstracts[x] for x in idx_sorted]
scores_sorted = [scores[x] for x in idx_sorted]
urls_sorted = [urls[x] for x in idx_sorted]
return titles_sorted, abstracts_sorted, urls_sorted, scores_sorted