Spaces:
Runtime error
Runtime error
# coding=utf8 | |
from transformers import AutoModel, AutoTokenizer, AutoConfig | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
import streamlit as st | |
import gdown | |
import numpy as np | |
import pandas as pd | |
import collections | |
from string import punctuation | |
class CONFIG: | |
#model params | |
model = 'deepset/xlm-roberta-large-squad2' | |
max_input_length = 384 #Hyperparameter to be tuned, following the guide from huggingface | |
doc_stride = 128 #Hyperparameter to be tuned, following the guide from huggingface | |
model_checkpoint = "pytorch_model.pth" | |
trained_model_url = 'https://drive.google.com/uc?id=16Vp918RglyLEFEyDlFuRD1HeNZ8SI7P5' | |
trained_model_output_fp = 'trained_pytorch.pth' | |
sample_df_fp = "sample_qa.json" | |
# model class | |
class ChaiModel(nn.Module): | |
def __init__(self, model_config): | |
super(ChaiModel, self).__init__() | |
self.backbone = AutoModel.from_pretrained(CONFIG.model) | |
self.linear = nn.Linear(model_config.hidden_size, 2) | |
def forward(self, input_ids, attention_mask): | |
model_output = self.backbone(input_ids, attention_mask=attention_mask) | |
sequence_output = model_output[0] # (batchsize, sequencelength, hidden_dim) | |
qa_logits = self.linear(sequence_output) # (batchsize, sequencelength, 2) | |
start_logit, end_logit = qa_logits.split(1, dim=-1) # (batchsize, sequencelength), 1), (batchsize, sequencelength, 1) | |
start_logits = start_logit.squeeze(-1) # remove last dim (batchsize, sequencelength) | |
end_logits = end_logit.squeeze(-1) #remove last dim (batchsize, sequencelength) | |
return start_logits, end_logits # (2,batchsize, sequencelength) | |
# dataset class | |
class ChaiDataset(Dataset): | |
def __init__(self, dataset, is_train=True): | |
super(ChaiDataset, self).__init__() | |
self.dataset = dataset #list of features | |
self.is_train= is_train | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, index): | |
features = self.dataset[index] | |
if self.is_train: | |
return { | |
'input_ids': torch.tensor(features['input_ids'], dtype=torch.long), | |
'attention_mask': torch.tensor(features['attention_mask'], dtype=torch.long), | |
'offset_mapping':torch.tensor(features['offset_mapping'], dtype=torch.long), | |
'start_position':torch.tensor(features['start_position'], dtype=torch.long), | |
'end_position':torch.tensor(features['end_position'], dtype=torch.long) | |
} | |
else: | |
return { | |
'input_ids': torch.tensor(features['input_ids'], dtype=torch.long), | |
'attention_mask': torch.tensor(features['attention_mask'], dtype=torch.long), | |
'offset_mapping':torch.tensor(features['offset_mapping'], dtype=torch.long), | |
'sequence_ids':features['sequence_ids'], | |
'id':features['example_id'], | |
'context':features['context'], | |
'question':features['question'] | |
} | |
def break_long_context(df, tokenizer, train=True): | |
if train: | |
n_examples = len(df) | |
full_set = [] | |
for i in range(n_examples): | |
row = df.iloc[i] | |
# tokenizer parameters can be found here | |
# https://huggingface.co/transformers/internal/tokenization_utils.html#transformers.tokenization_utils_base.PreTrainedTokenizerBase | |
tokenized_examples = tokenizer(row['question'], | |
row['context'], | |
padding='max_length', | |
max_length=CONFIG.max_input_length, | |
truncation='only_second', | |
stride=CONFIG.doc_stride, | |
return_overflowing_tokens=True, #returns the number of over flow | |
return_offsets_mapping=True #returns the BPE mapping to the original word | |
) | |
# tokenized_example keys | |
#'input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping' | |
sample_mappings = tokenized_examples.pop("overflow_to_sample_mapping") | |
offset_mappings = tokenized_examples.pop("offset_mapping") | |
final_examples = [] | |
n_sub_examples = len(sample_mappings) | |
for j in range(n_sub_examples): | |
input_ids = tokenized_examples["input_ids"][j] | |
attention_mask = tokenized_examples["attention_mask"][j] | |
sliced_text = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids)) | |
final_example = dict(input_ids = input_ids, | |
attention_mask = attention_mask, | |
sliced_text = sliced_text, | |
offset_mapping=offset_mappings[j], | |
fold=row['fold']) | |
# Most of the time cls_index is 0 | |
cls_index = input_ids.index(tokenizer.cls_token_id) | |
# None, 0, 0, .... None, None, 1, 1,..... | |
sequence_ids = tokenized_examples.sequence_ids(j) | |
sample_index = sample_mappings[j] | |
offset_map = offset_mappings[j] | |
if np.isnan(row["answer_start"]) : # if no answer, start and end position is cls_index | |
final_example['start_position'] = cls_index | |
final_example['end_position'] = cls_index | |
final_example['tokenized_answer'] = "" | |
final_example['answer_text'] = "" | |
else: | |
start_char = row["answer_start"] | |
end_char = start_char + len(row["answer_text"]) | |
token_start_index = sequence_ids.index(1) | |
token_end_index = len(sequence_ids)- 1 - (sequence_ids[::-1].index(1)) | |
if not (offset_map[token_start_index][0]<=start_char and offset_map[token_end_index][1] >= end_char): | |
final_example['start_position'] = cls_index | |
final_example['end_position'] = cls_index | |
final_example['tokenized_answer'] = "" | |
final_example['answer_text'] = "" | |
else: | |
#Move token_start_index to the correct context index | |
while token_start_index < len(offset_map) and offset_map[token_start_index][0] <= start_char: | |
token_start_index +=1 | |
final_example['start_position'] = token_start_index -1 | |
while offset_map[token_end_index][1] >= end_char: #Take note that we will want the end_index inclusively, we will need to slice properly later | |
token_end_index -=1 | |
final_example['end_position'] = token_end_index + 1 | |
tokenized_answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[final_example['start_position']:final_example['end_position']+1])) | |
final_example['tokenized_answer'] = tokenized_answer | |
final_example['answer_text'] = row['answer_text'] | |
final_examples.append(final_example) | |
full_set += final_examples | |
else: | |
n_examples = len(df) | |
full_set = [] | |
for i in range(n_examples): | |
row = df.iloc[i] | |
tokenized_examples = tokenizer(row['question'], | |
row['context'], | |
padding='max_length', | |
max_length=CONFIG.max_input_length, | |
truncation='only_second', | |
stride=CONFIG.doc_stride, | |
return_overflowing_tokens=True, #returns the number of over flow | |
return_offsets_mapping=True #returns the BPE mapping to the original word | |
) | |
sample_mappings = tokenized_examples.pop("overflow_to_sample_mapping") | |
offset_mappings = tokenized_examples.pop("offset_mapping") | |
n_sub_examples = len(sample_mappings) | |
final_examples = [] | |
for j in range(n_sub_examples): | |
input_ids = tokenized_examples["input_ids"][j] | |
attention_mask = tokenized_examples["attention_mask"][j] | |
final_example = dict( | |
input_ids = input_ids, | |
attention_mask = attention_mask, | |
offset_mapping=offset_mappings[j], | |
example_id = row['id'], | |
context = row['context'], | |
question = row['question'], | |
sequence_ids = [0 if value is None else value for value in tokenized_examples.sequence_ids(j)] | |
) | |
final_examples.append(final_example) | |
full_set += final_examples | |
return full_set | |
def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30): | |
all_start_logits, all_end_logits = raw_predictions | |
example_id_to_index = {k: i for i, k in enumerate(examples["id"])} | |
features_per_example = collections.defaultdict(list) | |
for i, feature in enumerate(features): | |
features_per_example[example_id_to_index[feature["example_id"]]].append(i) | |
predictions = collections.OrderedDict() | |
print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") | |
for example_index, example in examples.iterrows(): | |
feature_indices = features_per_example[example_index] | |
min_null_score = None | |
valid_answers = [] | |
context = example["context"] | |
for feature_index in feature_indices: | |
start_logits = all_start_logits[feature_index] | |
end_logits = all_end_logits[feature_index] | |
sequence_ids = features[feature_index]["sequence_ids"] | |
context_index = 1 | |
features[feature_index]["offset_mapping"] = [ | |
(o if sequence_ids[k] == context_index else None) | |
for k, o in enumerate(features[feature_index]["offset_mapping"]) | |
] | |
offset_mapping = features[feature_index]["offset_mapping"] | |
cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id) | |
feature_null_score = start_logits[cls_index] + end_logits[cls_index] | |
if min_null_score is None or min_null_score < feature_null_score: | |
min_null_score = feature_null_score | |
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() | |
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() | |
for start_index in start_indexes: | |
for end_index in end_indexes: | |
if ( | |
start_index >= len(offset_mapping) | |
or end_index >= len(offset_mapping) | |
or offset_mapping[start_index] is None | |
or offset_mapping[end_index] is None | |
): | |
continue | |
# Don't consider answers with a length that is either < 0 or > max_answer_length. | |
if end_index < start_index or end_index - start_index + 1 > max_answer_length: | |
continue | |
start_char = offset_mapping[start_index][0] | |
end_char = offset_mapping[end_index][1] | |
valid_answers.append( | |
{ | |
"score": start_logits[start_index] + end_logits[end_index], | |
"text": context[start_char: end_char] | |
} | |
) | |
if len(valid_answers) > 0: | |
best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0] | |
else: | |
best_answer = {"text": "", "score": 0.0} | |
predictions[example["id"]] = best_answer["text"] | |
return predictions | |
def download_finetuned_model(): | |
gdown.download(url=CONFIG.trained_model_url, output=CONFIG.trained_model_output_fp, quiet=False) | |
def get_prediction(context:str, question:str, model, tokenizer) -> str: | |
# convert to dataframe format to make it consistent with training way | |
test_df = pd.DataFrame({"id":[1], "context":[context.strip()], "question":[question.strip()]}) | |
test_set = break_long_context(test_df, tokenizer, train=False) | |
#create dataset and dataloader of batch 1 to prevent OOM | |
test_dataset = ChaiDataset(test_set, is_train=False) | |
test_dataloader = DataLoader(test_dataset, | |
batch_size=1, | |
shuffle=False, | |
drop_last=False | |
) | |
#main prediction function | |
start_logits =[] | |
end_logits=[] | |
for features in test_dataloader: | |
input_ids = features['input_ids'] | |
attention_mask = features['attention_mask'] | |
with torch.no_grad(): | |
start_logit, end_logit = model(input_ids, attention_mask) #(batch, 384,1) , (batch, 384,1) | |
start_logits.append(start_logit.to("cpu").numpy()) | |
end_logits.append(end_logit.to("cpu").numpy()) | |
start_logits, end_logits = np.vstack(start_logits), np.vstack(end_logits) | |
predictions = postprocess_qa_predictions(test_df, test_set, (start_logits, end_logits)) | |
predictions = list(predictions.items())[0][1] | |
predictions = predictions.strip(punctuation) | |
return predictions | |
def load_model(): | |
gdown.download(url=CONFIG.trained_model_url, output=CONFIG.trained_model_output_fp, quiet=False) | |
print("Downloaded pretrained model") | |
config = AutoConfig.from_pretrained(CONFIG.model) | |
model = ChaiModel(config) | |
model.load_state_dict(torch.load(CONFIG.trained_model_output_fp, map_location=torch.device('cpu'))) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(CONFIG.model) | |
sample_df = pd.read_json(CONFIG.sample_df_fp) | |
return model, tokenizer, sample_df | |
model, tokenizer, sample_df = load_model() | |
## initialize session_state | |
if "context" not in st.session_state: | |
st.session_state["context"] = "" | |
if "question" not in st.session_state: | |
st.session_state['question'] = "" | |
if "answer" not in st.session_state: | |
st.session_state['answer'] = "" | |
## Layout | |
st.sidebar.title("Hindi/Tamil Extractive Question Answering") | |
st.sidebar.markdown("---") | |
random_button = st.sidebar.button("Random") | |
st.sidebar.write("Randomly Generates a Hindi/Tamil Context and Question") | |
st.sidebar.markdown("---") | |
answer_button = st.sidebar.button("Answer!") | |
if random_button: | |
sample = sample_df.sample(1) | |
st.session_state['context'] = sample['context'].item() | |
st.session_state['question'] = sample['question'].item() | |
st.session_state['answer'] = "" | |
if answer_button: | |
# if question or context is empty text | |
if len(st.session_state['context']) == 0 or len(st.session_state['question']) ==0: | |
st.session_state['answer'] = " " | |
else: | |
st.session_state['answer'] = get_prediction(st.session_state['context'], st.session_state['question'], model, tokenizer) | |
st.session_state["context"] = st.text_area("Context", value=st.session_state['context'], height=300) | |
with st.container(): | |
col_1, col_2 = st.columns(2) | |
with col_1: | |
st.session_state['question'] = st.text_area("Question", value=st.session_state['question'], height=200) | |
with col_2: | |
st.text_area("Answer", value=st.session_state['answer'], height=200) | |