Spaces:
Sleeping
Sleeping
import streamlit as st | |
from joblib import load | |
from transformers import BertTokenizer, BertForSequenceClassification | |
import torch | |
from tensorflow.keras.models import load_model | |
import tensorflow as tf | |
from tensorflow.keras.preprocessing.text import Tokenizer | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
import time | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
model_checkpoint = 'cointegrated/rubert-tiny-toxicity' | |
toxicity_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
toxicity_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint) | |
clf = load('my_model_filename.pkl') | |
vectorizer = load('tfidf_vectorizer.pkl') | |
scaler = load('scaler.joblib') | |
tukinazor = load('tokenizer.pkl') | |
rnn_model = load_model('path_to_my_model.h5') | |
bert_model = BertForSequenceClassification.from_pretrained('my_bert_model') | |
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased') | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
bert_model = bert_model.to(device) | |
labels = ["не токсичный", "оскорбляющий", "непристойный", "угрожающий", "опасный"] | |
def text2toxicity(text, aggregate=True): | |
""" Calculate toxicity of a text (if aggregate=True) or a vector of toxicity aspects (if aggregate=False)""" | |
with torch.no_grad(): | |
inputs = toxicity_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(toxicity_model.device) | |
proba = torch.sigmoid(toxicity_model(**inputs).logits).cpu().numpy() | |
if isinstance(text, str): | |
proba = proba[0] | |
if aggregate: | |
return 1 - proba.T[0] * (1 - proba.T[-1]) | |
else: | |
result = {} | |
for label, prob in zip(labels, proba): | |
result[label] = prob | |
return result | |
def predict_text(text): | |
sequences = tukinazor.texts_to_sequences([text]) | |
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=200, padding='post', truncating='post') | |
predictions = rnn_model.predict(padded_sequences) | |
predicted_class = tf.argmax(predictions, axis=-1).numpy()[0] | |
return predicted_class | |
def page_reviews_classification(): | |
st.title("Модель классификации отзывов") | |
# Ввод текста | |
user_input = st.text_area("Введите текст отзыва:") | |
if st.button("Классифицировать"): | |
start_time = time.time() | |
user_input_vec = vectorizer.transform([user_input]) | |
sentence_vector_scaled = scaler.transform(user_input_vec) | |
prediction = clf.predict( | |
sentence_vector_scaled) | |
elapsed_time = time.time() - start_time | |
st.write(f"Прогнозируемый класс: {prediction[0]}") | |
st.write(f"Время вычисления: {elapsed_time:.2f} сек.") | |
user_input_rnn = st.text_area("Введите текст отзыва для RNN модели:") | |
if st.button("Классифицировать с RNN"): | |
start_time = time.time() | |
prediction_rnn = predict_text(user_input_rnn) | |
elapsed_time = time.time() - start_time | |
st.write(f"Прогнозируемый класс с RNN: {prediction_rnn}") | |
st.write(f"Время вычисления: {elapsed_time:.2f} сек.") | |
user_input_bert = st.text_area("Введите текст отзыва для BERT:") | |
if st.button("Классифицировать (BERT)"): | |
start_time = time.time() | |
encoding = tokenizer.encode_plus( | |
user_input_bert, | |
add_special_tokens=True, | |
max_length=200, | |
return_token_type_ids=False, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
input_ids = encoding['input_ids'].to(device) | |
attention_mask = encoding['attention_mask'].to(device) | |
with torch.no_grad(): | |
outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask) | |
predictions = torch.argmax(outputs.logits, dim=1) | |
elapsed_time = time.time() - start_time | |
st.write(f"Прогнозируемый класс (BERT): {predictions.item() + 1}") | |
st.write(f"Время вычисления: {elapsed_time:.2f} сек.") | |
def page_toxicity_analysis(): | |
# Код для анализа токсичности текста с использованием модели cointegrated/rubert-tiny-toxicity | |
user_input_toxicity = st.text_area("Введите текст для оценки токсичности:") | |
if st.button("Оценить токсичность"): | |
start_time = time.time() | |
probs = text2toxicity(user_input_toxicity, aggregate=False) | |
elapsed_time = time.time() - start_time | |
for label, prob in probs.items(): | |
st.write(f"Вероятность того что комментарий {label}: {prob:.4f}") | |
def main(): | |
page_selection = st.sidebar.selectbox("Выберите страницу:", ["Классификация отзывов", "Анализ токсичности"]) | |
if page_selection == "Классификация отзывов": | |
page_reviews_classification() | |
elif page_selection == "Анализ токсичности": | |
page_toxicity_analysis() | |
if __name__ == "__main__": | |
main() |