import streamlit as st from joblib import load from transformers import BertTokenizer, BertForSequenceClassification import torch from tensorflow.keras.models import load_model import tensorflow as tf from tensorflow.keras.preprocessing.text import Tokenizer from tensorflow.keras.preprocessing.sequence import pad_sequences import time from transformers import AutoTokenizer, AutoModelForSequenceClassification model_checkpoint = 'cointegrated/rubert-tiny-toxicity' toxicity_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) toxicity_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint) clf = load('my_model_filename.pkl') vectorizer = load('tfidf_vectorizer.pkl') scaler = load('scaler.joblib') tukinazor = load('tokenizer.pkl') rnn_model = load_model('path_to_my_model.h5') bert_model = BertForSequenceClassification.from_pretrained('my_bert_model') tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") bert_model = bert_model.to(device) labels = ["не токсичный", "оскорбляющий", "непристойный", "угрожающий", "опасный"] def text2toxicity(text, aggregate=True): """ Calculate toxicity of a text (if aggregate=True) or a vector of toxicity aspects (if aggregate=False)""" with torch.no_grad(): inputs = toxicity_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(toxicity_model.device) proba = torch.sigmoid(toxicity_model(**inputs).logits).cpu().numpy() if isinstance(text, str): proba = proba[0] if aggregate: return 1 - proba.T[0] * (1 - proba.T[-1]) else: result = {} for label, prob in zip(labels, proba): result[label] = prob return result def predict_text(text): sequences = tukinazor.texts_to_sequences([text]) padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=200, padding='post', truncating='post') predictions = rnn_model.predict(padded_sequences) predicted_class = tf.argmax(predictions, axis=-1).numpy()[0] return predicted_class def page_reviews_classification(): st.title("Модель классификации отзывов") # Ввод текста user_input = st.text_area("Введите текст отзыва:") if st.button("Классифицировать"): start_time = time.time() user_input_vec = vectorizer.transform([user_input]) sentence_vector_scaled = scaler.transform(user_input_vec) prediction = clf.predict( sentence_vector_scaled) elapsed_time = time.time() - start_time st.write(f"Прогнозируемый класс: {prediction[0]}") st.write(f"Время вычисления: {elapsed_time:.2f} сек.") user_input_rnn = st.text_area("Введите текст отзыва для RNN модели:") if st.button("Классифицировать с RNN"): start_time = time.time() prediction_rnn = predict_text(user_input_rnn) elapsed_time = time.time() - start_time st.write(f"Прогнозируемый класс с RNN: {prediction_rnn}") st.write(f"Время вычисления: {elapsed_time:.2f} сек.") user_input_bert = st.text_area("Введите текст отзыва для BERT:") if st.button("Классифицировать (BERT)"): start_time = time.time() encoding = tokenizer.encode_plus( user_input_bert, add_special_tokens=True, max_length=200, return_token_type_ids=False, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) input_ids = encoding['input_ids'].to(device) attention_mask = encoding['attention_mask'].to(device) with torch.no_grad(): outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask) predictions = torch.argmax(outputs.logits, dim=1) elapsed_time = time.time() - start_time st.write(f"Прогнозируемый класс (BERT): {predictions.item() + 1}") st.write(f"Время вычисления: {elapsed_time:.2f} сек.") def page_toxicity_analysis(): # Код для анализа токсичности текста с использованием модели cointegrated/rubert-tiny-toxicity user_input_toxicity = st.text_area("Введите текст для оценки токсичности:") if st.button("Оценить токсичность"): start_time = time.time() probs = text2toxicity(user_input_toxicity, aggregate=False) elapsed_time = time.time() - start_time for label, prob in probs.items(): st.write(f"Вероятность того что комментарий {label}: {prob:.4f}") def main(): page_selection = st.sidebar.selectbox("Выберите страницу:", ["Классификация отзывов", "Анализ токсичности"]) if page_selection == "Классификация отзывов": page_reviews_classification() elif page_selection == "Анализ токсичности": page_toxicity_analysis() if __name__ == "__main__": main()