Vladislawoo commited on
Commit
683cfcc
·
1 Parent(s): 41b9558

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -11
app.py CHANGED
@@ -9,6 +9,7 @@ from tensorflow.keras.preprocessing.sequence import pad_sequences
9
  import time
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
 
12
 
13
  tok = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
14
  model_checkpoint = 'cointegrated/rubert-tiny-toxicity'
@@ -52,17 +53,23 @@ def predict_text(text):
52
  return predicted_class
53
 
54
 
55
- def generate_text(model, prompt, max_length=150, temperature=1.0):
56
- input_ids = tok.encode(prompt, return_tensors='pt')
57
- output = model_finetuned.generate(
58
- input_ids=input_ids,
59
- max_length=max_length + len(input_ids[0]),
60
- temperature=temperature,
61
- num_return_sequences=1,
62
- pad_token_id=tokenizer.eos_token_id
63
- )
64
- generated_text = textwrap.fill(tok.decode(out_), 60)
65
- return generated_text
 
 
 
 
 
 
66
 
67
 
68
  def page_reviews_classification():
 
9
  import time
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
12
+ import textwrap
13
 
14
  tok = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
15
  model_checkpoint = 'cointegrated/rubert-tiny-toxicity'
 
53
  return predicted_class
54
 
55
 
56
+ def generate_text(model, prompt, max_length=150, temperature=1.0, num_beams=10, top_k=600, top_p=0.75, no_repeat_ngram_size=1, num_return_sequences=1):
57
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
58
+
59
+ with torch.inference_mode():
60
+ output = model.generate(
61
+ input_ids=input_ids,
62
+ max_length=max_length,
63
+ num_beams=num_beams,
64
+ do_sample=True,
65
+ temperature=temperature,
66
+ top_k=top_k,
67
+ top_p=top_p,
68
+ no_repeat_ngram_size=no_repeat_ngram_size,
69
+ num_return_sequences=num_return_sequences
70
+ )
71
+ texts = [textwrap.fill(tokenizer.decode(out), 60) for out in output]
72
+ return "\n------------------\n".join(texts)
73
 
74
 
75
  def page_reviews_classification():