Vladislawoo commited on
Commit
11ad39a
·
1 Parent(s): b7d25c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -3
app.py CHANGED
@@ -1,9 +1,31 @@
1
  import streamlit as st
2
  from joblib import load
 
 
 
 
 
 
 
 
3
  clf = load('my_model_filename.pkl')
4
  vectorizer = load('tfidf_vectorizer.pkl')
5
  scaler = load('scaler.joblib')
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Запуск приложения
9
  def main():
@@ -13,14 +35,47 @@ def main():
13
  user_input = st.text_area("Введите текст отзыва:")
14
 
15
  if st.button("Классифицировать"):
16
- # Векторизация текста (если вы использовали TF-IDF или другой векторизатор)
17
  user_input_vec = vectorizer.transform([user_input])
18
  sentence_vector_scaled = scaler.transform(user_input_vec)
19
- # Прогноз модели
20
  prediction = clf.predict(
21
- sentence_vector_scaled) # Используйте user_input_vec вместо user_input, если текст нужно векторизировать
 
22
  st.write(f"Прогнозируемый класс: {prediction[0]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
 
 
 
 
 
 
24
 
25
  if __name__ == "__main__":
26
  main()
 
1
  import streamlit as st
2
  from joblib import load
3
+ from transformers import BertTokenizer, BertForSequenceClassification
4
+ import torch
5
+ from tensorflow.keras.models import load_model
6
+ import tensorflow as tf
7
+ from tensorflow.keras.preprocessing.text import Tokenizer
8
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
9
+ import time
10
+
11
  clf = load('my_model_filename.pkl')
12
  vectorizer = load('tfidf_vectorizer.pkl')
13
  scaler = load('scaler.joblib')
14
 
15
+ tukinazor = load('tokenizer.pkl')
16
+ rnn_model = load_model('path_to_my_model.h5')
17
+ bert_model = BertForSequenceClassification.from_pretrained('my_bert_model')
18
+ tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ bert_model = bert_model.to(device)
21
+
22
+ def predict_text(text):
23
+ sequences = tukinazor.texts_to_sequences([text])
24
+ padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=200, padding='post', truncating='post')
25
+ predictions = rnn_model.predict(padded_sequences)
26
+ predicted_class = tf.argmax(predictions, axis=-1).numpy()[0]
27
+ return predicted_class
28
+
29
 
30
  # Запуск приложения
31
  def main():
 
35
  user_input = st.text_area("Введите текст отзыва:")
36
 
37
  if st.button("Классифицировать"):
38
+ start_time = time.time()
39
  user_input_vec = vectorizer.transform([user_input])
40
  sentence_vector_scaled = scaler.transform(user_input_vec)
 
41
  prediction = clf.predict(
42
+ sentence_vector_scaled)
43
+ elapsed_time = time.time() - start_time
44
  st.write(f"Прогнозируемый класс: {prediction[0]}")
45
+ st.write(f"Время вычисления: {elapsed_time:.2f} сек.")
46
+
47
+ user_input_rnn = st.text_area("Введите текст отзыва для Keras RNN модели:")
48
+
49
+ if st.button("Классифицировать с RNN"):
50
+ start_time = time.time()
51
+ prediction_rnn = predict_text(user_input_rnn)
52
+ elapsed_time = time.time() - start_time
53
+ st.write(f"Прогнозируемый класс с RNN: {prediction_rnn}")
54
+ st.write(f"Время вычисления: {elapsed_time:.2f} сек.")
55
+
56
+ user_input_bert = st.text_area("Введите текст отзыва для BERT:")
57
+
58
+ if st.button("Классифицировать (BERT)"):
59
+ start_time = time.time()
60
+ encoding = tokenizer.encode_plus(
61
+ user_input_bert,
62
+ add_special_tokens=True,
63
+ max_length=200,
64
+ return_token_type_ids=False,
65
+ padding='max_length',
66
+ truncation=True,
67
+ return_attention_mask=True,
68
+ return_tensors='pt'
69
+ )
70
+ input_ids = encoding['input_ids'].to(device)
71
+ attention_mask = encoding['attention_mask'].to(device)
72
 
73
+ with torch.no_grad():
74
+ outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
75
+ predictions = torch.argmax(outputs.logits, dim=1)
76
+ elapsed_time = time.time() - start_time
77
+ st.write(f"Прогнозируемый класс (BERT): {predictions.item() + 1}")
78
+ st.write(f"Время вычисления: {elapsed_time:.2f} сек.")
79
 
80
  if __name__ == "__main__":
81
  main()