etweedy commited on
Commit
ac4ecec
·
1 Parent(s): 3fe1714

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -9
app.py CHANGED
@@ -9,8 +9,11 @@ from transformers import (
9
  Trainer,
10
  default_data_collator,
11
  )
 
 
12
  from lib.utils import preprocess_examples, make_predictions, get_examples
13
 
 
14
  if torch.backends.mps.is_available():
15
  device = "mps"
16
  elif torch.cuda.is_available():
@@ -18,12 +21,6 @@ elif torch.cuda.is_available():
18
  else:
19
  device = "cpu"
20
 
21
- # TO DO:
22
- # - make it pretty
23
- # - add support for multiple questions corresponding to same context
24
- # - add examples
25
- # What else??
26
-
27
  # Initialize session state variables
28
  if 'response' not in st.session_state:
29
  st.session_state['response'] = ''
@@ -35,26 +32,44 @@ if 'question' not in st.session_state:
35
  # Build trainer using model and tokenizer from Hugging Face repo
36
  @st.cache_resource(show_spinner=False)
37
  def get_model():
 
 
 
 
 
 
 
 
 
 
 
38
  repo_id = 'etweedy/roberta-base-squad-v2'
39
  model = AutoModelForQuestionAnswering.from_pretrained(repo_id)
40
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
41
  return model, tokenizer
42
 
43
  def fill_in_example(i):
 
 
 
44
  st.session_state['response'] = ''
45
  st.session_state['question'] = ex_q[i]
46
  st.session_state['context'] = ex_c[i]
47
 
48
  def clear_boxes():
 
 
 
49
  st.session_state['response'] = ''
50
  st.session_state['question'] = ''
51
  st.session_state['context'] = ''
52
 
 
53
  with st.spinner('Loading the model...'):
54
  model, tokenizer = get_model()
55
 
 
56
  st.header('RoBERTa Q&A model')
57
-
58
  st.markdown('''
59
  This app demonstrates the answer-retrieval capabilities of a fine-tuned RoBERTa (Robustly optimized Bidirectional Encoder Representations from Transformers) model.
60
  ''')
@@ -90,18 +105,21 @@ with st.expander('Click to read more about the model...'):
90
  }
91
  ```
92
  ''')
93
-
94
  st.markdown('''
95
  Please type or paste a context paragraph and question you'd like to ask about it. The model will attempt to answer the question, or otherwise will report that it cannot. Your results will appear below the question field when the model is finished running.
96
 
97
  Alternatively, you can try an example by clicking one of the buttons below:
98
  ''')
99
 
 
100
  ex_q, ex_c = get_examples()
 
 
101
  example_container = st.container()
102
  input_container = st.container()
103
  response_container = st.container()
104
 
 
105
  with example_container:
106
  ex_cols = st.columns(len(ex_q)+1)
107
  for i in range(len(ex_q)):
@@ -119,9 +137,10 @@ with example_container:
119
  on_click = clear_boxes,
120
  )
121
 
122
- # Form for user inputs
123
  with input_container:
124
  with st.form(key='input_form',clear_on_submit=False):
 
125
  context = st.text_area(
126
  label='Context',
127
  value=st.session_state['context'],
@@ -130,6 +149,7 @@ with input_container:
130
  placeholder='Enter your context paragraph here.',
131
  height=300,
132
  )
 
133
  question = st.text_input(
134
  label='Question',
135
  value=st.session_state['question'],
@@ -137,11 +157,14 @@ with input_container:
137
  label_visibility='hidden',
138
  placeholder='Enter your question here.',
139
  )
 
140
  query_submitted = st.form_submit_button("Submit")
141
  if query_submitted:
 
142
  st.session_state['question'] = question
143
  st.session_state['context'] = context
144
  with st.spinner('Generating response...'):
 
145
  data_raw = Dataset.from_dict(
146
  {
147
  'id':[0],
@@ -149,6 +172,7 @@ with input_container:
149
  'question':[st.session_state['question']],
150
  }
151
  )
 
152
  data_proc = data_raw.map(
153
  preprocess_examples,
154
  remove_columns = data_raw.column_names,
@@ -157,15 +181,18 @@ with input_container:
157
  'tokenizer':tokenizer,
158
  }
159
  )
 
160
  predicted_answers = make_predictions(model, tokenizer,
161
  data_proc, data_raw,
162
  n_best = 20)
163
  answer = predicted_answers[0]['prediction_text']
164
  confidence = predicted_answers[0]['confidence']
 
165
  st.session_state['response'] = f"""
166
  Answer: {answer}\n
167
  Confidence: {confidence:.2%}
168
  """
 
169
  with response_container:
170
  st.write(st.session_state['response'])
171
 
 
9
  Trainer,
10
  default_data_collator,
11
  )
12
+
13
+ # Load custom functions
14
  from lib.utils import preprocess_examples, make_predictions, get_examples
15
 
16
+ # Set mps or cuda device if available
17
  if torch.backends.mps.is_available():
18
  device = "mps"
19
  elif torch.cuda.is_available():
 
21
  else:
22
  device = "cpu"
23
 
 
 
 
 
 
 
24
  # Initialize session state variables
25
  if 'response' not in st.session_state:
26
  st.session_state['response'] = ''
 
32
  # Build trainer using model and tokenizer from Hugging Face repo
33
  @st.cache_resource(show_spinner=False)
34
  def get_model():
35
+ """
36
+ Load model and tokenizer from 🤗 repo
37
+ Parameters: None
38
+ -----------
39
+ Returns:
40
+ --------
41
+ model : transformers.AutoModelForQuestionAnswering
42
+ The fine-tuned Q&A model
43
+ tokenizer : transformers.AutoTokenizer
44
+ The model's pre-trained tokenizer
45
+ """
46
  repo_id = 'etweedy/roberta-base-squad-v2'
47
  model = AutoModelForQuestionAnswering.from_pretrained(repo_id)
48
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
49
  return model, tokenizer
50
 
51
  def fill_in_example(i):
52
+ """
53
+ Function for context-question example button click
54
+ """
55
  st.session_state['response'] = ''
56
  st.session_state['question'] = ex_q[i]
57
  st.session_state['context'] = ex_c[i]
58
 
59
  def clear_boxes():
60
+ """
61
+ Function for field clear button click
62
+ """
63
  st.session_state['response'] = ''
64
  st.session_state['question'] = ''
65
  st.session_state['context'] = ''
66
 
67
+ # Retrieve stored model
68
  with st.spinner('Loading the model...'):
69
  model, tokenizer = get_model()
70
 
71
+ # Intro text
72
  st.header('RoBERTa Q&A model')
 
73
  st.markdown('''
74
  This app demonstrates the answer-retrieval capabilities of a fine-tuned RoBERTa (Robustly optimized Bidirectional Encoder Representations from Transformers) model.
75
  ''')
 
105
  }
106
  ```
107
  ''')
 
108
  st.markdown('''
109
  Please type or paste a context paragraph and question you'd like to ask about it. The model will attempt to answer the question, or otherwise will report that it cannot. Your results will appear below the question field when the model is finished running.
110
 
111
  Alternatively, you can try an example by clicking one of the buttons below:
112
  ''')
113
 
114
+ # Grab example question-context pairs from csv file
115
  ex_q, ex_c = get_examples()
116
+
117
+ # Generate containers in order
118
  example_container = st.container()
119
  input_container = st.container()
120
  response_container = st.container()
121
 
122
+ # Populate example button container
123
  with example_container:
124
  ex_cols = st.columns(len(ex_q)+1)
125
  for i in range(len(ex_q)):
 
137
  on_click = clear_boxes,
138
  )
139
 
140
+ # Populate user input container
141
  with input_container:
142
  with st.form(key='input_form',clear_on_submit=False):
143
+ # Context input field
144
  context = st.text_area(
145
  label='Context',
146
  value=st.session_state['context'],
 
149
  placeholder='Enter your context paragraph here.',
150
  height=300,
151
  )
152
+ # Question input field
153
  question = st.text_input(
154
  label='Question',
155
  value=st.session_state['question'],
 
157
  label_visibility='hidden',
158
  placeholder='Enter your question here.',
159
  )
160
+ # Form submit button
161
  query_submitted = st.form_submit_button("Submit")
162
  if query_submitted:
163
+ # update question, context in session state
164
  st.session_state['question'] = question
165
  st.session_state['context'] = context
166
  with st.spinner('Generating response...'):
167
+ # Generate dataset from input example
168
  data_raw = Dataset.from_dict(
169
  {
170
  'id':[0],
 
172
  'question':[st.session_state['question']],
173
  }
174
  )
175
+ # Tokenize and preprocess dataset
176
  data_proc = data_raw.map(
177
  preprocess_examples,
178
  remove_columns = data_raw.column_names,
 
181
  'tokenizer':tokenizer,
182
  }
183
  )
184
+ # Make answer prediction with model
185
  predicted_answers = make_predictions(model, tokenizer,
186
  data_proc, data_raw,
187
  n_best = 20)
188
  answer = predicted_answers[0]['prediction_text']
189
  confidence = predicted_answers[0]['confidence']
190
+ # Update response in session state
191
  st.session_state['response'] = f"""
192
  Answer: {answer}\n
193
  Confidence: {confidence:.2%}
194
  """
195
+ # Display response
196
  with response_container:
197
  st.write(st.session_state['response'])
198