bankholdup commited on
Commit
b36e2b2
·
1 Parent(s): 909ce17

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -251
app.py DELETED
@@ -1,251 +0,0 @@
1
- import pystuck
2
- pystuck.run_server()
3
-
4
- import os
5
-
6
- import argparse
7
- import logging
8
-
9
- import numpy as np
10
- import torch
11
- import datetime
12
- import gradio as gr
13
-
14
- from transformers import (
15
- CTRLLMHeadModel,
16
- CTRLTokenizer,
17
- GPT2LMHeadModel,
18
- GPT2Tokenizer,
19
- OpenAIGPTLMHeadModel,
20
- OpenAIGPTTokenizer,
21
- TransfoXLLMHeadModel,
22
- TransfoXLTokenizer,
23
- XLMTokenizer,
24
- XLMWithLMHeadModel,
25
- XLNetLMHeadModel,
26
- XLNetTokenizer,
27
- )
28
-
29
-
30
- logging.basicConfig(
31
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,
32
- )
33
- logger = logging.getLogger(__name__)
34
-
35
- MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
36
-
37
- MODEL_CLASSES = {
38
- "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
39
- "ctrl": (CTRLLMHeadModel, CTRLTokenizer),
40
- "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
41
- "xlnet": (XLNetLMHeadModel, XLNetTokenizer),
42
- "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
43
- "xlm": (XLMWithLMHeadModel, XLMTokenizer),
44
- }
45
-
46
- def set_seed(args):
47
- rd = np.random.randint(100000)
48
- print('seed =', rd)
49
- np.random.seed(rd)
50
- torch.manual_seed(rd)
51
- if args.n_gpu > 0:
52
- torch.cuda.manual_seed_all(rd)
53
-
54
- #
55
- # Functions to prepare models' input
56
- #
57
-
58
-
59
- def prepare_ctrl_input(args, _, tokenizer, prompt_text):
60
- if args.temperature > 0.7:
61
- logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
62
-
63
- encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
64
- if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
65
- logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
66
- return prompt_text
67
-
68
-
69
- def prepare_xlm_input(args, model, tokenizer, prompt_text):
70
- # kwargs = {"language": None, "mask_token_id": None}
71
-
72
- # Set the language
73
- use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
74
- if hasattr(model.config, "lang2id") and use_lang_emb:
75
- available_languages = model.config.lang2id.keys()
76
- if args.xlm_language in available_languages:
77
- language = args.xlm_language
78
- else:
79
- language = None
80
- while language not in available_languages:
81
- language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
82
-
83
- model.config.lang_id = model.config.lang2id[language]
84
- # kwargs["language"] = tokenizer.lang2id[language]
85
-
86
- # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
87
- # XLM masked-language modeling (MLM) models need masked token
88
- # is_xlm_mlm = "mlm" in args.model_name_or_path
89
- # if is_xlm_mlm:
90
- # kwargs["mask_token_id"] = tokenizer.mask_token_id
91
-
92
- return prompt_text
93
-
94
-
95
- def prepare_xlnet_input(args, _, tokenizer, prompt_text):
96
- prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
97
- return prompt_text
98
-
99
-
100
- def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
101
- prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
102
- return prompt_text
103
-
104
-
105
- PREPROCESSING_FUNCTIONS = {
106
- "ctrl": prepare_ctrl_input,
107
- "xlm": prepare_xlm_input,
108
- "xlnet": prepare_xlnet_input,
109
- "transfo-xl": prepare_transfoxl_input,
110
- }
111
-
112
-
113
- def adjust_length_to_model(length, max_sequence_length):
114
- if length < 0 and max_sequence_length > 0:
115
- length = max_sequence_length
116
- elif 0 < max_sequence_length < length:
117
- length = max_sequence_length # No generation bigger than model size
118
- elif length < 0:
119
- length = MAX_LENGTH # avoid infinite loop
120
- return length
121
-
122
-
123
- def main():
124
- parser = argparse.ArgumentParser()
125
- parser.add_argument(
126
- "--model_type",
127
- default=None,
128
- type=str,
129
- required=True,
130
- help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
131
- )
132
- parser.add_argument(
133
- "--model_name_or_path",
134
- default=None,
135
- type=str,
136
- required=True,
137
- help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
138
- )
139
-
140
- parser.add_argument("--prompt", type=str, default="")
141
- parser.add_argument("--length", type=int, default=20)
142
- parser.add_argument("--stop_token", type=str, default="</s>", help="Token at which lyrics generation is stopped")
143
-
144
- parser.add_argument(
145
- "--temperature",
146
- type=float,
147
- default=1.0,
148
- help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
149
- )
150
- parser.add_argument(
151
- "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
152
- )
153
- parser.add_argument("--k", type=int, default=0)
154
- parser.add_argument("--p", type=float, default=0.9)
155
-
156
- parser.add_argument("--padding_text", type=str, default="", help="Padding lyrics for Transfo-XL and XLNet.")
157
- parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
158
-
159
- parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
160
- parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
161
- parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
162
- args = parser.parse_args()
163
-
164
- args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
165
- args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
166
-
167
- # Initialize the model and tokenizer
168
- try:
169
- args.model_type = args.model_type.lower()
170
- model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
171
- except KeyError:
172
- raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
173
-
174
- tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
175
- model = model_class.from_pretrained(args.model_name_or_path)
176
- model.to(args.device)
177
-
178
- args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings)
179
- logger.info(args)
180
- generated_sequences = []
181
- prompt_text = ""
182
- while prompt_text != "stop":
183
- set_seed(args)
184
- while not len(prompt_text):
185
- prompt_text = args.prompt if args.prompt else input("Context >>> ")
186
-
187
- # Different models need different input formatting and/or extra arguments
188
- requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
189
- if requires_preprocessing:
190
- prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
191
- preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
192
- encoded_prompt = tokenizer.encode(
193
- preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
194
- )
195
- else:
196
- encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
197
- encoded_prompt = encoded_prompt.to(args.device)
198
-
199
- output_sequences = model.generate(
200
- input_ids=encoded_prompt,
201
- max_length=args.length + len(encoded_prompt[0]),
202
- temperature=args.temperature,
203
- top_k=args.k,
204
- top_p=args.p,
205
- repetition_penalty=args.repetition_penalty,
206
- do_sample=True,
207
- num_return_sequences=args.num_return_sequences,
208
- )
209
-
210
- # Remove the batch dimension when returning multiple sequences
211
- if len(output_sequences.shape) > 2:
212
- output_sequences.squeeze_()
213
-
214
- now = datetime.datetime.now()
215
- date_time = now.strftime('%Y%m%d_%H%M%S%f')
216
-
217
- for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
218
- print("ruGPT:".format(generated_sequence_idx + 1))
219
- generated_sequence = generated_sequence.tolist()
220
-
221
- # Decode lyrics
222
- text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
223
-
224
- # Remove all lyrics after the stop token
225
- text = text[: text.find(args.stop_token) if args.stop_token else None]
226
-
227
- # Add the prompt at the beginning of the sequence. Remove the excess lyrics that was used for pre-processing
228
- total_sequence = (
229
- prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
230
- )
231
-
232
- generated_sequences.append(total_sequence)
233
- # os.system('clear')
234
- print(total_sequence)
235
-
236
- prompt_text = ""
237
- if args.prompt:
238
- break
239
-
240
- return generated_sequences
241
-
242
- title = "ruGPT3 Song Writer"
243
- description = "Generate russian songs via fine-tuned ruGPT3"
244
-
245
- gr.Interface(
246
- process,
247
- gr.inputs.Textbox(lines=1, label="Input text", examples="Как дела? Как дела? Это новый кадиллак"),
248
- gr.outputs.Textbox(lines=20, label="Output text"),
249
- title=title,
250
- description=description,
251
- ).launch(enable_queue=True,cache_examples=True)