nlp / app.py
ElijahDi's picture
Rename Streamlit app.py to app.py
76ca7ff verified
raw
history blame
8.51 kB
import streamlit as st
import numpy as np
import pandas as pd
import time
import torch
import torch.nn as nn
from torch import tensor
import joblib
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
import json
from preprocessing import predict_review, data_preprocessing_hard
from model_lstm import LSTMClassifier
# from BERT_inputs import BertInputs
device = 'cpu'
classifier_bag = joblib.load('classifier_bag.pkl')
classifier_tf = joblib.load('classifier_tf.pkl')
BERT_lin_cl = joblib.load('BERT_base_model.pkl')
selected_model = st.sidebar.radio("Зачем пришел?", ("Классифиция отзывов лечебных учреждений",
"Оценка степени токсичности пользовательского сообщения",
"Генерация текста GPT-моделью по пользовательскому prompt"))
# Классификация отзыва на поликлиники
model_options = ["BagOfWords", "TF-IDF", "LSTM", "BERT-based-ru"]
if selected_model == "Классифиция отзывов лечебных учреждений":
st.title("""
Приложение классифицирует твой отзыв и подскажет позитивный он или негативный
""")
st.write("""
Классификация происходит с использованием классических ML моделей, нейросетевой модели LSTM,
и, как вариант, с использованием нейросетевой модели Bert-basic-ru для векторизации и линейной
регрессии для классификации.
""")
vectorizer_1 = joblib.load('vectorizer_bag.joblib')
vectorizer_2 = joblib.load('vectorizer_tf.joblib')
# LSTM
with open('vocab_lstm.json', 'r') as file:
vocab_to_int = json.load(file)
@dataclass
class ConfigRNN:
vocab_size: int
device : str
n_layers : int
embedding_dim : int
hidden_size : int
seq_len : int
bidirectional : bool or int
net_config = ConfigRNN(
vocab_size = len(vocab_to_int)+1,
device='cpu',
n_layers=2,
embedding_dim=64,
hidden_size=32,
seq_len = 100,
bidirectional=False)
lstm = LSTMClassifier(net_config)
lstm.load_state_dict(torch.load('lstm_model.pth', map_location=device))
lstm.to(device)
# lstm.eval()
# BERT
tokenizer = AutoTokenizer.from_pretrained("Geotrend/bert-base-ru-cased")
model = AutoModel.from_pretrained("Geotrend/bert-base-ru-cased")
# model.eval()
MAX_LEN = 200
data = pd.DataFrame({
'Модель': ["BagOfWords", "TF-IDF", "LSTM", "BERT-based-ru"],
'f1_macro': [0.934, 0.939, 0.009, 0.845]
})
st.subheader("""
Немного информации о точности используемых моделей после обучения:
""")
# st.write(data)
st.table(data)
user_text_input = st.text_area('Введите ваш отзыв здесь:', '')
selected_model_name = st.selectbox('Выберите модель:', model_options, index=0)
if st.button('Предсказать'):
start_time = time.time()
if selected_model_name == "BagOfWords":
X = vectorizer_1.transform([data_preprocessing_hard(user_text_input)])
predictions = classifier_bag.predict(X)
elif selected_model_name == "TF-IDF":
X = vectorizer_2.transform([data_preprocessing_hard(user_text_input)])
predictions = classifier_tf.predict(X)
elif selected_model_name == "LSTM":
predictions = predict_review(model=lstm, review_text=user_text_input, net_config=net_config,
vocab_to_int=vocab_to_int)
elif selected_model_name == "BERT-based-ru":
tokens = tokenizer.encode(user_text_input, add_special_tokens=True)
padded_tokens = tokens + [0] * (MAX_LEN - len(tokens))
input_tensor = tensor(padded_tokens).unsqueeze(0)
with torch.no_grad():
outputs = model(input_tensor)
X = outputs.last_hidden_state[:,0,:].detach().cpu().numpy()
predictions = BERT_lin_cl.predict(X)
pass
end_time = time.time()
prediction_time = end_time - start_time
model_message = f'Предсказание модели {selected_model_name}:'
if predictions >= 0.5:
# st.write(f'{model_message} кажется это положительный комментарий.')
gif_url = 'https://media2.giphy.com/media/v1.Y2lkPTc5MGI3NjExOTdnYjJ1eTE0bjRuMGptcjhpdTk2YTYzeXEzMzlidWFsamY2bW8wZyZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/LUg1GEjapflW7Vg6B9/giphy.gif'
st.image(gif_url, caption="Позитивный коментарий")
else:
# st.write(f'{model_message} кажется это негативный комментарий.')
gif_url = 'https://i.gifer.com/LdC3.gif'
st.image(gif_url, caption="Негативный коментарий")
st.write(f'Время предсказания: {prediction_time:.4f} секунд')
# Оценка степени токсичности пользовательского сообщения
elif selected_model == "Оценка степени токсичности пользовательского сообщения":
st.title("""
Приложение классифицирует токсичный комментарий или нет
""")
st.write("""
Классификация происходит с использованием нейросетевой модели rubert-tiny-toxicity.
""")
# Toxicity
model_t_checkpoint = 'cointegrated/rubert-tiny-toxicity'
tokenizer_t = AutoTokenizer.from_pretrained(model_t_checkpoint)
model_t = AutoModelForSequenceClassification.from_pretrained(model_t_checkpoint)
def text2toxicity(text, aggregate=True):
with torch.no_grad():
inputs = tokenizer_t(text, return_tensors='pt', truncation=True, padding=True).to(model_t.device)
proba = torch.sigmoid(model_t(**inputs).logits).cpu().numpy()
if isinstance(text, str):
proba = proba[0]
if aggregate:
return 1 - proba.T[0] * (1 - proba.T[-1])
return proba
user_text_input = st.text_area('Введите ваш отзыв здесь:')
if st.button('Предсказать'):
start_time = time.time()
proba = text2toxicity(user_text_input, True)
end_time = time.time()
prediction_time = end_time - start_time
model_message = f'Предсказание модели:'
if proba >= 0.5:
# st.write(f' Кажется это токсичный комментарий.')
gif_url = "https://media1.giphy.com/media/cInbau65cwPWUeGTIZ/giphy.gif?cid=6c09b952seqdtvky8yn2uq6bt3kvo1vu5sdzpkdznjvmtxsh&ep=v1_internal_gif_by_id&rid=giphy.gif&ct=s"
st.image(gif_url, caption="ТОКСИК")
else:
# st.write(f' Кажется это не токсичный комментарий.')
gif_url = 'https://i.gifer.com/origin/51/518fbbf9cf32763122f9466d3c686bb3_w200.gif'
st.image(gif_url, caption="МИЛОТА")
st.write(f'Время предсказания: {prediction_time:.4f} секунд')
# Генерация текста GPT-моделью
elif selected_model == "Генерация текста GPT-моделью по пользовательскому prompt":
st.title("""
Приложение генерирует текст по Вашему промту
""")
st.write("""
Для генерации текста используется предобученная сеть GPT.
""")
uploaded_img = st.sidebar.file_uploader('Загрузи свое космофото', type=["jpg", "png", "jpeg"])
if uploaded_img is not None:
input_img = io.imread(uploaded_img)
else:
input_img = io.imread('/Users/id/Documents/strlit/cv_project/Segm.jpg')