File size: 21,088 Bytes
0df5fcd
 
 
 
 
 
 
 
 
 
ddadb1a
 
 
 
 
 
0df5fcd
ddadb1a
 
 
0df5fcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddadb1a
 
 
 
 
 
 
 
 
 
 
0df5fcd
 
 
 
ddadb1a
 
 
 
 
 
 
 
 
0df5fcd
ddadb1a
 
 
 
 
 
0df5fcd
ddadb1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0df5fcd
ddadb1a
 
 
0df5fcd
ddadb1a
 
2ea6ff9
0df5fcd
ddadb1a
 
 
 
0df5fcd
ddadb1a
0df5fcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddadb1a
 
6a56215
ddadb1a
 
 
 
0df5fcd
ddadb1a
 
 
0df5fcd
ddadb1a
0df5fcd
ddadb1a
 
 
 
 
0df5fcd
ddadb1a
0df5fcd
ddadb1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0df5fcd
 
 
ddadb1a
 
 
0df5fcd
 
 
 
ddadb1a
0df5fcd
ddadb1a
 
0df5fcd
ddadb1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0df5fcd
ddadb1a
 
 
 
 
 
 
0df5fcd
ddadb1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0df5fcd
ddadb1a
 
0df5fcd
 
ddadb1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
import torch
import streamlit as st
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForQuestionAnswering,
    pipeline,
)
import spacy
from lib.utils import (
    ContextRetriever,
    get_examples,
    generate_query,
    generate_answer,
)

#################################
### Model retrieval functions ###
#################################

@st.cache_resource(show_spinner=False)
def get_pipeline():
    """
    Load model and tokenizer from 🤗 repo
    and build pipeline
    Parameters: None
    -----------
    Returns:
    --------
    qa_pipeline : transformers.QuestionAnsweringPipeline
        The question answering pipeline object
    """
    repo_id = 'etweedy/roberta-base-squad-v2'
    qa_pipeline = pipeline(
        task = 'question-answering',
        model=repo_id,
        tokenizer=repo_id,
        handle_impossible_answer = True
    )
    return qa_pipeline

@st.cache_resource(show_spinner=False)
def get_spacy():
    """
    Load spaCy model for processing query
    Parameters: None
    -----------
    Returns:
    --------
    nlp : spaCy.Pipe
        Portion of 'en_core_web_sm' model pipeline
        only containing tokenizer and part-of-speech
        tagger
    """
    nlp = spacy.load(
        'en_core_web_sm',
        disable = ['ner','parser','textcat']
    )
    return nlp

#############
### Setup ###
#############
    
# Set mps or cuda device if available
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

# Initialize session state variables
for tab in ['basic','semi','auto']:
    if tab not in st.session_state:
        st.session_state[tab] = {}
    for field in ['question','context','query','response']:
        if field not in st.session_state[tab]:
            st.session_state[tab][field] = ''
for field in ['page_options','selected_pages']:
    if field not in st.session_state['semi']:
        st.session_state['semi'][field] = []
        
# Retrieve models
with st.spinner('Loading the model...'):
    qa_pipeline = get_pipeline()
    nlp = get_spacy()

# Retrieve example questions and contexts
examples = get_examples()
# ex_queries, ex_questions, ex_contexts = get_examples()
if 'ex_questions' not in st.session_state['semi']:
    st.session_state['semi']['ex_questions'] = len(examples[1][0])*['']
    
################################
### Initialize App Structure ###
################################

tabs = st.tabs([
    'RoBERTa Q&A model',
    'Basic extractive Q&A',
    'User-guided Wiki Q&A',
    'Automated Wiki Q&A',
])

with tabs[0]:
    intro_container = st.container()
with tabs[1]:
    basic_title_container = st.container()
    basic_example_container = st.container()
    basic_input_container = st.container()
    basic_response_container = st.container()
with tabs[2]:
    semi_title_container = st.container()
    semi_query_container = st.container()
    semi_page_container = st.container()
    semi_input_container = st.container()
    semi_response_container = st.container()
with tabs[3]:
    auto_title_container = st.container()
    auto_example_container = st.container()
    auto_input_container = st.container()
    auto_response_container = st.container()

##############################
### Populate tab - Welcome ###
##############################

with intro_container:
    # Intro text
    st.header('RoBERTa Q&A with Wiki tools')
    st.markdown('''
    This app demonstrates the answer-retrieval capabilities of a fine-tuned RoBERTa (Robustly optimized Bidirectional Encoder Representations from Transformers) model.
    ''')
    with st.expander('Click to read more about the model...'):
        st.markdown('''
* [Click here](https://huggingface.co/etweedy/roberta-base-squad-v2) to visit the Hugging Face model card for this fine-tuned model.
* To create this model, I fine-tuned the [RoBERTa base model](https://huggingface.co/roberta-base) Version 2 of [SQuAD (Stanford Question Answering Dataset)](https://huggingface.co/datasets/squad_v2), a dataset of context-question-answer triples.
* The objective of the model is "extractive question answering", the task of retrieving the answer to the question from a given context text corpus.
* SQuAD Version 2 incorporates the 100,000 samples from Version 1.1, along with 50,000 'unanswerable' questions, i.e. samples in the question cannot be answered using the context given.
* The original base RoBERTa model was introduced in [this paper](https://arxiv.org/abs/1907.11692) and [this repository](https://github.com/facebookresearch/fairseq/tree/main/examples/roberta).  Here's a citation for that base model:
```bibtex
@article{DBLP:journals/corr/abs-1907-11692,
  author    = {Yinhan Liu and
               Myle Ott and
               Naman Goyal and
               Jingfei Du and
               Mandar Joshi and
               Danqi Chen and
               Omer Levy and
               Mike Lewis and
               Luke Zettlemoyer and
               Veselin Stoyanov},
  title     = {RoBERTa: {A} Robustly Optimized {BERT} Pretraining Approach},
  journal   = {CoRR},
  volume    = {abs/1907.11692},
  year      = {2019},
  url       = {http://arxiv.org/abs/1907.11692},
  archivePrefix = {arXiv},
  eprint    = {1907.11692},
  timestamp = {Thu, 01 Aug 2019 08:59:33 +0200},
  biburl    = {https://dblp.org/rec/journals/corr/abs-1907-11692.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
    ''')
    st.markdown('''
Use the menu at the top to navigate between tabs containing several tools:
1. A basic Q&A tool which allows the user to ask the model to search a user-provided context paragraph for the answer to a user-provided question.
2. A user-guided Wiki Q&A tool which allows the user to search for one or more Wikipedia pages and ask the model to search those pages for the answer to a user-provided question.
3. An automated Wiki Q&A tool which asks the model to perform retrieve its own Wikipedia pages in order to answer the user-provided question.
    ''')

################################
### Populate tab - basic Q&A ###
################################

from lib.utils import basic_clear_boxes, basic_ex_click

with basic_title_container:
    ### Intro text ###
    st.header('Basic extractive Q&A')
    st.markdown('''
The basic functionality of a RoBERTa model for extractive question-answering is to attempt to extract the answer to a user-provided question from a piece of user-provided context text.  The model is also trained to recognize when the context doesn't provide the answer.

Please type or paste a context paragraph and question you'd like to ask about it.  The model will attempt to answer the question based on the context you provided, or report that it cannot find the answer in the context.  Your results will appear below the question field when the model is finished running.

Alternatively, you can try an example by clicking one of the buttons below:
    ''')
    
### Populate example button container ###
with basic_example_container:
    basic_ex_cols = st.columns(len(examples[0])+1)
    for i in range(len(examples[0])):
        with basic_ex_cols[i]:
            st.button(
                label = f'example {i+1}',
                key = f'basic_ex_button_{i+1}',
                on_click = basic_ex_click,
                args = (examples,i,),
            )
    with basic_ex_cols[-1]:
        st.button(
            label = "Clear all fields",
            key = "basic_clear_button",
            on_click = basic_clear_boxes,
        )
### Populate user input container ###
with basic_input_container:
    with st.form(key='basic_input_form',clear_on_submit=False):
        # Context input field
        context = st.text_area(
            label='Context',
            value=st.session_state['basic']['context'],
            key='basic_context_field',
            label_visibility='collapsed',
            placeholder='Enter your context paragraph here.',
            height=300,
        )
        # Question input field
        question = st.text_input(
            label='Question',
            value=st.session_state['basic']['question'],
            key='basic_question_field',
            label_visibility='collapsed',
            placeholder='Enter your question here.',
        )
        # Form submit button
        query_submitted = st.form_submit_button("Submit")
        if query_submitted and question!= '':
            # update question, context in session state
            st.session_state['basic']['question'] = question
            st.session_state['basic']['context'] = context
            with st.spinner('Generating response...'):
                # Generate dictionary from inputs
                query = {
                    'context':st.session_state['basic']['context'],
                    'question':st.session_state['basic']['question'],
                }
                # Pass to QA pipeline
                response = qa_pipeline(**query)
                answer = response['answer']
                confidence = response['score']
                # Reformat empty answer to message
                if answer == '':
                    answer = "I don't have an answer based on the context provided."
                # Update response in session state
                st.session_state['basic']['response'] = f"""
                    Answer: {answer}\n
                    Confidence: {confidence:.2%}
                """
### Populate response container ###
with basic_response_container:
    st.write(st.session_state['basic']['response'])
            
#################################
### Populate tab - guided Q&A ###
#################################

from lib.utils import (
    semi_ex_query_click,
    semi_ex_question_click,
    semi_clear_query,
    semi_clear_question,
)
    
### Intro text ###
with semi_title_container:
    st.header('User-guided Wiki Q&A')
    st.markdown('''
This component allows you to perform a Wikipedia search for source material to feed as contexts to the RoBERTa question-answering model.
    ''')
    with st.expander("Click here to find out what's happening behind the scenes..."):
        st.markdown('''
    1. A Wikipedia search is performed using your query, resulting in a list of pages which then populate the drop-down menu.
    2. The pages you select are retrieved and broken up into paragraphs.  Wikipedia queries and page collection use the [wikipedia library](https://pypi.org/project/wikipedia/), a wrapper for the [MediaWiki API](https://www.mediawiki.org/wiki/API).
    3. The paragraphs are ranked in descending order of relevance to your question, using the [Okapi BM25 score](https://en.wikipedia.org/wiki/Okapi_BM25) as implemented in the [rank_bm25 library](https://github.com/dorianbrown/rank_bm25).
    4. Among these ranked paragraphs, approximately the top 25% are fed as context to the RoBERTa model, from which it will attempt to extract the answer to your question.  The 'hit' having the highest confidence (prediction probability) from the model is reported as the answer.
        ''')

### Populate query container ###
with semi_query_container:
    st.markdown('First submit a search query, or choose one of the examples.')
    semi_query_cols = st.columns(len(examples[0])+1)
    # Buttons for query examples
    for i in range(len(examples[0])):
        with semi_query_cols[i]:
            st.button(
                label = f'query {i+1}',
                key = f'semi_query_button_{i+1}',
                on_click = semi_ex_query_click,
                args=(examples,i,),
            )
    # Button for clearning query field
    with semi_query_cols[-1]:
        st.button(
            label = "Clear query",
            key = "semi_clear_query",
            on_click = semi_clear_query,
        )
    # Search query input form
    with st.form(key='semi_query_form',clear_on_submit=False):
        query = st.text_input(
            label='Search query',
            value=st.session_state['semi']['query'],
            key='semi_query_field',
            label_visibility='collapsed',
            placeholder='Enter your Wikipedia search query here.',
        )
        query_submitted = st.form_submit_button("Submit")

        if query_submitted and query != '':
            st.session_state['semi']['query'] = query
            # Retrieve Wikipedia page list from
            # search results and store in session state
            with st.spinner('Retrieving Wiki pages...'):
                retriever = ContextRetriever()
                retriever.get_pageids(query)
                st.session_state['semi']['page_options'] = retriever.pageids
                st.session_state['semi']['selected_pages'] = []
    
### Populate page selection container ###
with semi_page_container:
    st.markdown('Next select any number of Wikipedia pages to provide to RoBERTa:')
    # Page title selection form
    with st.form(key='semi_page_form',clear_on_submit=False):
        selected_pages = st.multiselect(
                label = "Choose Wiki pages for Q&A model:",
                options = st.session_state['semi']['page_options'],
                default = st.session_state['semi']['selected_pages'],
                label_visibility = 'collapsed',
                key = "semi_page_selectbox",
                format_func = lambda x:x[1],
            )
        pages_submitted = st.form_submit_button("Submit")
        if pages_submitted:
            st.session_state['semi']['selected_pages'] = selected_pages

### Populate question input container ###
with semi_input_container:
    st.markdown('Finally submit a question for RoBERTa to answer based on the above pages or choose one of the examples.')
    # Question example buttons
    semi_ques_cols = st.columns(len(examples[0])+1)
    for i in range(len(examples[0])):
        with semi_ques_cols[i]:
            st.button(
                label = f'question {i+1}',
                key = f'semi_ques_button_{i+1}',
                on_click = semi_ex_question_click,
                args=(i,),
            )
    # Question field clear button
    with semi_ques_cols[-1]:
        st.button(
            label = "Clear question",
            key = "semi_clear_question",
            on_click = semi_clear_question,
        )
    # Question submission form
    with st.form(key = "semi_question_form",clear_on_submit=False):
        question = st.text_input(
            label='Question',
            value=st.session_state['semi']['question'],
            key='semi_question_field',
            label_visibility='collapsed',
            placeholder='Enter your question here.',
        )
        question_submitted = st.form_submit_button("Submit")
        if question_submitted and len(question)>0 and len(st.session_state['semi']['selected_pages'])>0:
            st.session_state['semi']['response'] = ''
            st.session_state['semi']['question'] = question
            # Retrieve pages corresponding to user selections,
            # extract paragraphs, and retrieve top 10 paragraphs,
            # ranked by relevance to user question
            with st.spinner("Retrieving documentation..."):
                retriever = ContextRetriever()
                pages = retriever.ids_to_pages(selected_pages)
                paragraphs = retriever.pages_to_paragraphs(pages)
                best_paragraphs = retriever.rank_paragraphs(
                    paragraphs, question,
                )     
            with st.spinner("Generating response..."):
                # Generate a response and update the session state
                response = generate_answer(
                    pipeline = qa_pipeline,
                    paragraphs = best_paragraphs,
                    question = st.session_state['semi']['question'],
                )
                st.session_state['semi']['response'] = response
                
### Populate response container ###
with semi_response_container:
    st.write(st.session_state['semi']['response'])

####################################
### Populate tab - automated Q&A ###
####################################

from lib.utils import auto_ex_click, auto_clear_boxes
    
### Intro text ###
with auto_title_container:
    st.header('Automated Wiki Q&A')
    st.markdown('''
This component attempts to automate the Wiki-assisted extractive question-answering task.  A Wikipedia search will be performed based on your question, and a list of relevant paragraphs will be passed to the RoBERTa model so it can attempt to find an answer.
    ''')
    with st.expander("Click here to find out what's happening behind the scenes..."):
        st.markdown('''
    When you submit a question, the following steps are performed:
    1. Your question is condensed into a search query which just retains nouns, verbs, numerals, and adjectives, where part-of-speech tagging is done using the [en_core_web_sm](https://spacy.io/models/en#en_core_web_sm) pipeline in the [spaCy library](https://spacy.io/).
    2. A Wikipedia search is performed using this query, resulting in several articles.  The articles from the top 3 search results are collected and split into paragraphs.  Wikipedia queries and article collection use the [wikipedia library](https://pypi.org/project/wikipedia/), a wrapper for the [MediaWiki API](https://www.mediawiki.org/wiki/API).
    4. The paragraphs are ranked in descending order of relevance to the query, using the [Okapi BM25 score](https://en.wikipedia.org/wiki/Okapi_BM25) as implemented in the [rank_bm25 library](https://github.com/dorianbrown/rank_bm25).
    5. The ten most relevant paragraphs are fed as context to the RoBERTa model, from which it will attempt to extract the answer to your question.  The 'hit' having the highest confidence (prediction probability) from the model is reported as the answer.
        ''')

    st.markdown('''
Please provide a question you'd like the model to try to answer.  The model will report back its answer, as well as an excerpt of text from Wikipedia in which it found its answer.  Your result will appear below the question field when the model is finished running.

Alternatively, you can try an example by clicking one of the buttons below:
    ''')

### Populate example container ###
with auto_example_container:
    auto_ex_cols = st.columns(len(examples[0])+1)
    # Buttons for selecting example questions
    for i in range(len(examples[0])):
        with auto_ex_cols[i]:
            st.button(
                label = f'example {i+1}',
                key = f'auto_ex_button_{i+1}',
                on_click = auto_ex_click,
                args=(examples,i,),
            )
    # Button for clearing question field and response
    with auto_ex_cols[-1]:
        st.button(
            label = "Clear all fields",
            key = "auto_clear_button",
            on_click = auto_clear_boxes,
        )

### Populate user input container ###
with auto_input_container:
    with st.form(key='auto_input_form',clear_on_submit=False):
        # Question input field
        question = st.text_input(
            label='Question',
            value=st.session_state['auto']['question'],
            key='auto_question_field',
            label_visibility='collapsed',                
            placeholder='Enter your question here.',
        )
        # Form submit button
        question_submitted = st.form_submit_button("Submit")
        if question_submitted:
            # update question, context in session state
            st.session_state['auto']['question'] = question
            query = generate_query(nlp,question)
            # query == '' will throw error in document retrieval
            if len(query)==0:
                st.session_state['auto']['response'] = 'Please include some nouns, verbs, and/or adjectives in your question.'
            elif len(question)>0:
                with st.spinner('Retrieving documentation...'):
                    # Retrieve ids from top 3 results
                    retriever = ContextRetriever()
                    retriever.get_pageids(query,topn=3)
                    # Retrieve pages then paragraphs
                    retriever.get_all_pages()
                    retriever.get_all_paragraphs()
                    # Get top 10 paragraphs, ranked by relevance to query
                    best_paragraphs = retriever.rank_paragraphs(retriever.paragraphs, query)
                with st.spinner('Generating response...'):
                    # Generate a response and update the session state
                    response = generate_answer(
                        pipeline = qa_pipeline,
                        paragraphs = best_paragraphs,
                        question = st.session_state['auto']['question'],
                    )
                    st.session_state['auto']['response'] = response
### Populate response container ###
with auto_response_container:
    st.write(st.session_state['auto']['response'])