Text Classification
Transformers
Safetensors
English
HHEMv2Config
custom_code
File size: 2,662 Bytes
e2b6d9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f7b340
 
 
 
 
 
 
 
 
e2b6d9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch

from peft import PeftModel
from transformers import PreTrainedModel, AutoConfig, T5ForTokenClassification, AutoModel, AutoTokenizer, AutoModelForTokenClassification

from .configuration_hhem_v2 import HHEMv2Config

class HHEMv2Model(PreTrainedModel):
    config_class = HHEMv2Config
    
    def __init__(self, config):
        super().__init__(config)
    #     self.t5 = T5ForTokenClassification.from_config(
    #         AutoConfig.from_pretrained(config.foundation)
    #     )

    # def populate(self, model): 
    #     self.t5 = model

    # def forward(self, **kwarg):
    #     return self.t5.transformer(**kwarg)

class HHEMv2ForSequenceClassification(PreTrainedModel):
    config_class = HHEMv2Config
    
    def __init__(self, config=HHEMv2Config()):
        super().__init__(config)
        self.t5 = T5ForTokenClassification(
            AutoConfig.from_pretrained(config.foundation)
        )
        self.prompt = config.prompt
        self.tokenzier = AutoTokenizer.from_pretrained(config.foundation)

    def populate(self, model: AutoModel): 
        """Initiate the model with the pretrained model
                
        This method should only be called by Vectara employee who prepares the model for publishing. Users do not need to call this method.

        """
        self.t5 = model

    # TODO: Figure out how to publish only the adapter yet still able to do end-to-end pulling and inference.
    # def populate_lora(self, checkpoint: str): 
    #     base_model = AutoModelForTokenClassification.from_pretrained(self.config.foundation)
    #     combined_model = PeftModel.from_pretrained(base_model, checkpoint, is_trainable=False)
    #     self.t5 = combined_model

    def forward(self, **kwargs): # To cope with `text-classiication` pipeline
        self.t5.eval()
        with torch.no_grad():
            outputs = self.t5(**kwargs)
        logits = outputs.logits    
        logits = logits[:, 0, :]
        outputs.logits = logits
        return outputs
        # return self.t5(**kwargs)

    def predict(self, text_pairs):
        tokenizer = self.tokenzier
        pair_dict = [{'text1': pair[0], 'text2': pair[1]} for pair in text_pairs]
        inputs = tokenizer(
            [self.prompt.format(**pair) for pair in pair_dict], return_tensors='pt', padding=True)
        self.t5.eval()
        with torch.no_grad():
            outputs = self.t5(**inputs)
        logits = outputs.logits    
        logits = logits[:, 0, :] # tok_cls
        transformed_probs = torch.softmax(logits, dim=-1)
        raw_scores = transformed_probs[:, 1] # the probability of class 1
        return raw_scores