stal76 commited on
Commit
38083c7
·
1 Parent(s): 2dc13a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -18,11 +18,14 @@ class Net(nn.Module):
18
 
19
  def forward(self,x):
20
  return self.layer(x)
21
-
22
- model = Net()
23
- model.load_state_dict(torch.load('model.dat', map_location=torch.device('cpu')))
24
- tokenizer = AutoTokenizer.from_pretrained("Callidior/bert2bert-base-arxiv-titlegen")
25
- model_emb = AutoModelForSeq2SeqLM.from_pretrained("Callidior/bert2bert-base-arxiv-titlegen")
 
 
 
26
 
27
  def BuildAnswer(txt):
28
  def get_hidden_states(encoded, model):
@@ -49,6 +52,8 @@ def BuildAnswer(txt):
49
  7: "Quantitative Finance",
50
  8: "Statistics"
51
  }
 
 
52
 
53
  embed = get_word_vector(txt, tokenizer, model_emb)
54
  logits = torch.nn.functional.softmax(model(embed), dim=0)
@@ -71,10 +76,8 @@ st.markdown("### Hello, world!")
71
  st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
72
  # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
73
 
74
- #st.markdown("#### Title")
75
- title = st.text_area("Title")
76
- #st.markdown("#### Abstract")
77
- abstract = st.text_area("Abstract")
78
 
79
  #from transformers import pipeline
80
  #pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl")
 
18
 
19
  def forward(self,x):
20
  return self.layer(x)
21
+
22
+ @st.cache
23
+ def GetModelAndTokenizer():
24
+ model = Net()
25
+ model.load_state_dict(torch.load('model.dat', map_location=torch.device('cpu')))
26
+ tokenizer = AutoTokenizer.from_pretrained("Callidior/bert2bert-base-arxiv-titlegen")
27
+ model_emb = AutoModelForSeq2SeqLM.from_pretrained("Callidior/bert2bert-base-arxiv-titlegen")
28
+ return model, tokenizer, model_emb
29
 
30
  def BuildAnswer(txt):
31
  def get_hidden_states(encoded, model):
 
52
  7: "Quantitative Finance",
53
  8: "Statistics"
54
  }
55
+
56
+ model, tokenizer, model_emb = GetModelAndTokenizer()
57
 
58
  embed = get_word_vector(txt, tokenizer, model_emb)
59
  logits = torch.nn.functional.softmax(model(embed), dim=0)
 
76
  st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
77
  # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
78
 
79
+ title = st.text_area("Title:")
80
+ abstract = st.text_area("Abstract:", height=10)
 
 
81
 
82
  #from transformers import pipeline
83
  #pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl")