File size: 5,105 Bytes
429523a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dcbe0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429523a
 
 
 
e908864
429523a
8dcbe0a
429523a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dcbe0a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import streamlit as st


@st.cache_data
def prepare_model():
    """
    Prepare the tokenizer and the model for classification.
    """
    tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier-arxiv")
    model = AutoModelForSequenceClassification.from_pretrained(
        "oracat/bert-paper-classifier-arxiv"
    )
    return (tokenizer, model)


def top_pct(preds, threshold=0.95):
    """
    Output top predictions and their scores
    """
    preds = sorted(preds, key=lambda x: -x["score"])

    cum_score = 0
    for i, item in enumerate(preds):
        cum_score += item["score"]
        if cum_score >= threshold:
            break

    preds = preds[: (i + 1)]

    return preds


def format_predictions(preds) -> str:
    """
    Prepare predictions and their scores for printing to the user
    """
    out = ""
    for i, item in enumerate(preds):
        out += f"{i+1}. **{item['label']}** *(score {item['score']:.2f})*\n"
    return out


def process(text):
    """
    Translate incoming text to tokens and classify it
    """
    pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
    result = pipe(text)[0]
    return format_predictions(top_pct(result))


tokenizer, model = prepare_model()


# State managements
#
# The state in the app is the title and the abstract.
# State management is used here in order to pre-fill
# input fields with values for demos.

if "title" not in st.session_state:
    st.session_state["title"] = ""

if "abstract" not in st.session_state:
    st.session_state["abstract"] = ""

if "output" not in st.session_state:
    st.session_state["output"] = ""


# Simple streamlit interface

st.markdown("### Hello, paper classifier!")


## Demo buttons and their callbacks


def demo_cl_callback():
    """
    Use https://ai.facebook.com/blog/large-language-model-llama-meta-ai/ for demo
    """
    paper_title = (
        "Introducing LLaMA: A foundational, 65-billion-parameter large language model"
    )
    paper_abstract = "Over the last year, large language models — natural language processing (NLP) systems with billions of parameters — have shown new capabilities to generate creative text, solve mathematical theorems, predict protein structures, answer reading comprehension questions, and more. They are one of the clearest cases of the substantial potential benefits AI can offer at scale to billions of people. Smaller models trained on more tokens — which are pieces of words — are easier to retrain and fine-tune for specific potential product use cases. We trained LLaMA 65B and LLaMA 33B on 1.4 trillion tokens. Our smallest model, LLaMA 7B, is trained on one trillion tokens. Like other large language models, LLaMA works by taking a sequence of words as an input and predicts a next word to recursively generate text. To train our model, we chose text from the 20 languages with the most speakers, focusing on those with Latin and Cyrillic alphabets."
    st.session_state["title"] = paper_title
    st.session_state["abstract"] = paper_abstract


def demo_cv_callback():
    """
    Use https://arxiv.org/abs/2010.11929 for demo
    """
    paper_title = (
        "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"
    )
    paper_abstract = "While the Transformer architecture has become the de-facto standard for natural language processing tasks, its applications to computer vision remain limited. In vision, attention is either applied in conjunction with convolutional networks, or used to replace certain components of convolutional networks while keeping their overall structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to sequences of image patches can perform very well on image classification tasks. When pre-trained on large amounts of data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.), Vision Transformer (ViT) attains excellent results compared to state-of-the-art convolutional networks while requiring substantially fewer computational resources to train."
    st.session_state["title"] = paper_title
    st.session_state["abstract"] = paper_abstract


def clear_callback():
    """
    Clear input fields
    """
    st.session_state["title"] = ""
    st.session_state["abstract"] = ""
    st.session_state["output"] = ""


col1, col2, col3 = st.columns([1, 1, 1])
with col1:
    st.button("Demo: LLaMA paper", on_click=demo_cl_callback)
with col2:
    st.button("Demo: ViT paper", on_click=demo_cv_callback)
with col3:
    st.button("Clear fields", on_click=clear_callback)

## Input fields

placeholder = st.empty()

title = st.text_input("Enter the title:", key="title")
abstract = st.text_area(
    "... and maybe the abstract of the paper you want to classify:", key="abstract"
)

text = "\n".join([title, abstract])

## Output

if len(text.strip()) > 0:
    st.markdown(f"{process(text)}", unsafe_allow_html=True)