history_mistery / app.py
SaviAnna's picture
Update app.py
4af482d
raw
history blame
2.45 kB
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from PIL import Image
st.title("""
History Mistery
""")
# image = Image.open('data-scins.jpeg')
# st.image(image, caption='Current mood')
# Добавление слайдера
temperature = st.slider("Градус дичи", 1.0, 20.0, 1.0)
# Загрузка модели и токенизатора
# model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# #Задаем класс модели (уже в streamlit/tg_bot)
model = GPT2LMHeadModel.from_pretrained(
'sberbank-ai/rugpt3small_based_on_gpt2',
output_attentions = False,
output_hidden_states = False,
)
# # Вешаем сохраненные веса на нашу модель
model.load_state_dict(torch.load('model_history.pt'))
# Функция для генерации текста
def generate_text(prompt):
# Преобразование входной строки в токены
input_ids = tokenizer.encode(prompt, return_tensors='pt')
# Генерация текста
output = model.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
temperature=1.0, top_k=50, top_p=0.6, no_repeat_ngram_size=3,
num_return_sequences=3)
# Декодирование сгенерированного текста
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
# Создание кнопки "Сгенерировать"
generate_button = st.button("За работу!")
# Streamlit приложение
def main():
st.write("""
# GPT-3 генерация текста
""")
# Ввод строки пользователем
prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на руси")
# Генерация текста по введенной строке
generated_text = generate_text(prompt)
# Обработка события нажатия кнопки
if generate_button:
# Вывод сгенерированного текста
st.subheader("Продолжение:")
st.write(generated_text)
if __name__ == "__main__":
main()