yhn112 commited on
Commit
9a72fe1
·
1 Parent(s): b5e8297

Add application file

Browse files
Files changed (3) hide show
  1. app.py +78 -0
  2. model.pt +3 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import pandas as pd
4
+ import streamlit as st
5
+ import torch.nn as nn
6
+ from transformers import RobertaTokenizer, RobertaModel, PretrainedConfig
7
+
8
+
9
+ @st.cache_resource
10
+ def init_model():
11
+ model = RobertaModel(config=PretrainedConfig().from_pretrained("roberta-large-mnli"))
12
+
13
+ model.pooler = nn.Sequential(
14
+ nn.Linear(1024, 256),
15
+ nn.LayerNorm(256),
16
+ nn.ReLU(),
17
+ nn.Linear(256, 8),
18
+ nn.Sigmoid()
19
+ )
20
+
21
+ model_path = "model.pt"
22
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
23
+ model.eval()
24
+ return model
25
+
26
+ cats = ["Computer Science", "Economics", "Electrical Engineering",
27
+ "Mathematics", "Physics", "Biology", "Finance", "Statistics"]
28
+
29
+ def predict(outputs):
30
+ top = 0
31
+ temp = 100000
32
+ apr_probs = torch.nn.functional.softmax(torch.tensor([39253., 84., 220., 2263., 1214., 909., 66., 10661.]) / temp, dim=0)
33
+ probs = nn.functional.softmax(outputs / apr_probs, dim=1).tolist()[0]
34
+
35
+ top_cats = []
36
+ top_probs = []
37
+
38
+ first = True
39
+ write_cs = False
40
+ for prob, cat in sorted(zip(probs, cats), reverse=True):
41
+ if first:
42
+ if cat == "Computer Science":
43
+ write_cs = True
44
+ first = False
45
+ if top < 95:
46
+ percent = prob * 100
47
+ top += percent
48
+ top_cats.append(cat)
49
+ top_probs.append(str(round(percent, 1)))
50
+ res = pd.DataFrame(top_probs, index=top_cats, columns=['Percent'])
51
+ st.write(res)
52
+ if write_cs:
53
+ st.write("Today everything is connected with Computer Science")
54
+
55
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
56
+ model = init_model()
57
+
58
+ st.title("Article classifier")
59
+ st.markdown("### Title")
60
+
61
+ title = st.text_input("*Enter title (required)")
62
+
63
+ st.markdown("### Abstract")
64
+
65
+ abstract = st.text_area(" Enter abstract", height=200)
66
+
67
+ if not title:
68
+ st.warning("Please fill in required fields")
69
+ else:
70
+ try:
71
+ st.markdown("### Result")
72
+ encoded_input = tokenizer(title + ". " + abstract, return_tensors="pt", padding=True,
73
+ max_length=1024, truncation=True)
74
+ with torch.no_grad():
75
+ outputs = model(**encoded_input).pooler_output[:, 0, :]
76
+ predict(outputs)
77
+ except Exception:
78
+ st.error("Something went wrong. Try different text")
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9ce2a83d4d7f59e53ab917fb99ecaeb26f66a14c9f336b898f4924935af2140
3
+ size 1418460457
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ altair==4.0
2
+ pandas
3
+ torch
4
+ tokenizers
5
+ transformers