|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
import time |
|
import torch |
|
import torch.nn as nn |
|
from torch import tensor |
|
|
|
import joblib |
|
from dataclasses import dataclass |
|
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification |
|
import json |
|
|
|
from preprocessing import predict_review, data_preprocessing_hard |
|
from model_lstm import LSTMClassifier |
|
|
|
|
|
device = 'cpu' |
|
classifier_bag = joblib.load('classifier_bag.pkl') |
|
classifier_tf = joblib.load('classifier_tf.pkl') |
|
BERT_lin_cl = joblib.load('BERT_base_model.pkl') |
|
|
|
selected_model = st.sidebar.radio("Зачем пришел?", ("Классифиция отзывов лечебных учреждений", |
|
"Оценка степени токсичности пользовательского сообщения", |
|
"Генерация текста GPT-моделью по пользовательскому prompt")) |
|
|
|
|
|
model_options = ["BagOfWords", "TF-IDF", "LSTM", "BERT-based-ru"] |
|
if selected_model == "Классифиция отзывов лечебных учреждений": |
|
st.title(""" |
|
Приложение классифицирует твой отзыв и подскажет позитивный он или негативный |
|
""") |
|
st.write(""" |
|
Классификация происходит с использованием классических ML моделей, нейросетевой модели LSTM, |
|
и, как вариант, с использованием нейросетевой модели Bert-basic-ru для векторизации и линейной |
|
регрессии для классификации. |
|
""") |
|
vectorizer_1 = joblib.load('vectorizer_bag.joblib') |
|
vectorizer_2 = joblib.load('vectorizer_tf.joblib') |
|
|
|
|
|
with open('vocab_lstm.json', 'r') as file: |
|
vocab_to_int = json.load(file) |
|
|
|
@dataclass |
|
class ConfigRNN: |
|
vocab_size: int |
|
device : str |
|
n_layers : int |
|
embedding_dim : int |
|
hidden_size : int |
|
seq_len : int |
|
bidirectional : bool or int |
|
|
|
net_config = ConfigRNN( |
|
vocab_size = len(vocab_to_int)+1, |
|
device='cpu', |
|
n_layers=2, |
|
embedding_dim=64, |
|
hidden_size=32, |
|
seq_len = 100, |
|
bidirectional=False) |
|
|
|
lstm = LSTMClassifier(net_config) |
|
lstm.load_state_dict(torch.load('lstm_model.pth', map_location=device)) |
|
lstm.to(device) |
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Geotrend/bert-base-ru-cased") |
|
model = AutoModel.from_pretrained("Geotrend/bert-base-ru-cased") |
|
|
|
MAX_LEN = 200 |
|
|
|
data = pd.DataFrame({ |
|
'Модель': ["BagOfWords", "TF-IDF", "LSTM", "BERT-based-ru"], |
|
'f1_macro': [0.934, 0.939, 0.009, 0.845] |
|
}) |
|
|
|
st.subheader(""" |
|
Немного информации о точности используемых моделей после обучения: |
|
""") |
|
|
|
st.table(data) |
|
user_text_input = st.text_area('Введите ваш отзыв здесь:', '') |
|
selected_model_name = st.selectbox('Выберите модель:', model_options, index=0) |
|
|
|
if st.button('Предсказать'): |
|
start_time = time.time() |
|
|
|
if selected_model_name == "BagOfWords": |
|
X = vectorizer_1.transform([data_preprocessing_hard(user_text_input)]) |
|
predictions = classifier_bag.predict(X) |
|
|
|
elif selected_model_name == "TF-IDF": |
|
X = vectorizer_2.transform([data_preprocessing_hard(user_text_input)]) |
|
predictions = classifier_tf.predict(X) |
|
|
|
elif selected_model_name == "LSTM": |
|
predictions = predict_review(model=lstm, review_text=user_text_input, net_config=net_config, |
|
vocab_to_int=vocab_to_int) |
|
|
|
elif selected_model_name == "BERT-based-ru": |
|
tokens = tokenizer.encode(user_text_input, add_special_tokens=True) |
|
padded_tokens = tokens + [0] * (MAX_LEN - len(tokens)) |
|
input_tensor = tensor(padded_tokens).unsqueeze(0) |
|
with torch.no_grad(): |
|
outputs = model(input_tensor) |
|
X = outputs.last_hidden_state[:,0,:].detach().cpu().numpy() |
|
predictions = BERT_lin_cl.predict(X) |
|
pass |
|
|
|
end_time = time.time() |
|
prediction_time = end_time - start_time |
|
|
|
model_message = f'Предсказание модели {selected_model_name}:' |
|
if predictions >= 0.5: |
|
|
|
gif_url = 'https://media2.giphy.com/media/v1.Y2lkPTc5MGI3NjExOTdnYjJ1eTE0bjRuMGptcjhpdTk2YTYzeXEzMzlidWFsamY2bW8wZyZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/LUg1GEjapflW7Vg6B9/giphy.gif' |
|
st.image(gif_url, caption="Позитивный коментарий") |
|
else: |
|
|
|
gif_url = 'https://i.gifer.com/LdC3.gif' |
|
st.image(gif_url, caption="Негативный коментарий") |
|
st.write(f'Время предсказания: {prediction_time:.4f} секунд') |
|
|
|
|
|
|
|
|
|
elif selected_model == "Оценка степени токсичности пользовательского сообщения": |
|
st.title(""" |
|
Приложение классифицирует токсичный комментарий или нет |
|
""") |
|
|
|
st.write(""" |
|
Классификация происходит с использованием нейросетевой модели rubert-tiny-toxicity. |
|
""") |
|
|
|
|
|
model_t_checkpoint = 'cointegrated/rubert-tiny-toxicity' |
|
tokenizer_t = AutoTokenizer.from_pretrained(model_t_checkpoint) |
|
model_t = AutoModelForSequenceClassification.from_pretrained(model_t_checkpoint) |
|
|
|
def text2toxicity(text, aggregate=True): |
|
with torch.no_grad(): |
|
inputs = tokenizer_t(text, return_tensors='pt', truncation=True, padding=True).to(model_t.device) |
|
proba = torch.sigmoid(model_t(**inputs).logits).cpu().numpy() |
|
if isinstance(text, str): |
|
proba = proba[0] |
|
if aggregate: |
|
return 1 - proba.T[0] * (1 - proba.T[-1]) |
|
return proba |
|
|
|
user_text_input = st.text_area('Введите ваш отзыв здесь:') |
|
|
|
if st.button('Предсказать'): |
|
start_time = time.time() |
|
proba = text2toxicity(user_text_input, True) |
|
end_time = time.time() |
|
prediction_time = end_time - start_time |
|
|
|
model_message = f'Предсказание модели:' |
|
if proba >= 0.5: |
|
|
|
gif_url = "https://media1.giphy.com/media/cInbau65cwPWUeGTIZ/giphy.gif?cid=6c09b952seqdtvky8yn2uq6bt3kvo1vu5sdzpkdznjvmtxsh&ep=v1_internal_gif_by_id&rid=giphy.gif&ct=s" |
|
st.image(gif_url, caption="ТОКСИК") |
|
|
|
else: |
|
|
|
gif_url = 'https://i.gifer.com/origin/51/518fbbf9cf32763122f9466d3c686bb3_w200.gif' |
|
st.image(gif_url, caption="МИЛОТА") |
|
st.write(f'Время предсказания: {prediction_time:.4f} секунд') |
|
|
|
|
|
|
|
|
|
elif selected_model == "Генерация текста GPT-моделью по пользовательскому prompt": |
|
st.title(""" |
|
Приложение генерирует текст по Вашему промту |
|
""") |
|
|
|
st.write(""" |
|
Для генерации текста используется предобученная сеть GPT. |
|
""") |
|
uploaded_img = st.sidebar.file_uploader('Загрузи свое космофото', type=["jpg", "png", "jpeg"]) |
|
if uploaded_img is not None: |
|
input_img = io.imread(uploaded_img) |
|
else: |
|
input_img = io.imread('/Users/id/Documents/strlit/cv_project/Segm.jpg') |
|
|