nlp-gpt-team / app.py
Vladislawoo's picture
Update app.py
6f95ca2
raw
history blame
5.5 kB
import streamlit as st
from joblib import load
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from tensorflow.keras.models import load_model
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import time
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_checkpoint = 'cointegrated/rubert-tiny-toxicity'
toxicity_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
toxicity_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
clf = load('my_model_filename.pkl')
vectorizer = load('tfidf_vectorizer.pkl')
scaler = load('scaler.joblib')
tukinazor = load('tokenizer.pkl')
rnn_model = load_model('path_to_my_model.h5')
bert_model = BertForSequenceClassification.from_pretrained('my_bert_model')
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_model = bert_model.to(device)
labels = ["не токсичный", "оскорбляющий", "непристойный", "угрожающий", "опасный"]
def text2toxicity(text, aggregate=True):
""" Calculate toxicity of a text (if aggregate=True) or a vector of toxicity aspects (if aggregate=False)"""
with torch.no_grad():
inputs = toxicity_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(toxicity_model.device)
proba = torch.sigmoid(toxicity_model(**inputs).logits).cpu().numpy()
if isinstance(text, str):
proba = proba[0]
if aggregate:
return 1 - proba.T[0] * (1 - proba.T[-1])
else:
result = {}
for label, prob in zip(labels, proba):
result[label] = prob
return result
def predict_text(text):
sequences = tukinazor.texts_to_sequences([text])
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=200, padding='post', truncating='post')
predictions = rnn_model.predict(padded_sequences)
predicted_class = tf.argmax(predictions, axis=-1).numpy()[0]
return predicted_class
def page_reviews_classification():
st.title("Модель классификации отзывов")
# Ввод текста
user_input = st.text_area("Введите текст отзыва:")
if st.button("Классифицировать"):
start_time = time.time()
user_input_vec = vectorizer.transform([user_input])
sentence_vector_scaled = scaler.transform(user_input_vec)
prediction = clf.predict(
sentence_vector_scaled)
elapsed_time = time.time() - start_time
st.write(f"Прогнозируемый класс: {prediction[0]}")
st.write(f"Время вычисления: {elapsed_time:.2f} сек.")
user_input_rnn = st.text_area("Введите текст отзыва для RNN модели:")
if st.button("Классифицировать с RNN"):
start_time = time.time()
prediction_rnn = predict_text(user_input_rnn)
elapsed_time = time.time() - start_time
st.write(f"Прогнозируемый класс с RNN: {prediction_rnn}")
st.write(f"Время вычисления: {elapsed_time:.2f} сек.")
user_input_bert = st.text_area("Введите текст отзыва для BERT:")
if st.button("Классифицировать (BERT)"):
start_time = time.time()
encoding = tokenizer.encode_plus(
user_input_bert,
add_special_tokens=True,
max_length=200,
return_token_type_ids=False,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
predictions = torch.argmax(outputs.logits, dim=1)
elapsed_time = time.time() - start_time
st.write(f"Прогнозируемый класс (BERT): {predictions.item() + 1}")
st.write(f"Время вычисления: {elapsed_time:.2f} сек.")
def page_toxicity_analysis():
# Код для анализа токсичности текста с использованием модели cointegrated/rubert-tiny-toxicity
user_input_toxicity = st.text_area("Введите текст для оценки токсичности:")
if st.button("Оценить токсичность"):
start_time = time.time()
probs = text2toxicity(user_input_toxicity, aggregate=False)
elapsed_time = time.time() - start_time
for label, prob in probs.items():
st.write(f"Вероятность того что комментарий {label}: {prob:.4f}")
def main():
page_selection = st.sidebar.selectbox("Выберите страницу:", ["Классификация отзывов", "Анализ токсичности"])
if page_selection == "Классификация отзывов":
page_reviews_classification()
elif page_selection == "Анализ токсичности":
page_toxicity_analysis()
if __name__ == "__main__":
main()