AntNikYab commited on
Commit
f7e3be7
·
1 Parent(s): 1b8e753

Upload TheBroCode.py

Browse files
Files changed (1) hide show
  1. pages/TheBroCode.py +64 -0
pages/TheBroCode.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import textwrap
3
+ import torch
4
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
+
6
+ DEVICE = torch.device("cpu")
7
+ # Load GPT-2 model and tokenizer
8
+ tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
9
+ model_finetuned = GPT2LMHeadModel.from_pretrained(
10
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
11
+ output_attentions = False,
12
+ output_hidden_states = False,
13
+ )
14
+ if torch.cuda.is_available():
15
+ model_finetuned.load_state_dict(torch.load('models/brat.pt'))
16
+ else:
17
+ model_finetuned.load_state_dict(torch.load('models/brat.pt', map_location=torch.device('cpu')))
18
+ model_finetuned.eval()
19
+
20
+ # Function to generate text
21
+ def generate_text(prompt, temperature, top_p, max_length, top_k):
22
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
23
+
24
+ with torch.no_grad():
25
+ out = model_finetuned.generate(
26
+ input_ids,
27
+ do_sample=True,
28
+ num_beams=5,
29
+ temperature=temperature,
30
+ top_p=top_p,
31
+ max_length=max_length,
32
+ top_k=top_k,
33
+ no_repeat_ngram_size=3,
34
+ num_return_sequences=1,
35
+ )
36
+
37
+ generated_text = list(map(tokenizer.decode, out))
38
+ return generated_text
39
+
40
+ # Streamlit app
41
+ def main():
42
+ st.title("Генерация текста 'Кодекс Братана'")
43
+
44
+ # User inputs
45
+ prompt = st.text_area("Введите начало текста")
46
+ temperature = st.slider("Temperature", min_value=0.2, max_value=2.5, value=1.8, step=0.1)
47
+ top_p = st.slider("Top-p", min_value=0.1, max_value=1.0, value=0.9, step=0.1)
48
+ max_length = st.slider("Max Length", min_value=10, max_value=300, value=100, step=10)
49
+ top_k = st.slider("Top-k", min_value=1, max_value=500, value=500, step=10)
50
+ num_return_sequences = st.slider("Number of Sequences", min_value=1, max_value=5, value=1, step=1)
51
+
52
+ if st.button("Generate Text"):
53
+ st.subheader("Generated Text:")
54
+ for i in range(num_return_sequences):
55
+ generated_text = generate_text(prompt, temperature, top_p, max_length, top_k)
56
+ st.write(f"Generated Text {i + 1}:")
57
+ wrapped_text = textwrap.fill(generated_text[0], width=80)
58
+ st.write(wrapped_text)
59
+ st.write("------------------")
60
+
61
+ st.sidebar.image('images/theBROcode.jpeg', use_column_width=True)
62
+
63
+ if __name__ == "__main__":
64
+ main()