ellenhp commited on
Commit
bc83430
·
1 Parent(s): d1bb5e7

Check in space

Browse files
Files changed (2) hide show
  1. app.py +67 -0
  2. hydra.py +112 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio.components import Component
2
+ import torch
3
+ from hydra import Hydra
4
+ from transformers import AutoTokenizer
5
+ import gradio as gr
6
+ from hydra import Hydra
7
+ import os
8
+ from typing import Any, Optional
9
+
10
+ model_name = "ellenhp/query2osm-bert-v1"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True)
12
+ model = Hydra.from_pretrained(model_name).to('cpu')
13
+
14
+
15
+ class DatasetSaver(gr.FlaggingCallback):
16
+ inner: Optional[gr.HuggingFaceDatasetSaver] = None
17
+
18
+ def __init__(self, inner):
19
+ self.inner = inner
20
+
21
+ def setup(self, components: list[Component], flagging_dir: str):
22
+ self.inner.setup(components, flagging_dir)
23
+
24
+ def flag(self,
25
+ flag_data: list[Any],
26
+ flag_option: str = "",
27
+ username: str | None = None):
28
+ flag_data = [flag_data[0], {"label": flag_data[1]['label']}]
29
+ self.inner.flag(flag_data, flag_option, None)
30
+
31
+
32
+ HF_TOKEN = os.getenv('HF_TOKEN')
33
+ if HF_TOKEN is not None:
34
+ hf_writer = gr.HuggingFaceDatasetSaver(
35
+ HF_TOKEN, "osm-queries-crowdsourced", True, "data.csv", False)
36
+ else:
37
+ hf_writer = None
38
+
39
+
40
+ flag_callback = DatasetSaver(hf_writer)
41
+
42
+
43
+ def predict(input_query):
44
+ with torch.no_grad():
45
+ print(input_query)
46
+ input_text = input_query.strip().lower()
47
+ inputs = tokenizer(input_text, return_tensors="pt")
48
+ outputs = model.forward(inputs.input_ids)
49
+ return {classification[0]: classification[1] for classification in outputs.classifications[0]}
50
+
51
+
52
+ textbox = gr.Textbox(label="Query",
53
+ placeholder="Where can I get a quick bite to eat?")
54
+ label = gr.Label(label="Result", num_top_classes=5)
55
+
56
+ gradio_app = gr.Interface(
57
+ predict,
58
+ inputs=[textbox],
59
+ outputs=[label],
60
+ title="Query Classification",
61
+ allow_flagging="manual",
62
+ flagging_options=["potentially harmful", "wrong classification"],
63
+ flagging_callback=flag_callback,
64
+ )
65
+
66
+ if __name__ == "__main__":
67
+ gradio_app.launch()
hydra.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertConfig, BertModel
2
+ import torch.nn as nn
3
+ import torch
4
+ from typing import Optional, Union, Tuple, List
5
+ from transformers.modeling_outputs import SequenceClassifierOutput
6
+ from torch.nn import CrossEntropyLoss
7
+
8
+
9
+ class HydraConfig(BertConfig):
10
+ model_type = "hydra"
11
+ label_groups = None
12
+
13
+ def __init__(self, **kwargs):
14
+ super().__init__(**kwargs)
15
+
16
+ def num_labels(self):
17
+ return sum([len(group) for group in self.label_groups])
18
+
19
+ def distilbert_config(self):
20
+ return BertConfig(**self.__dict__)
21
+
22
+
23
+ class HydraSequenceClassifierOutput(SequenceClassifierOutput):
24
+ classifications: List[dict]
25
+
26
+ def __init__(self, classifications=None, **kwargs):
27
+ super().__init__(**kwargs)
28
+ self.classifications = classifications
29
+
30
+
31
+ class Hydra(BertModel):
32
+ config_class = HydraConfig
33
+
34
+ def __init__(self, config: HydraConfig):
35
+ super().__init__(config)
36
+ self.config = config
37
+ self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
38
+ self.classifiers = nn.Linear(config.hidden_size, sum(
39
+ [len(group) for group in config.label_groups]))
40
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
41
+
42
+ self.embeddings.requires_grad_(False)
43
+
44
+ self.post_init()
45
+
46
+ def forward(
47
+ self,
48
+ input_ids: Optional[torch.Tensor] = None,
49
+ attention_mask: Optional[torch.Tensor] = None,
50
+ head_mask: Optional[torch.Tensor] = None,
51
+ inputs_embeds: Optional[torch.Tensor] = None,
52
+ labels: Optional[torch.LongTensor] = None,
53
+ output_attentions: Optional[bool] = None,
54
+ output_hidden_states: Optional[bool] = None,
55
+ return_dict: Optional[bool] = None,
56
+ ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
57
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
58
+
59
+ distilbert_output = super().forward(
60
+ input_ids=input_ids,
61
+ attention_mask=attention_mask,
62
+ head_mask=head_mask,
63
+ inputs_embeds=inputs_embeds,
64
+ output_attentions=output_attentions,
65
+ output_hidden_states=output_hidden_states,
66
+ return_dict=return_dict
67
+ )
68
+ hidden_state = distilbert_output[0] # (bs, seq_len, dim)
69
+ pooled_output = hidden_state[:, 0] # (bs, dim)
70
+ pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
71
+ pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
72
+ pooled_output = self.dropout(pooled_output) # (bs, dim)
73
+ logits = self.classifiers(pooled_output) # (bs, num_labels)
74
+
75
+ loss = None
76
+ if labels is not None:
77
+
78
+ loss_fct = CrossEntropyLoss()
79
+ loss = loss_fct(logits, labels)
80
+
81
+ if not return_dict:
82
+ output = (logits,) + distilbert_output[1:]
83
+ return ((loss,) + output) if loss is not None else output
84
+
85
+ classifications = []
86
+ if logits.shape[0] == 1:
87
+ offset = 0
88
+ for group in self.config.label_groups:
89
+ inverted = {group[pair]: pair for pair in group}
90
+ softmax = nn.Softmax(dim=1)
91
+ output = softmax(logits[:, offset:offset + len(group)])
92
+ classification = []
93
+ for i, val in enumerate(output[0]):
94
+ classification.append((inverted[i], val.item()))
95
+ classification.sort(key=lambda x: x[1], reverse=True)
96
+ classifications.append(classification)
97
+ offset += len(group)
98
+
99
+ return HydraSequenceClassifierOutput(
100
+ loss=loss,
101
+ logits=logits,
102
+ hidden_states=distilbert_output.hidden_states,
103
+ attentions=distilbert_output.attentions,
104
+ classifications=classifications
105
+ )
106
+
107
+ def to(self, device):
108
+ super().to(device)
109
+ self.pre_classifier.to(device)
110
+ self.classifiers.to(device)
111
+ self.dropout.to(device)
112
+ return self