|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
from bert_file import BERTClassifier |
|
import numpy as np |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") |
|
model = BERTClassifier() |
|
device = 'cpu' |
|
|
|
model.load_state_dict(torch.load('BERTmodel_weights2.pth',map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
@st.cache_data |
|
def predict_sentiment(text): |
|
MAX_LEN = 100 |
|
encoded_review = tokenizer.encode_plus( |
|
text, |
|
max_length=MAX_LEN, |
|
add_special_tokens=True, |
|
return_token_type_ids=False, |
|
pad_to_max_length=True, |
|
return_attention_mask=True, |
|
return_tensors='pt', |
|
) |
|
input_ids = encoded_review['input_ids'].to(device) |
|
attention_mask = encoded_review['attention_mask'].to(device) |
|
|
|
with torch.no_grad(): |
|
output = model(input_ids, attention_mask) |
|
prediction = torch.round(output).cpu().numpy()[0][0] |
|
if prediction == 1: |
|
return "Позитивный отзыв 😀" |
|
else: |
|
return "Негативный отзыв 😟" |
|
|
|
def bert_model_page(): |
|
st.title("Классификация отзывов") |
|
user_input = st.text_area("Введите отзыв:") |
|
if st.button("Классифицировать"): |
|
if user_input: |
|
prediction = predict_sentiment(user_input) |
|
st.write(prediction) |
|
else: |
|
st.write("Пожалуйста, введите отзыв для классификации.") |
|
|