import streamlit as st import numpy as np import pandas as pd import time import torch import torch.nn as nn from torch import tensor import joblib from dataclasses import dataclass from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification import json from preprocessing import predict_review, data_preprocessing_hard from model_lstm import LSTMClassifier # from BERT_inputs import BertInputs device = 'cpu' classifier_bag = joblib.load('classifier_bag.pkl') classifier_tf = joblib.load('classifier_tf.pkl') BERT_lin_cl = joblib.load('BERT_base_model.pkl') selected_model = st.sidebar.radio("Зачем пришел?", ("Классифиция отзывов лечебных учреждений", "Оценка степени токсичности пользовательского сообщения", "Генерация текста GPT-моделью по пользовательскому prompt")) # Классификация отзыва на поликлиники model_options = ["BagOfWords", "TF-IDF", "LSTM", "BERT-based-ru"] if selected_model == "Классифиция отзывов лечебных учреждений": st.title(""" Приложение классифицирует твой отзыв и подскажет позитивный он или негативный """) st.write(""" Классификация происходит с использованием классических ML моделей, нейросетевой модели LSTM, и, как вариант, с использованием нейросетевой модели Bert-basic-ru для векторизации и линейной регрессии для классификации. """) vectorizer_1 = joblib.load('vectorizer_bag.joblib') vectorizer_2 = joblib.load('vectorizer_tf.joblib') # LSTM with open('vocab_lstm.json', 'r') as file: vocab_to_int = json.load(file) @dataclass class ConfigRNN: vocab_size: int device : str n_layers : int embedding_dim : int hidden_size : int seq_len : int bidirectional : bool or int net_config = ConfigRNN( vocab_size = len(vocab_to_int)+1, device='cpu', n_layers=2, embedding_dim=64, hidden_size=32, seq_len = 100, bidirectional=False) lstm = LSTMClassifier(net_config) lstm.load_state_dict(torch.load('lstm_model.pth', map_location=device)) lstm.to(device) # lstm.eval() # BERT tokenizer = AutoTokenizer.from_pretrained("Geotrend/bert-base-ru-cased") model = AutoModel.from_pretrained("Geotrend/bert-base-ru-cased") # model.eval() MAX_LEN = 200 data = pd.DataFrame({ 'Модель': ["BagOfWords", "TF-IDF", "LSTM", "BERT-based-ru"], 'f1_macro': [0.934, 0.939, 0.009, 0.845] }) st.subheader(""" Немного информации о точности используемых моделей после обучения: """) # st.write(data) st.table(data) user_text_input = st.text_area('Введите ваш отзыв здесь:', '') selected_model_name = st.selectbox('Выберите модель:', model_options, index=0) if st.button('Предсказать'): start_time = time.time() if selected_model_name == "BagOfWords": X = vectorizer_1.transform([data_preprocessing_hard(user_text_input)]) predictions = classifier_bag.predict(X) elif selected_model_name == "TF-IDF": X = vectorizer_2.transform([data_preprocessing_hard(user_text_input)]) predictions = classifier_tf.predict(X) elif selected_model_name == "LSTM": predictions = predict_review(model=lstm, review_text=user_text_input, net_config=net_config, vocab_to_int=vocab_to_int) elif selected_model_name == "BERT-based-ru": tokens = tokenizer.encode(user_text_input, add_special_tokens=True) padded_tokens = tokens + [0] * (MAX_LEN - len(tokens)) input_tensor = tensor(padded_tokens).unsqueeze(0) with torch.no_grad(): outputs = model(input_tensor) X = outputs.last_hidden_state[:,0,:].detach().cpu().numpy() predictions = BERT_lin_cl.predict(X) pass end_time = time.time() prediction_time = end_time - start_time model_message = f'Предсказание модели {selected_model_name}:' if predictions >= 0.5: # st.write(f'{model_message} кажется это положительный комментарий.') gif_url = 'https://media2.giphy.com/media/v1.Y2lkPTc5MGI3NjExOTdnYjJ1eTE0bjRuMGptcjhpdTk2YTYzeXEzMzlidWFsamY2bW8wZyZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/LUg1GEjapflW7Vg6B9/giphy.gif' st.image(gif_url, caption="Позитивный коментарий") else: # st.write(f'{model_message} кажется это негативный комментарий.') gif_url = 'https://i.gifer.com/LdC3.gif' st.image(gif_url, caption="Негативный коментарий") st.write(f'Время предсказания: {prediction_time:.4f} секунд') # Оценка степени токсичности пользовательского сообщения elif selected_model == "Оценка степени токсичности пользовательского сообщения": st.title(""" Приложение классифицирует токсичный комментарий или нет """) st.write(""" Классификация происходит с использованием нейросетевой модели rubert-tiny-toxicity. """) # Toxicity model_t_checkpoint = 'cointegrated/rubert-tiny-toxicity' tokenizer_t = AutoTokenizer.from_pretrained(model_t_checkpoint) model_t = AutoModelForSequenceClassification.from_pretrained(model_t_checkpoint) def text2toxicity(text, aggregate=True): with torch.no_grad(): inputs = tokenizer_t(text, return_tensors='pt', truncation=True, padding=True).to(model_t.device) proba = torch.sigmoid(model_t(**inputs).logits).cpu().numpy() if isinstance(text, str): proba = proba[0] if aggregate: return 1 - proba.T[0] * (1 - proba.T[-1]) return proba user_text_input = st.text_area('Введите ваш отзыв здесь:') if st.button('Предсказать'): start_time = time.time() proba = text2toxicity(user_text_input, True) end_time = time.time() prediction_time = end_time - start_time model_message = f'Предсказание модели:' if proba >= 0.5: # st.write(f' Кажется это токсичный комментарий.') gif_url = "https://media1.giphy.com/media/cInbau65cwPWUeGTIZ/giphy.gif?cid=6c09b952seqdtvky8yn2uq6bt3kvo1vu5sdzpkdznjvmtxsh&ep=v1_internal_gif_by_id&rid=giphy.gif&ct=s" st.image(gif_url, caption="ТОКСИК") else: # st.write(f' Кажется это не токсичный комментарий.') gif_url = 'https://i.gifer.com/origin/51/518fbbf9cf32763122f9466d3c686bb3_w200.gif' st.image(gif_url, caption="МИЛОТА") st.write(f'Время предсказания: {prediction_time:.4f} секунд') # Генерация текста GPT-моделью elif selected_model == "Генерация текста GPT-моделью по пользовательскому prompt": st.title(""" Приложение генерирует текст по Вашему промту """) st.write(""" Для генерации текста используется предобученная сеть GPT. """) uploaded_img = st.sidebar.file_uploader('Загрузи свое космофото', type=["jpg", "png", "jpeg"]) if uploaded_img is not None: input_img = io.imread(uploaded_img) else: input_img = io.imread('/Users/id/Documents/strlit/cv_project/Segm.jpg')