ellenhp commited on
Commit
b4df2a0
·
1 Parent(s): 0dca3eb

Update to use new and improved bert model

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. hydra.py +3 -3
app.py CHANGED
@@ -50,7 +50,7 @@ def predict(input_query):
50
 
51
 
52
  textbox = gr.Textbox(label="Query",
53
- placeholder="Where can I get a quick bite to eat?")
54
  label = gr.Label(label="Result", num_top_classes=5)
55
 
56
  gradio_app = gr.Interface(
@@ -59,7 +59,7 @@ gradio_app = gr.Interface(
59
  outputs=[label],
60
  title="Query Classification",
61
  allow_flagging="manual",
62
- flagging_options=["potentially harmful", "wrong classification"],
63
  flagging_callback=flag_callback,
64
  )
65
 
 
50
 
51
 
52
  textbox = gr.Textbox(label="Query",
53
+ placeholder="Quick bite to eat near me")
54
  label = gr.Label(label="Result", num_top_classes=5)
55
 
56
  gradio_app = gr.Interface(
 
59
  outputs=[label],
60
  title="Query Classification",
61
  allow_flagging="manual",
62
+ flagging_options=["correct classification", "incorrect classification"],
63
  flagging_callback=flag_callback,
64
  )
65
 
hydra.py CHANGED
@@ -35,7 +35,7 @@ class Hydra(BertModel):
35
  super().__init__(config)
36
  self.config = config
37
  self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
38
- self.classifiers = nn.Linear(config.hidden_size, sum(
39
  [len(group) for group in config.label_groups]))
40
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
41
 
@@ -70,7 +70,7 @@ class Hydra(BertModel):
70
  pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
71
  pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
72
  pooled_output = self.dropout(pooled_output) # (bs, dim)
73
- logits = self.classifiers(pooled_output) # (bs, num_labels)
74
 
75
  loss = None
76
  if labels is not None:
@@ -107,6 +107,6 @@ class Hydra(BertModel):
107
  def to(self, device):
108
  super().to(device)
109
  self.pre_classifier.to(device)
110
- self.classifiers.to(device)
111
  self.dropout.to(device)
112
  return self
 
35
  super().__init__(config)
36
  self.config = config
37
  self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
38
+ self.classifier = nn.Linear(config.hidden_size, sum(
39
  [len(group) for group in config.label_groups]))
40
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
41
 
 
70
  pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
71
  pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
72
  pooled_output = self.dropout(pooled_output) # (bs, dim)
73
+ logits = self.classifier(pooled_output) # (bs, num_labels)
74
 
75
  loss = None
76
  if labels is not None:
 
107
  def to(self, device):
108
  super().to(device)
109
  self.pre_classifier.to(device)
110
+ self.classifier.to(device)
111
  self.dropout.to(device)
112
  return self