Spaces:
Sleeping
Sleeping
File size: 2,708 Bytes
49e7ba4 053e03d 0ac7ea7 053e03d a8c7602 44cac7c 053e03d 44cac7c 053e03d 44cac7c a8c7602 44cac7c ce7fad9 b7882eb 053e03d a8c7602 9b52475 053e03d e1cbbad 053e03d a8c7602 053e03d a8c7602 053e03d e1cbbad a8c7602 e1cbbad a8c7602 053e03d 44cac7c eeef851 053e03d 4af482d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import transformers
import streamlit as st
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from PIL import Image
import torch
st.title("""
History Mistery
""")
# image = Image.open('data-scins.jpeg')
# st.image(image, caption='Current mood')
# Добавление слайдера
temperature = st.slider("Градус дичи", 1.0, 20.0, 1.0)
# Загрузка модели и токенизатора
# model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
# #Задаем класс модели (уже в streamlit/tg_bot)
model = GPT2LMHeadModel.from_pretrained(
'sberbank-ai/rugpt3small_based_on_gpt2',
output_attentions = False,
output_hidden_states = False,
)
tokenizer = GPT2Tokenizer.from_pretrained(
'sberbank-ai/rugpt3small_based_on_gpt2',
output_attentions = False,
output_hidden_states = False,
)
# # Вешаем сохраненные веса на нашу модель
model.load_state_dict(torch.load('model_history.pt',map_location=torch.device('cpu')))
# Функция для генерации текста
def generate_text(prompt):
# Преобразование входной строки в токены
input_ids = tokenizer.encode(prompt, return_tensors='pt')
# Генерация текста
output = model.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
temperature=1.0, top_k=50, top_p=0.6, no_repeat_ngram_size=3,
num_return_sequences=3)
# Декодирование сгенерированного текста
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
# Streamlit приложение
def main():
st.write("""
# GPT-3 генерация текста
""")
# Ввод строки пользователем
prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века на руси")
# # Генерация текста по введенной строке
# generated_text = generate_text(prompt)
# Создание кнопки "Сгенерировать"
generate_button = st.button("За работу!")
# Обработка события нажатия кнопки
if generate_button:
# Вывод сгенерированного текста
generated_text = generate_text(prompt)
st.subheader("Продолжение:")
st.write(generated_text)
if __name__ == "__main__":
main()
|