Ramon Meffert commited on
Commit
ab5dfc2
·
1 Parent(s): fa8dc75

Add reader

Browse files
main.py CHANGED
@@ -1,12 +1,21 @@
1
  from datasets import DatasetDict, load_dataset
2
 
3
- from src.retrievers.fais_retriever import FAISRetriever
 
4
  from src.utils.log import get_logger
5
- from src.evaluation import evaluate
6
  from typing import cast
7
 
8
- logger = get_logger()
 
 
 
 
9
 
 
 
 
 
10
 
11
  if __name__ == '__main__':
12
  dataset_name = "GroNLP/ik-nlp-22_slp"
@@ -15,24 +24,44 @@ if __name__ == '__main__':
15
 
16
  questions_test = questions["test"]
17
 
18
- logger.info(questions)
19
 
20
  # Initialize retriever
21
- r = FAISRetriever()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # # Retrieve example
24
- example_q = "What is the perplexity of a language model?"
25
- scores, result = r.retrieve(example_q)
 
 
 
 
 
26
 
27
- logger.info(
28
- f"Example q: {example_q} answer: {result['text'][0]}")
29
 
30
- for i, score in enumerate(scores):
31
- logger.info(f"Result {i+1} (score: {score:.02f}):")
32
- logger.info(result['text'][i])
33
 
34
- # Compute overall performance
35
- exact_match, f1_score = evaluate(
36
- r, questions_test["question"], questions_test["answer"])
37
- logger.info(f"Exact match: {exact_match:.02f}\n"
38
- f"F1-score: {f1_score:.02f}")
 
1
  from datasets import DatasetDict, load_dataset
2
 
3
+ from src.readers.dpr_reader import DprReader
4
+ from src.retrievers.faiss_retriever import FaissRetriever
5
  from src.utils.log import get_logger
6
+ # from src.evaluation import evaluate
7
  from typing import cast
8
 
9
+ from src.utils.preprocessing import result_to_reader_input
10
+
11
+ import torch
12
+ import transformers
13
+ import os
14
 
15
+ os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
16
+
17
+ logger = get_logger()
18
+ transformers.logging.set_verbosity_error()
19
 
20
  if __name__ == '__main__':
21
  dataset_name = "GroNLP/ik-nlp-22_slp"
 
24
 
25
  questions_test = questions["test"]
26
 
27
+ # logger.info(questions)
28
 
29
  # Initialize retriever
30
+ retriever = FaissRetriever()
31
+
32
+ # Retrieve example
33
+ example_q = questions_test.shuffle()["question"][0]
34
+ scores, result = retriever.retrieve(example_q)
35
+
36
+ reader_input = result_to_reader_input(result)
37
+
38
+ # Initialize reader
39
+ reader = DprReader()
40
+ answers = reader.read(example_q, reader_input)
41
+
42
+ # Calculate softmaxed scores for readable output
43
+ sm = torch.nn.Softmax(dim=0)
44
+ document_scores = sm(torch.Tensor(
45
+ [pred.relevance_score for pred in answers]))
46
+ span_scores = sm(torch.Tensor(
47
+ [pred.span_score for pred in answers]))
48
 
49
+ print(example_q)
50
+ for answer_i, answer in enumerate(answers):
51
+ print(f"[{answer_i + 1}]: {answer.text}")
52
+ print(f"\tDocument {answer.doc_id}", end='')
53
+ print(f"\t(score {document_scores[answer_i] * 100:.02f})")
54
+ print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
55
+ print(f"\t(score {span_scores[answer_i] * 100:.02f})")
56
+ print() # Newline
57
 
58
+ # print(f"Example q: {example_q} answer: {result['text'][0]}")
 
59
 
60
+ # for i, score in enumerate(scores):
61
+ # print(f"Result {i+1} (score: {score:.02f}):")
62
+ # print(result['text'][i])
63
 
64
+ # # Compute overall performance
65
+ # exact_match, f1_score = evaluate(
66
+ # r, questions_test["question"], questions_test["answer"])
67
+ # print(f"Exact match: {exact_match:.02f}\n", f"F1-score: {f1_score:.02f}")
 
src/readers/dpr_reader.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DPRReader, DPRReaderTokenizer
2
+ from typing import List, Dict, Tuple
3
+
4
+
5
+ class DprReader():
6
+ def __init__(self) -> None:
7
+ self._tokenizer = DPRReaderTokenizer.from_pretrained(
8
+ "facebook/dpr-reader-single-nq-base")
9
+ self._model = DPRReader.from_pretrained(
10
+ "facebook/dpr-reader-single-nq-base"
11
+ )
12
+
13
+ def read(self, query: str, context: Dict[str, List[str]]) -> List[Tuple]:
14
+ encoded_inputs = self._tokenizer(
15
+ questions=query,
16
+ titles=context['titles'],
17
+ texts=context['texts'],
18
+ return_tensors='pt',
19
+ truncation=True,
20
+ padding=True
21
+ )
22
+ outputs = self._model(**encoded_inputs)
23
+
24
+ predicted_spans = self._tokenizer.decode_best_spans(
25
+ encoded_inputs, outputs)
26
+
27
+ return predicted_spans
src/retrievers/{fais_retriever.py → faiss_retriever.py} RENAMED
@@ -13,15 +13,15 @@ from transformers import (
13
  from src.retrievers.base_retriever import Retriever
14
  from src.utils.log import get_logger
15
 
16
- os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
17
  # Hacky fix for FAISS error on macOS
18
  # See https://stackoverflow.com/a/63374568/4545692
 
19
 
20
 
21
  logger = get_logger()
22
 
23
 
24
- class FAISRetriever(Retriever):
25
  """A class used to retrieve relevant documents based on some query.
26
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
27
  """
@@ -56,14 +56,16 @@ class FAISRetriever(Retriever):
56
  self.dataset_name = dataset_name
57
  self.dataset = self._init_dataset(dataset_name)
58
 
59
- def _init_dataset(self,
60
- dataset_name: str,
61
- embedding_path: str = "./models/paragraphs_embedding.faiss"):
 
 
62
  """Loads the dataset and adds FAISS embeddings.
63
 
64
  Args:
65
  dataset (str): A HuggingFace dataset name.
66
- fname (str): The name to use to save the embeddings to disk for
67
  faster loading after the first run.
68
 
69
  Returns:
@@ -73,9 +75,8 @@ class FAISRetriever(Retriever):
73
  # Load dataset
74
  ds = load_dataset(dataset_name, name="paragraphs")[
75
  "train"] # type: ignore
76
- logger.info(ds)
77
 
78
- if os.path.exists(embedding_path):
79
  # If we already have FAISS embeddings, load them from disk
80
  ds.load_faiss_index('embeddings', embedding_path) # type: ignore
81
  return ds
@@ -95,7 +96,7 @@ class FAISRetriever(Retriever):
95
  ds_with_embeddings.add_faiss_index(column="embeddings")
96
 
97
  # save dataset w/ embeddings
98
- os.makedirs("./models/", exist_ok=True)
99
  ds_with_embeddings.save_faiss_index("embeddings", embedding_path)
100
 
101
  return ds_with_embeddings
 
13
  from src.retrievers.base_retriever import Retriever
14
  from src.utils.log import get_logger
15
 
 
16
  # Hacky fix for FAISS error on macOS
17
  # See https://stackoverflow.com/a/63374568/4545692
18
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
19
 
20
 
21
  logger = get_logger()
22
 
23
 
24
+ class FaissRetriever(Retriever):
25
  """A class used to retrieve relevant documents based on some query.
26
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
27
  """
 
56
  self.dataset_name = dataset_name
57
  self.dataset = self._init_dataset(dataset_name)
58
 
59
+ def _init_dataset(
60
+ self,
61
+ dataset_name: str,
62
+ embedding_path: str = "./src/models/paragraphs_embedding.faiss",
63
+ force_new_embedding: bool = False):
64
  """Loads the dataset and adds FAISS embeddings.
65
 
66
  Args:
67
  dataset (str): A HuggingFace dataset name.
68
+ fname (str): The name to use to save the embeddings to disk for
69
  faster loading after the first run.
70
 
71
  Returns:
 
75
  # Load dataset
76
  ds = load_dataset(dataset_name, name="paragraphs")[
77
  "train"] # type: ignore
 
78
 
79
+ if not force_new_embedding and os.path.exists(embedding_path):
80
  # If we already have FAISS embeddings, load them from disk
81
  ds.load_faiss_index('embeddings', embedding_path) # type: ignore
82
  return ds
 
96
  ds_with_embeddings.add_faiss_index(column="embeddings")
97
 
98
  # save dataset w/ embeddings
99
+ os.makedirs("./src/models/", exist_ok=True)
100
  ds_with_embeddings.save_faiss_index("embeddings", embedding_path)
101
 
102
  return ds_with_embeddings
src/utils/preprocessing.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+
4
+ def result_to_reader_input(result: Dict[str, List[str]]) \
5
+ -> Dict[str, List[str]]:
6
+ """Takes the output of the retriever and turns it into a format the reader
7
+ understands.
8
+
9
+ Args:
10
+ result (Dict[str, List[str]]): The result from the retriever
11
+ """
12
+
13
+ # Take the number of valeus of an arbitrary item as the number of entries
14
+ # (This should always be valid)
15
+ num_entries = len(result['n_chapter'])
16
+
17
+ # Prepare result
18
+ reader_result = {
19
+ 'titles': [],
20
+ 'texts': []
21
+ }
22
+
23
+ for n in range(num_entries):
24
+ # Get the most specific title
25
+ if result['subsection'][n] != 'nan':
26
+ title = result['subsection'][n]
27
+ elif result['section'][n] != 'nan':
28
+ title = result['section'][n]
29
+ else:
30
+ title = result['chapter'][n]
31
+
32
+ reader_result['titles'].append(title)
33
+ reader_result['texts'].append(result['text'][n])
34
+
35
+ return reader_result