import os import json import numpy as np import pandas as pd import random import streamlit as st import torch import torch.nn.functional as F from transformers import DistilBertTokenizer, DistilBertForSequenceClassification @st.cache(allow_output_mutation=True) def init_model(): tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') model = DistilBertForSequenceClassification.from_pretrained('khizon/distilbert-unreliable-news-eng-4L', num_labels = 2) return tokenizer, model def download_dataset(): url = 'https://drive.google.com/drive/folders/11mRvsHAkggFEJvG4axH4mmWI6FHMQp7X?usp=sharing' data = 'data/nela_gt_2018_site_split' os.system(f'gdown --folder {url} -O {data}') @st.cache(allow_output_mutation=True) def jsonl_to_df(file_path): with open(file_path) as f: lines = f.read().splitlines() df_inter = pd.DataFrame(lines) df_inter.columns = ['json_element'] df_inter['json_element'].apply(json.loads) return pd.json_normalize(df_inter['json_element'].apply(json.loads)) @st.cache def load_test_df(): file_path = os.path.join('test_sub.jsonl') test_df = jsonl_to_df(file_path) test_df = pd.get_dummies(test_df, columns = ['label']) return test_df @st.cache(allow_output_mutation=True) def predict(model, tokenizer, data): labels = data[['label_0', 'label_1']] labels = torch.tensor(labels, dtype=torch.float32) encoding = tokenizer.encode_plus( data['title'], ' [SEP] ' + data['content'], add_special_tokens=True, max_length = 512, return_token_type_ids = False, padding = 'max_length', truncation = 'only_second', return_attention_mask = True, return_tensors = 'pt' ) output = model(**encoding) return correct_preds(output['logits'], labels) @st.cache(allow_output_mutation=True) def predict_new(model, tokenizer, title, content): encoding = tokenizer.encode_plus( title, ' [SEP] ' + content, add_special_tokens=True, max_length = 512, return_token_type_ids = False, padding = 'max_length', truncation = 'only_second', return_attention_mask = True, return_tensors = 'pt' ) output = model(**encoding) preds = F.softmax(output['logits'], dim = 1) p_idx = torch.argmax(preds, dim = 1) return 'reliable' if p_idx > 0 else 'unreliable' def correct_preds(preds, labels): preds = torch.nn.functional.softmax(preds, dim = 1) p_idx = torch.argmax(preds, dim=1) l_idx = torch.argmax(labels, dim=0) pred_label = 'reliable' if p_idx > 0 else 'unreliable' correct = True if (p_idx == l_idx).sum().item() > 0 else False return pred_label, correct if __name__ == '__main__': df = load_test_df() tokenizer, model = init_model() st.title("Unreliable News classifier") mode = st.radio( '', ('Test article', 'Input own article') ) if mode == 'Test article': if st.button('Get random article'): idx = np.random.randint(0, len(df)) sample = df.iloc[idx] prediction, correct = predict(model, tokenizer, sample) label = 'reliable' if sample['label_1'] > sample['label_0'] else 'unreliable' st.header(sample['title']) if correct: st.success(f'Prediction: {prediction}') else: st.error(f'Prediction: {prediction}') st.caption(f'Source: {sample["source"]} ({label})') # if len(sample['content']) > 300: # sample['content'] = sample['content'][:300] temp = [] for idx, word in enumerate(sample['content'].split()): if (random.randint(0, 99)> 45) and idx > 0: word = '▒'*len(word) temp.append(word) sample['content'] = ' '.join(temp) st.markdown(sample['content']) else: title = st.text_input('Article title', 'Test title') content = st.text_area('Article content', 'Lorem ipsum') if st.button('Submit'): pred = predict_new(model, tokenizer, title, content) st.markdown(f'Prediction: {pred}') # st.success('success')