Ramon Meffert
commited on
Commit
·
ab5dfc2
1
Parent(s):
fa8dc75
Add reader
Browse files- main.py +47 -18
- src/readers/dpr_reader.py +27 -0
- src/retrievers/{fais_retriever.py → faiss_retriever.py} +10 -9
- src/utils/preprocessing.py +35 -0
main.py
CHANGED
@@ -1,12 +1,21 @@
|
|
1 |
from datasets import DatasetDict, load_dataset
|
2 |
|
3 |
-
from src.
|
|
|
4 |
from src.utils.log import get_logger
|
5 |
-
from src.evaluation import evaluate
|
6 |
from typing import cast
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
|
|
|
|
|
|
|
|
|
10 |
|
11 |
if __name__ == '__main__':
|
12 |
dataset_name = "GroNLP/ik-nlp-22_slp"
|
@@ -15,24 +24,44 @@ if __name__ == '__main__':
|
|
15 |
|
16 |
questions_test = questions["test"]
|
17 |
|
18 |
-
logger.info(questions)
|
19 |
|
20 |
# Initialize retriever
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
f"Example q: {example_q} answer: {result['text'][0]}")
|
29 |
|
30 |
-
for i, score in enumerate(scores):
|
31 |
-
|
32 |
-
|
33 |
|
34 |
-
# Compute overall performance
|
35 |
-
exact_match, f1_score = evaluate(
|
36 |
-
|
37 |
-
|
38 |
-
f"F1-score: {f1_score:.02f}")
|
|
|
1 |
from datasets import DatasetDict, load_dataset
|
2 |
|
3 |
+
from src.readers.dpr_reader import DprReader
|
4 |
+
from src.retrievers.faiss_retriever import FaissRetriever
|
5 |
from src.utils.log import get_logger
|
6 |
+
# from src.evaluation import evaluate
|
7 |
from typing import cast
|
8 |
|
9 |
+
from src.utils.preprocessing import result_to_reader_input
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import transformers
|
13 |
+
import os
|
14 |
|
15 |
+
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
|
16 |
+
|
17 |
+
logger = get_logger()
|
18 |
+
transformers.logging.set_verbosity_error()
|
19 |
|
20 |
if __name__ == '__main__':
|
21 |
dataset_name = "GroNLP/ik-nlp-22_slp"
|
|
|
24 |
|
25 |
questions_test = questions["test"]
|
26 |
|
27 |
+
# logger.info(questions)
|
28 |
|
29 |
# Initialize retriever
|
30 |
+
retriever = FaissRetriever()
|
31 |
+
|
32 |
+
# Retrieve example
|
33 |
+
example_q = questions_test.shuffle()["question"][0]
|
34 |
+
scores, result = retriever.retrieve(example_q)
|
35 |
+
|
36 |
+
reader_input = result_to_reader_input(result)
|
37 |
+
|
38 |
+
# Initialize reader
|
39 |
+
reader = DprReader()
|
40 |
+
answers = reader.read(example_q, reader_input)
|
41 |
+
|
42 |
+
# Calculate softmaxed scores for readable output
|
43 |
+
sm = torch.nn.Softmax(dim=0)
|
44 |
+
document_scores = sm(torch.Tensor(
|
45 |
+
[pred.relevance_score for pred in answers]))
|
46 |
+
span_scores = sm(torch.Tensor(
|
47 |
+
[pred.span_score for pred in answers]))
|
48 |
|
49 |
+
print(example_q)
|
50 |
+
for answer_i, answer in enumerate(answers):
|
51 |
+
print(f"[{answer_i + 1}]: {answer.text}")
|
52 |
+
print(f"\tDocument {answer.doc_id}", end='')
|
53 |
+
print(f"\t(score {document_scores[answer_i] * 100:.02f})")
|
54 |
+
print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
|
55 |
+
print(f"\t(score {span_scores[answer_i] * 100:.02f})")
|
56 |
+
print() # Newline
|
57 |
|
58 |
+
# print(f"Example q: {example_q} answer: {result['text'][0]}")
|
|
|
59 |
|
60 |
+
# for i, score in enumerate(scores):
|
61 |
+
# print(f"Result {i+1} (score: {score:.02f}):")
|
62 |
+
# print(result['text'][i])
|
63 |
|
64 |
+
# # Compute overall performance
|
65 |
+
# exact_match, f1_score = evaluate(
|
66 |
+
# r, questions_test["question"], questions_test["answer"])
|
67 |
+
# print(f"Exact match: {exact_match:.02f}\n", f"F1-score: {f1_score:.02f}")
|
|
src/readers/dpr_reader.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import DPRReader, DPRReaderTokenizer
|
2 |
+
from typing import List, Dict, Tuple
|
3 |
+
|
4 |
+
|
5 |
+
class DprReader():
|
6 |
+
def __init__(self) -> None:
|
7 |
+
self._tokenizer = DPRReaderTokenizer.from_pretrained(
|
8 |
+
"facebook/dpr-reader-single-nq-base")
|
9 |
+
self._model = DPRReader.from_pretrained(
|
10 |
+
"facebook/dpr-reader-single-nq-base"
|
11 |
+
)
|
12 |
+
|
13 |
+
def read(self, query: str, context: Dict[str, List[str]]) -> List[Tuple]:
|
14 |
+
encoded_inputs = self._tokenizer(
|
15 |
+
questions=query,
|
16 |
+
titles=context['titles'],
|
17 |
+
texts=context['texts'],
|
18 |
+
return_tensors='pt',
|
19 |
+
truncation=True,
|
20 |
+
padding=True
|
21 |
+
)
|
22 |
+
outputs = self._model(**encoded_inputs)
|
23 |
+
|
24 |
+
predicted_spans = self._tokenizer.decode_best_spans(
|
25 |
+
encoded_inputs, outputs)
|
26 |
+
|
27 |
+
return predicted_spans
|
src/retrievers/{fais_retriever.py → faiss_retriever.py}
RENAMED
@@ -13,15 +13,15 @@ from transformers import (
|
|
13 |
from src.retrievers.base_retriever import Retriever
|
14 |
from src.utils.log import get_logger
|
15 |
|
16 |
-
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
17 |
# Hacky fix for FAISS error on macOS
|
18 |
# See https://stackoverflow.com/a/63374568/4545692
|
|
|
19 |
|
20 |
|
21 |
logger = get_logger()
|
22 |
|
23 |
|
24 |
-
class
|
25 |
"""A class used to retrieve relevant documents based on some query.
|
26 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
27 |
"""
|
@@ -56,14 +56,16 @@ class FAISRetriever(Retriever):
|
|
56 |
self.dataset_name = dataset_name
|
57 |
self.dataset = self._init_dataset(dataset_name)
|
58 |
|
59 |
-
def _init_dataset(
|
60 |
-
|
61 |
-
|
|
|
|
|
62 |
"""Loads the dataset and adds FAISS embeddings.
|
63 |
|
64 |
Args:
|
65 |
dataset (str): A HuggingFace dataset name.
|
66 |
-
fname (str): The name to use to save the embeddings to disk for
|
67 |
faster loading after the first run.
|
68 |
|
69 |
Returns:
|
@@ -73,9 +75,8 @@ class FAISRetriever(Retriever):
|
|
73 |
# Load dataset
|
74 |
ds = load_dataset(dataset_name, name="paragraphs")[
|
75 |
"train"] # type: ignore
|
76 |
-
logger.info(ds)
|
77 |
|
78 |
-
if os.path.exists(embedding_path):
|
79 |
# If we already have FAISS embeddings, load them from disk
|
80 |
ds.load_faiss_index('embeddings', embedding_path) # type: ignore
|
81 |
return ds
|
@@ -95,7 +96,7 @@ class FAISRetriever(Retriever):
|
|
95 |
ds_with_embeddings.add_faiss_index(column="embeddings")
|
96 |
|
97 |
# save dataset w/ embeddings
|
98 |
-
os.makedirs("./models/", exist_ok=True)
|
99 |
ds_with_embeddings.save_faiss_index("embeddings", embedding_path)
|
100 |
|
101 |
return ds_with_embeddings
|
|
|
13 |
from src.retrievers.base_retriever import Retriever
|
14 |
from src.utils.log import get_logger
|
15 |
|
|
|
16 |
# Hacky fix for FAISS error on macOS
|
17 |
# See https://stackoverflow.com/a/63374568/4545692
|
18 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
19 |
|
20 |
|
21 |
logger = get_logger()
|
22 |
|
23 |
|
24 |
+
class FaissRetriever(Retriever):
|
25 |
"""A class used to retrieve relevant documents based on some query.
|
26 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
27 |
"""
|
|
|
56 |
self.dataset_name = dataset_name
|
57 |
self.dataset = self._init_dataset(dataset_name)
|
58 |
|
59 |
+
def _init_dataset(
|
60 |
+
self,
|
61 |
+
dataset_name: str,
|
62 |
+
embedding_path: str = "./src/models/paragraphs_embedding.faiss",
|
63 |
+
force_new_embedding: bool = False):
|
64 |
"""Loads the dataset and adds FAISS embeddings.
|
65 |
|
66 |
Args:
|
67 |
dataset (str): A HuggingFace dataset name.
|
68 |
+
fname (str): The name to use to save the embeddings to disk for
|
69 |
faster loading after the first run.
|
70 |
|
71 |
Returns:
|
|
|
75 |
# Load dataset
|
76 |
ds = load_dataset(dataset_name, name="paragraphs")[
|
77 |
"train"] # type: ignore
|
|
|
78 |
|
79 |
+
if not force_new_embedding and os.path.exists(embedding_path):
|
80 |
# If we already have FAISS embeddings, load them from disk
|
81 |
ds.load_faiss_index('embeddings', embedding_path) # type: ignore
|
82 |
return ds
|
|
|
96 |
ds_with_embeddings.add_faiss_index(column="embeddings")
|
97 |
|
98 |
# save dataset w/ embeddings
|
99 |
+
os.makedirs("./src/models/", exist_ok=True)
|
100 |
ds_with_embeddings.save_faiss_index("embeddings", embedding_path)
|
101 |
|
102 |
return ds_with_embeddings
|
src/utils/preprocessing.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
|
3 |
+
|
4 |
+
def result_to_reader_input(result: Dict[str, List[str]]) \
|
5 |
+
-> Dict[str, List[str]]:
|
6 |
+
"""Takes the output of the retriever and turns it into a format the reader
|
7 |
+
understands.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
result (Dict[str, List[str]]): The result from the retriever
|
11 |
+
"""
|
12 |
+
|
13 |
+
# Take the number of valeus of an arbitrary item as the number of entries
|
14 |
+
# (This should always be valid)
|
15 |
+
num_entries = len(result['n_chapter'])
|
16 |
+
|
17 |
+
# Prepare result
|
18 |
+
reader_result = {
|
19 |
+
'titles': [],
|
20 |
+
'texts': []
|
21 |
+
}
|
22 |
+
|
23 |
+
for n in range(num_entries):
|
24 |
+
# Get the most specific title
|
25 |
+
if result['subsection'][n] != 'nan':
|
26 |
+
title = result['subsection'][n]
|
27 |
+
elif result['section'][n] != 'nan':
|
28 |
+
title = result['section'][n]
|
29 |
+
else:
|
30 |
+
title = result['chapter'][n]
|
31 |
+
|
32 |
+
reader_result['titles'].append(title)
|
33 |
+
reader_result['texts'].append(result['text'][n])
|
34 |
+
|
35 |
+
return reader_result
|