from sentence_transformers import util from nltk.tokenize import sent_tokenize import torch import numpy as np def compute_sentencewise_scores(model, query_sents, candidate_sents): # list of sentences from query and candidate q_v, c_v = get_embedding(model, query_sents, candidate_sents) return util.cos_sim(q_v, c_v) def get_embedding(model, query_sents, candidate_sents): q_v = model.encode(query_sents) c_v = model.encode(candidate_sents) return q_v, c_v def get_top_k(score_mat, K=3): """ Pick top K sentences to show """ idx = torch.argsort(-score_mat) picked_sent = idx[:,:K] picked_scores = torch.vstack( [score_mat[i,picked_sent[i]] for i in range(picked_sent.shape[0])] ) return picked_sent, picked_scores def get_words(sent): words = [] sent_start_id = [] # keep track of the word index where the new sentence starts counter = 0 for x in sent: w = x.split() nw = len(w) counter += nw words.append(w) sent_start_id.append(counter) words = [x.split() for x in sent] all_words = [item for sublist in words for item in sublist] sent_start_id.pop() sent_start_id = [0] + sent_start_id assert(len(sent_start_id) == len(sent)) return words, all_words, sent_start_id def mark_words(words, all_words, sent_start_id, sent_ids, sent_scores): num_query_sent = sent_ids.shape[0] num_words = len(all_words) output = dict() output['all_words'] = all_words output['words_by_sentence'] = words # for each query sentence, mark the highlight information for i in range(num_query_sent): is_selected_sent = np.zeros(num_words) is_selected_phrase = np.zeros(num_words) word_scores = np.zeros(num_words) + 1e-4 # get sentence selection information for sid, sscore in zip(sent_ids[i], sent_scores[i]): #print(len(sent_start_id), sid, sid+1) if sid+1 < len(sent_start_id): sent_range = (sent_start_id[sid], sent_start_id[sid+1]) is_selected_sent[sent_range[0]:sent_range[1]] = 1 word_scores[sent_range[0]:sent_range[1]] = sscore else: is_selected_sent[sent_range[0]:] = 1 word_scores[sent_range[0]:] = sscore # TODO get phrase selection information output[i] = { 'is_selected_sent': is_selected_sent, 'is_selected_phrase': is_selected_phrase, 'scores': word_scores } return output def get_highlight_info(model, text1, text2, K=3): sent1 = sent_tokenize(text1) # query sent2 = sent_tokenize(text2) # candidate score_mat = compute_sentencewise_scores(model, sent1, sent2) sent_ids, sent_scores = get_top_k(score_mat, K=K) #print(sent_ids, sent_scores) words1, all_words1, sent_start_id1 = get_words(sent2) #print(all_words1, sent_start_id1) info = mark_words(words1, all_words1, sent_start_id1, sent_ids, sent_scores) return sent_ids, sent_scores, info ## Document-level operations def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20): # concatenate title and abstract title_abs = [] for t, a in zip(titles, abstracts): if t is not None and a is not None: title_abs.append(t + ' [SEP] ' + a) num_docs = len(title_abs) no_iter = int(np.ceil(num_docs / batch)) # preprocess the input scores = [] with torch.no_grad(): # batch for i in range(no_iter): inputs = tokenizer( [query] + title_abs[i*batch:(i+1)*batch], padding=True, truncation=True, return_tensors="pt", max_length=512 ) inputs.to(doc_model.device) result = doc_model(**inputs) # take the first token in the batch as the embedding embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy() # compute cosine similarity q_emb = embeddings[0,:] p_emb = embeddings[1:,:] nn = np.linalg.norm(q_emb) * np.linalg.norm(p_emb, axis=1) scores += list(np.dot(p_emb, q_emb) / nn) assert(len(scores) == num_docs) return scores def compute_overall_score(doc_model, tokenizer, query, papers, batch=5): scores = [] titles = [] abstracts = [] for p in papers: titles.append(p['title']) abstracts.append(p['abstract']) scores = predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=batch) idx_sorted = np.argsort(scores)[::-1] titles_sorted = [titles[x] for x in idx_sorted] abstracts_sorted = [abstracts[x] for x in idx_sorted] scores_sorted = [scores[x] for x in idx_sorted] return titles_sorted, abstracts_sorted, scores_sorted