vera / app.py
liujch1998's picture
Demo
5d92357
raw
history blame
2.37 kB
import gradio as gr
import torch
import transformers
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
class Interactive:
def __init__(self):
self.tokenizer = transformers.AutoTokenizer.from_pretrained('liujch1998/cd-pi')
self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/cd-pi').to(device)
self.linear = torch.nn.Linear(self.model.shared.embedding_dim, 1).to(device)
self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1)
self.model.eval()
self.t = 2.2247
def run(self, statement):
input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest').input_ids.to(device)
with torch.no_grad():
output = self.model(input_ids)
last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D)
hidden = last_hidden_state[0, -1, :] # (D)
logit = self.linear(hidden).squeeze(-1) # ()
logit_calibrated = logit / self.t
score = logit.sigmoid()
score_calibrated = logit_calibrated.sigmoid()
return {
'logit': logit.item(),
'logit_calibrated': logit_calibrated.item(),
'score': score.item(),
'score_calibrated': score_calibrated.item(),
}
interactive = Interactive()
def predict(statement, model):
result = interactive.run(statement)
return {
'True': result['score_calibrated'],
'False': 1 - result['score_calibrated'],
}
examples = [
'If A sits next to B and B sits next to C, then A must sit next to C.',
'If A sits next to B and B sits next to C, then A might not sit next to C.',
]
input_statement = gr.Dropdown(choices=examples, label='Statement:')
input_model = gr.Textbox(label='Commonsense statement verification model:', value='liujch1998/cd-pi', interactive=False)
output = gr.outputs.Label(num_top_classes=2)
description = '''This is a demo for a commonsense statement verification model. Under development.'''
gr.Interface(
fn=predict,
inputs=[input_statement, input_model],
outputs=output,
title="cd-pi Demo",
description=description,
).launch()