File size: 2,708 Bytes
49e7ba4
053e03d
 
 
 
0ac7ea7
053e03d
a8c7602
44cac7c
053e03d
 
 
 
44cac7c
 
053e03d
44cac7c
 
a8c7602
44cac7c
ce7fad9
 
 
 
b7882eb
 
 
 
 
053e03d
a8c7602
9b52475
053e03d
 
 
 
 
 
 
 
 
 
 
 
 
 
e1cbbad
053e03d
 
a8c7602
 
 
053e03d
 
a8c7602
053e03d
e1cbbad
 
 
 
a8c7602
 
 
e1cbbad
a8c7602
 
053e03d
44cac7c
 
eeef851
053e03d
4af482d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from PIL import Image
import torch

st.title("""
 History Mistery
 """)
# image = Image.open('data-scins.jpeg')

# st.image(image, caption='Current mood')
# Добавление слайдера
temperature = st.slider("Градус дичи", 1.0, 20.0, 1.0)
# Загрузка модели и токенизатора
# model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# #Задаем класс модели (уже в streamlit/tg_bot)
model = GPT2LMHeadModel.from_pretrained(
     'sberbank-ai/rugpt3small_based_on_gpt2',
     output_attentions = False,
     output_hidden_states = False,
 )
tokenizer = GPT2Tokenizer.from_pretrained(
     'sberbank-ai/rugpt3small_based_on_gpt2',
     output_attentions = False,
     output_hidden_states = False,
 )

# # Вешаем сохраненные веса на нашу модель
model.load_state_dict(torch.load('model_history.pt',map_location=torch.device('cpu')))
# Функция для генерации текста
def generate_text(prompt):
    # Преобразование входной строки в токены
    input_ids = tokenizer.encode(prompt, return_tensors='pt')

    # Генерация текста
    output = model.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
                            temperature=1.0, top_k=50, top_p=0.6, no_repeat_ngram_size=3,
                            num_return_sequences=3)

    # Декодирование сгенерированного текста
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

    return generated_text

# Streamlit приложение
def main():
    st.write("""
    # GPT-3 генерация текста
    """)

    # Ввод строки пользователем
    prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на руси")

    # # Генерация текста по введенной строке
    # generated_text = generate_text(prompt)
    # Создание кнопки "Сгенерировать"
    generate_button = st.button("За работу!")
    # Обработка события нажатия кнопки
    if generate_button:
    # Вывод сгенерированного текста
        generated_text = generate_text(prompt)
        st.subheader("Продолжение:")
        st.write(generated_text)



if __name__ == "__main__":
    main()