Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
@@ -9,8 +9,11 @@ from transformers import (
|
|
9 |
Trainer,
|
10 |
default_data_collator,
|
11 |
)
|
|
|
|
|
12 |
from lib.utils import preprocess_examples, make_predictions, get_examples
|
13 |
|
|
|
14 |
if torch.backends.mps.is_available():
|
15 |
device = "mps"
|
16 |
elif torch.cuda.is_available():
|
@@ -18,12 +21,6 @@ elif torch.cuda.is_available():
|
|
18 |
else:
|
19 |
device = "cpu"
|
20 |
|
21 |
-
# TO DO:
|
22 |
-
# - make it pretty
|
23 |
-
# - add support for multiple questions corresponding to same context
|
24 |
-
# - add examples
|
25 |
-
# What else??
|
26 |
-
|
27 |
# Initialize session state variables
|
28 |
if 'response' not in st.session_state:
|
29 |
st.session_state['response'] = ''
|
@@ -35,26 +32,44 @@ if 'question' not in st.session_state:
|
|
35 |
# Build trainer using model and tokenizer from Hugging Face repo
|
36 |
@st.cache_resource(show_spinner=False)
|
37 |
def get_model():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
repo_id = 'etweedy/roberta-base-squad-v2'
|
39 |
model = AutoModelForQuestionAnswering.from_pretrained(repo_id)
|
40 |
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
41 |
return model, tokenizer
|
42 |
|
43 |
def fill_in_example(i):
|
|
|
|
|
|
|
44 |
st.session_state['response'] = ''
|
45 |
st.session_state['question'] = ex_q[i]
|
46 |
st.session_state['context'] = ex_c[i]
|
47 |
|
48 |
def clear_boxes():
|
|
|
|
|
|
|
49 |
st.session_state['response'] = ''
|
50 |
st.session_state['question'] = ''
|
51 |
st.session_state['context'] = ''
|
52 |
|
|
|
53 |
with st.spinner('Loading the model...'):
|
54 |
model, tokenizer = get_model()
|
55 |
|
|
|
56 |
st.header('RoBERTa Q&A model')
|
57 |
-
|
58 |
st.markdown('''
|
59 |
This app demonstrates the answer-retrieval capabilities of a fine-tuned RoBERTa (Robustly optimized Bidirectional Encoder Representations from Transformers) model.
|
60 |
''')
|
@@ -90,18 +105,21 @@ with st.expander('Click to read more about the model...'):
|
|
90 |
}
|
91 |
```
|
92 |
''')
|
93 |
-
|
94 |
st.markdown('''
|
95 |
Please type or paste a context paragraph and question you'd like to ask about it. The model will attempt to answer the question, or otherwise will report that it cannot. Your results will appear below the question field when the model is finished running.
|
96 |
|
97 |
Alternatively, you can try an example by clicking one of the buttons below:
|
98 |
''')
|
99 |
|
|
|
100 |
ex_q, ex_c = get_examples()
|
|
|
|
|
101 |
example_container = st.container()
|
102 |
input_container = st.container()
|
103 |
response_container = st.container()
|
104 |
|
|
|
105 |
with example_container:
|
106 |
ex_cols = st.columns(len(ex_q)+1)
|
107 |
for i in range(len(ex_q)):
|
@@ -119,9 +137,10 @@ with example_container:
|
|
119 |
on_click = clear_boxes,
|
120 |
)
|
121 |
|
122 |
-
#
|
123 |
with input_container:
|
124 |
with st.form(key='input_form',clear_on_submit=False):
|
|
|
125 |
context = st.text_area(
|
126 |
label='Context',
|
127 |
value=st.session_state['context'],
|
@@ -130,6 +149,7 @@ with input_container:
|
|
130 |
placeholder='Enter your context paragraph here.',
|
131 |
height=300,
|
132 |
)
|
|
|
133 |
question = st.text_input(
|
134 |
label='Question',
|
135 |
value=st.session_state['question'],
|
@@ -137,11 +157,14 @@ with input_container:
|
|
137 |
label_visibility='hidden',
|
138 |
placeholder='Enter your question here.',
|
139 |
)
|
|
|
140 |
query_submitted = st.form_submit_button("Submit")
|
141 |
if query_submitted:
|
|
|
142 |
st.session_state['question'] = question
|
143 |
st.session_state['context'] = context
|
144 |
with st.spinner('Generating response...'):
|
|
|
145 |
data_raw = Dataset.from_dict(
|
146 |
{
|
147 |
'id':[0],
|
@@ -149,6 +172,7 @@ with input_container:
|
|
149 |
'question':[st.session_state['question']],
|
150 |
}
|
151 |
)
|
|
|
152 |
data_proc = data_raw.map(
|
153 |
preprocess_examples,
|
154 |
remove_columns = data_raw.column_names,
|
@@ -157,15 +181,18 @@ with input_container:
|
|
157 |
'tokenizer':tokenizer,
|
158 |
}
|
159 |
)
|
|
|
160 |
predicted_answers = make_predictions(model, tokenizer,
|
161 |
data_proc, data_raw,
|
162 |
n_best = 20)
|
163 |
answer = predicted_answers[0]['prediction_text']
|
164 |
confidence = predicted_answers[0]['confidence']
|
|
|
165 |
st.session_state['response'] = f"""
|
166 |
Answer: {answer}\n
|
167 |
Confidence: {confidence:.2%}
|
168 |
"""
|
|
|
169 |
with response_container:
|
170 |
st.write(st.session_state['response'])
|
171 |
|
|
|
9 |
Trainer,
|
10 |
default_data_collator,
|
11 |
)
|
12 |
+
|
13 |
+
# Load custom functions
|
14 |
from lib.utils import preprocess_examples, make_predictions, get_examples
|
15 |
|
16 |
+
# Set mps or cuda device if available
|
17 |
if torch.backends.mps.is_available():
|
18 |
device = "mps"
|
19 |
elif torch.cuda.is_available():
|
|
|
21 |
else:
|
22 |
device = "cpu"
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# Initialize session state variables
|
25 |
if 'response' not in st.session_state:
|
26 |
st.session_state['response'] = ''
|
|
|
32 |
# Build trainer using model and tokenizer from Hugging Face repo
|
33 |
@st.cache_resource(show_spinner=False)
|
34 |
def get_model():
|
35 |
+
"""
|
36 |
+
Load model and tokenizer from 🤗 repo
|
37 |
+
Parameters: None
|
38 |
+
-----------
|
39 |
+
Returns:
|
40 |
+
--------
|
41 |
+
model : transformers.AutoModelForQuestionAnswering
|
42 |
+
The fine-tuned Q&A model
|
43 |
+
tokenizer : transformers.AutoTokenizer
|
44 |
+
The model's pre-trained tokenizer
|
45 |
+
"""
|
46 |
repo_id = 'etweedy/roberta-base-squad-v2'
|
47 |
model = AutoModelForQuestionAnswering.from_pretrained(repo_id)
|
48 |
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
49 |
return model, tokenizer
|
50 |
|
51 |
def fill_in_example(i):
|
52 |
+
"""
|
53 |
+
Function for context-question example button click
|
54 |
+
"""
|
55 |
st.session_state['response'] = ''
|
56 |
st.session_state['question'] = ex_q[i]
|
57 |
st.session_state['context'] = ex_c[i]
|
58 |
|
59 |
def clear_boxes():
|
60 |
+
"""
|
61 |
+
Function for field clear button click
|
62 |
+
"""
|
63 |
st.session_state['response'] = ''
|
64 |
st.session_state['question'] = ''
|
65 |
st.session_state['context'] = ''
|
66 |
|
67 |
+
# Retrieve stored model
|
68 |
with st.spinner('Loading the model...'):
|
69 |
model, tokenizer = get_model()
|
70 |
|
71 |
+
# Intro text
|
72 |
st.header('RoBERTa Q&A model')
|
|
|
73 |
st.markdown('''
|
74 |
This app demonstrates the answer-retrieval capabilities of a fine-tuned RoBERTa (Robustly optimized Bidirectional Encoder Representations from Transformers) model.
|
75 |
''')
|
|
|
105 |
}
|
106 |
```
|
107 |
''')
|
|
|
108 |
st.markdown('''
|
109 |
Please type or paste a context paragraph and question you'd like to ask about it. The model will attempt to answer the question, or otherwise will report that it cannot. Your results will appear below the question field when the model is finished running.
|
110 |
|
111 |
Alternatively, you can try an example by clicking one of the buttons below:
|
112 |
''')
|
113 |
|
114 |
+
# Grab example question-context pairs from csv file
|
115 |
ex_q, ex_c = get_examples()
|
116 |
+
|
117 |
+
# Generate containers in order
|
118 |
example_container = st.container()
|
119 |
input_container = st.container()
|
120 |
response_container = st.container()
|
121 |
|
122 |
+
# Populate example button container
|
123 |
with example_container:
|
124 |
ex_cols = st.columns(len(ex_q)+1)
|
125 |
for i in range(len(ex_q)):
|
|
|
137 |
on_click = clear_boxes,
|
138 |
)
|
139 |
|
140 |
+
# Populate user input container
|
141 |
with input_container:
|
142 |
with st.form(key='input_form',clear_on_submit=False):
|
143 |
+
# Context input field
|
144 |
context = st.text_area(
|
145 |
label='Context',
|
146 |
value=st.session_state['context'],
|
|
|
149 |
placeholder='Enter your context paragraph here.',
|
150 |
height=300,
|
151 |
)
|
152 |
+
# Question input field
|
153 |
question = st.text_input(
|
154 |
label='Question',
|
155 |
value=st.session_state['question'],
|
|
|
157 |
label_visibility='hidden',
|
158 |
placeholder='Enter your question here.',
|
159 |
)
|
160 |
+
# Form submit button
|
161 |
query_submitted = st.form_submit_button("Submit")
|
162 |
if query_submitted:
|
163 |
+
# update question, context in session state
|
164 |
st.session_state['question'] = question
|
165 |
st.session_state['context'] = context
|
166 |
with st.spinner('Generating response...'):
|
167 |
+
# Generate dataset from input example
|
168 |
data_raw = Dataset.from_dict(
|
169 |
{
|
170 |
'id':[0],
|
|
|
172 |
'question':[st.session_state['question']],
|
173 |
}
|
174 |
)
|
175 |
+
# Tokenize and preprocess dataset
|
176 |
data_proc = data_raw.map(
|
177 |
preprocess_examples,
|
178 |
remove_columns = data_raw.column_names,
|
|
|
181 |
'tokenizer':tokenizer,
|
182 |
}
|
183 |
)
|
184 |
+
# Make answer prediction with model
|
185 |
predicted_answers = make_predictions(model, tokenizer,
|
186 |
data_proc, data_raw,
|
187 |
n_best = 20)
|
188 |
answer = predicted_answers[0]['prediction_text']
|
189 |
confidence = predicted_answers[0]['confidence']
|
190 |
+
# Update response in session state
|
191 |
st.session_state['response'] = f"""
|
192 |
Answer: {answer}\n
|
193 |
Confidence: {confidence:.2%}
|
194 |
"""
|
195 |
+
# Display response
|
196 |
with response_container:
|
197 |
st.write(st.session_state['response'])
|
198 |
|