File size: 1,527 Bytes
eafbf97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""Utils for processing and encoding text."""

import torch



def lemmatize_verbs(verbs: list):
    from nltk.stem import WordNetLemmatizer
    wnl = WordNetLemmatizer()
    return [wnl.lemmatize(verb, 'v') for verb in verbs]


def lemmatize_adverbs(adverbs: list):
    from nltk.stem import WordNetLemmatizer
    wnl = WordNetLemmatizer()
    return [wnl.lemmatize(adverb, 'r') for adverb in adverbs]


class SentenceEncoder:

    def __init__(self, model_name="roberta-base"):
        from transformers import RobertaTokenizer, RobertaModel
        if model_name == 'roberta-base':
            self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
            self.model = RobertaModel.from_pretrained(model_name)
    
    def encode_sentence(self, sentence):
        inputs = self.tokenizer.encode_plus(
            sentence, add_special_tokens=True, return_tensors='pt',
        )
        with torch.no_grad():
            outputs = self.model(**inputs)
        # sentence_embedding = torch.mean(outputs.last_hidden_state, dim=1).squeeze(0)
        sentence_embedding = outputs.last_hidden_state[:, 0, :]
        return sentence_embedding
    
    def encode_sentences(self, sentences):
        """Encodes a list of sentences using model."""
        tokenized_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
        with torch.no_grad():
            outputs = self.model(**tokenized_input)
        embeddings = outputs.last_hidden_state[:, 0, :]
        return embeddings