Ramon Meffert
commited on
Commit
·
be1f224
1
Parent(s):
b06298d
Add longformer
Browse files- .gitattributes +2 -0
- README.md +7 -2
- query.py +61 -18
- src/models/{paragraphs_embedding.faiss → dpr.faiss} +1 -1
- src/models/longformer.faiss +3 -0
- src/readers/base_reader.py +9 -0
- src/readers/dpr_reader.py +3 -1
- src/readers/longformer_reader.py +41 -0
- src/retrievers/faiss_retriever.py +89 -33
.gitattributes
CHANGED
@@ -28,3 +28,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
28 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
29 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
30 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
28 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
29 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
30 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
31 |
+
src/models/dpr.faiss filter=lfs diff=lfs merge=lfs -text
|
32 |
+
src/models/longformer.faiss filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -75,7 +75,10 @@ By default, the best answer along with its location in the book will be
|
|
75 |
returned. If you want to generate more answers (say, a top-5), you can supply
|
76 |
the `--top=5` option. The default retriever uses [FAISS](https://faiss.ai/), but
|
77 |
you can also use [ElasticSearch](https://www.elastic.co/elastic-stack/) using
|
78 |
-
the `--retriever=es` option.
|
|
|
|
|
|
|
79 |
|
80 |
### CLI overview
|
81 |
|
@@ -83,7 +86,7 @@ To get an overview of all available options, run `python query.py --help`. The
|
|
83 |
options are also printed below.
|
84 |
|
85 |
```sh
|
86 |
-
usage: query.py [-h] [--top int] [--retriever {faiss,es}] str
|
87 |
|
88 |
positional arguments:
|
89 |
str The question to feed to the QA system
|
@@ -93,6 +96,8 @@ options:
|
|
93 |
--top int, -t int The number of answers to retrieve
|
94 |
--retriever {faiss,es}, -r {faiss,es}
|
95 |
The retrieval method to use
|
|
|
|
|
96 |
```
|
97 |
|
98 |
|
|
|
75 |
returned. If you want to generate more answers (say, a top-5), you can supply
|
76 |
the `--top=5` option. The default retriever uses [FAISS](https://faiss.ai/), but
|
77 |
you can also use [ElasticSearch](https://www.elastic.co/elastic-stack/) using
|
78 |
+
the `--retriever=es` option. You can also pick a language model using the
|
79 |
+
`--lm` option, which accepts either `dpr` (Dense Passage Retrieval) or
|
80 |
+
`longformer`. The language model is used to generate embeddings for FAISS, and
|
81 |
+
is used to generate the answer.
|
82 |
|
83 |
### CLI overview
|
84 |
|
|
|
86 |
options are also printed below.
|
87 |
|
88 |
```sh
|
89 |
+
usage: query.py [-h] [--top int] [--retriever {faiss,es}] [--lm {dpr,longformer}] str
|
90 |
|
91 |
positional arguments:
|
92 |
str The question to feed to the QA system
|
|
|
96 |
--top int, -t int The number of answers to retrieve
|
97 |
--retriever {faiss,es}, -r {faiss,es}
|
98 |
The retrieval method to use
|
99 |
+
--lm {dpr,longformer}, -l {dpr,longformer}
|
100 |
+
The language model to use for the FAISS retriever
|
101 |
```
|
102 |
|
103 |
|
query.py
CHANGED
@@ -2,21 +2,48 @@ import argparse
|
|
2 |
import torch
|
3 |
import transformers
|
4 |
|
5 |
-
from typing import List, Literal,
|
6 |
from datasets import load_dataset, DatasetDict
|
7 |
from dotenv import load_dotenv
|
8 |
|
|
|
|
|
9 |
from src.readers.dpr_reader import DprReader
|
10 |
from src.retrievers.base_retriever import Retriever
|
11 |
from src.retrievers.es_retriever import ESRetriever
|
12 |
-
from src.retrievers.faiss_retriever import
|
|
|
|
|
|
|
13 |
from src.utils.preprocessing import context_to_reader_input
|
14 |
from src.utils.log import get_logger
|
15 |
|
16 |
|
17 |
-
def get_retriever(
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
def print_name(contexts: dict, section: str, id: int):
|
@@ -51,7 +78,11 @@ def print_answers(answers: List[tuple], scores: List[float], contexts: dict):
|
|
51 |
print()
|
52 |
|
53 |
|
54 |
-
def probe(query: str,
|
|
|
|
|
|
|
|
|
55 |
scores, contexts = retriever.retrieve(query)
|
56 |
reader_input = context_to_reader_input(contexts)
|
57 |
answers = reader.read(query, reader_input, num_answers)
|
@@ -63,7 +94,7 @@ def default_probe(query: str):
|
|
63 |
# default probe is a probe that prints 5 answers with faiss
|
64 |
paragraphs = cast(DatasetDict, load_dataset(
|
65 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
66 |
-
retriever = get_retriever("faiss",
|
67 |
reader = DprReader()
|
68 |
|
69 |
return probe(query, retriever, reader)
|
@@ -75,13 +106,20 @@ def main(args: argparse.Namespace):
|
|
75 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
76 |
|
77 |
# Retrieve
|
78 |
-
retriever = get_retriever(args.retriever,
|
79 |
-
reader =
|
80 |
answers, scores, contexts = probe(
|
81 |
-
args.query, retriever, reader, args.
|
82 |
|
83 |
# Print output
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
|
87 |
if __name__ == "__main__":
|
@@ -94,13 +132,18 @@ if __name__ == "__main__":
|
|
94 |
parser = argparse.ArgumentParser(
|
95 |
formatter_class=argparse.MetavarTypeHelpFormatter
|
96 |
)
|
97 |
-
parser.add_argument(
|
98 |
-
|
99 |
-
parser.add_argument(
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
args = parser.parse_args()
|
106 |
main(args)
|
|
|
2 |
import torch
|
3 |
import transformers
|
4 |
|
5 |
+
from typing import Dict, List, Literal, Tuple, cast
|
6 |
from datasets import load_dataset, DatasetDict
|
7 |
from dotenv import load_dotenv
|
8 |
|
9 |
+
from src.readers.base_reader import Reader
|
10 |
+
from src.readers.longformer_reader import LongformerReader
|
11 |
from src.readers.dpr_reader import DprReader
|
12 |
from src.retrievers.base_retriever import Retriever
|
13 |
from src.retrievers.es_retriever import ESRetriever
|
14 |
+
from src.retrievers.faiss_retriever import (
|
15 |
+
FaissRetriever,
|
16 |
+
FaissRetrieverOptions
|
17 |
+
)
|
18 |
from src.utils.preprocessing import context_to_reader_input
|
19 |
from src.utils.log import get_logger
|
20 |
|
21 |
|
22 |
+
def get_retriever(paragraphs: DatasetDict,
|
23 |
+
r: Literal["es", "faiss"],
|
24 |
+
lm: Literal["dpr", "longformer"]) -> Retriever:
|
25 |
+
match (r, lm):
|
26 |
+
case "es", _:
|
27 |
+
return ESRetriever()
|
28 |
+
case "faiss", "dpr":
|
29 |
+
options = FaissRetrieverOptions.dpr("./src/models/dpr.faiss")
|
30 |
+
return FaissRetriever(paragraphs, options)
|
31 |
+
case "faiss", "longformer":
|
32 |
+
options = FaissRetrieverOptions.longformer(
|
33 |
+
"./src/models/longformer.faiss")
|
34 |
+
return FaissRetriever(paragraphs, options)
|
35 |
+
case _:
|
36 |
+
raise ValueError("Retriever options not recognized")
|
37 |
+
|
38 |
+
|
39 |
+
def get_reader(lm: Literal["dpr", "longformer"]) -> Reader:
|
40 |
+
match lm:
|
41 |
+
case "dpr":
|
42 |
+
return DprReader()
|
43 |
+
case "longformer":
|
44 |
+
return LongformerReader()
|
45 |
+
case _:
|
46 |
+
raise ValueError("Language model not recognized")
|
47 |
|
48 |
|
49 |
def print_name(contexts: dict, section: str, id: int):
|
|
|
78 |
print()
|
79 |
|
80 |
|
81 |
+
def probe(query: str,
|
82 |
+
retriever: Retriever,
|
83 |
+
reader: Reader,
|
84 |
+
num_answers: int = 5) \
|
85 |
+
-> Tuple[List[tuple], List[float], Dict[str, List[str]]]:
|
86 |
scores, contexts = retriever.retrieve(query)
|
87 |
reader_input = context_to_reader_input(contexts)
|
88 |
answers = reader.read(query, reader_input, num_answers)
|
|
|
94 |
# default probe is a probe that prints 5 answers with faiss
|
95 |
paragraphs = cast(DatasetDict, load_dataset(
|
96 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
97 |
+
retriever = get_retriever(paragraphs, "faiss", "dpr")
|
98 |
reader = DprReader()
|
99 |
|
100 |
return probe(query, retriever, reader)
|
|
|
106 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
107 |
|
108 |
# Retrieve
|
109 |
+
retriever = get_retriever(paragraphs, args.retriever, args.lm)
|
110 |
+
reader = get_reader(args.lm)
|
111 |
answers, scores, contexts = probe(
|
112 |
+
args.query, retriever, reader, args.top)
|
113 |
|
114 |
# Print output
|
115 |
+
print("Question: " + args.query)
|
116 |
+
print("Answer(s):")
|
117 |
+
if args.lm == "dpr":
|
118 |
+
print_answers(answers, scores, contexts)
|
119 |
+
else:
|
120 |
+
answers = filter(lambda a: len(a[0].strip()) > 0, answers)
|
121 |
+
for pos, answer in enumerate(answers, start=1):
|
122 |
+
print(f" - {answer[0].strip()}")
|
123 |
|
124 |
|
125 |
if __name__ == "__main__":
|
|
|
132 |
parser = argparse.ArgumentParser(
|
133 |
formatter_class=argparse.MetavarTypeHelpFormatter
|
134 |
)
|
135 |
+
parser.add_argument(
|
136 |
+
"query", type=str, help="The question to feed to the QA system")
|
137 |
+
parser.add_argument(
|
138 |
+
"--top", "-t", type=int, default=1,
|
139 |
+
help="The number of answers to retrieve")
|
140 |
+
parser.add_argument(
|
141 |
+
"--retriever", "-r", type=str.lower, choices=["faiss", "es"],
|
142 |
+
default="faiss", help="The retrieval method to use")
|
143 |
+
parser.add_argument(
|
144 |
+
"--lm", "-l", type=str.lower,
|
145 |
+
choices=["dpr", "longformer"], default="dpr",
|
146 |
+
help="The language model to use for the FAISS retriever")
|
147 |
|
148 |
args = parser.parse_args()
|
149 |
main(args)
|
src/models/{paragraphs_embedding.faiss → dpr.faiss}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 5213229
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6bc0e5c38ddeb0a6a4daaf3ae98cd3e564f22ff9a263bc8867d0b363e828ccce
|
3 |
size 5213229
|
src/models/longformer.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:56b2616392540f4d2d8fa34d313a59c41572dca3ef5a683c7a8dbd2691418ea6
|
3 |
+
size 5213229
|
src/readers/base_reader.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Tuple
|
2 |
+
|
3 |
+
|
4 |
+
class Reader():
|
5 |
+
def read(self,
|
6 |
+
query: str,
|
7 |
+
context: Dict[str, List[str]],
|
8 |
+
num_answers: int) -> List[Tuple]:
|
9 |
+
raise NotImplementedError()
|
src/readers/dpr_reader.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
from transformers import DPRReader, DPRReaderTokenizer
|
2 |
from typing import List, Dict, Tuple
|
3 |
|
|
|
4 |
|
5 |
-
|
|
|
6 |
def __init__(self) -> None:
|
7 |
self._tokenizer = DPRReaderTokenizer.from_pretrained(
|
8 |
"facebook/dpr-reader-single-nq-base")
|
|
|
1 |
from transformers import DPRReader, DPRReaderTokenizer
|
2 |
from typing import List, Dict, Tuple
|
3 |
|
4 |
+
from src.readers.base_reader import Reader
|
5 |
|
6 |
+
|
7 |
+
class DprReader(Reader):
|
8 |
def __init__(self) -> None:
|
9 |
self._tokenizer = DPRReaderTokenizer.from_pretrained(
|
10 |
"facebook/dpr-reader-single-nq-base")
|
src/readers/longformer_reader.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import (
|
3 |
+
LongformerTokenizerFast,
|
4 |
+
LongformerForQuestionAnswering
|
5 |
+
)
|
6 |
+
from typing import List, Dict, Tuple
|
7 |
+
|
8 |
+
from src.readers.base_reader import Reader
|
9 |
+
|
10 |
+
|
11 |
+
class LongformerReader(Reader):
|
12 |
+
def __init__(self) -> None:
|
13 |
+
checkpoint = "valhalla/longformer-base-4096-finetuned-squadv1"
|
14 |
+
self.tokenizer = LongformerTokenizerFast.from_pretrained(checkpoint)
|
15 |
+
self.model = LongformerForQuestionAnswering.from_pretrained(checkpoint)
|
16 |
+
|
17 |
+
def read(self,
|
18 |
+
query: str,
|
19 |
+
context: Dict[str, List[str]],
|
20 |
+
num_answers=5) -> List[Tuple]:
|
21 |
+
answers = []
|
22 |
+
|
23 |
+
for text in context['texts']:
|
24 |
+
encoding = self.tokenizer(
|
25 |
+
query, text, return_tensors="pt")
|
26 |
+
input_ids = encoding["input_ids"]
|
27 |
+
attention_mask = encoding["attention_mask"]
|
28 |
+
outputs = self.model(input_ids, attention_mask=attention_mask)
|
29 |
+
|
30 |
+
start_logits = outputs.start_logits
|
31 |
+
end_logits = outputs.end_logits
|
32 |
+
all_tokens = self.tokenizer.convert_ids_to_tokens(
|
33 |
+
input_ids[0].tolist())
|
34 |
+
answer_tokens = all_tokens[
|
35 |
+
torch.argmax(start_logits):torch.argmax(end_logits) + 1]
|
36 |
+
answer = self.tokenizer.decode(
|
37 |
+
self.tokenizer.convert_tokens_to_ids(answer_tokens)
|
38 |
+
)
|
39 |
+
answers.append([answer, [], []])
|
40 |
+
|
41 |
+
return answers
|
src/retrievers/faiss_retriever.py
CHANGED
@@ -1,14 +1,19 @@
|
|
1 |
import os
|
2 |
import os.path
|
3 |
-
|
4 |
import torch
|
5 |
-
|
|
|
|
|
6 |
from transformers import (
|
7 |
DPRContextEncoder,
|
8 |
-
|
9 |
DPRQuestionEncoder,
|
10 |
-
|
|
|
|
|
11 |
)
|
|
|
|
|
12 |
|
13 |
from src.retrievers.base_retriever import RetrieveType, Retriever
|
14 |
from src.utils.log import get_logger
|
@@ -23,35 +28,99 @@ os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
|
23 |
logger = get_logger()
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
class FaissRetriever(Retriever):
|
27 |
"""A class used to retrieve relevant documents based on some query.
|
28 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
29 |
"""
|
30 |
|
31 |
-
def __init__(self, paragraphs: DatasetDict,
|
|
|
32 |
torch.set_grad_enabled(False)
|
33 |
|
|
|
|
|
34 |
# Context encoding and tokenization
|
35 |
-
self.ctx_encoder =
|
36 |
-
|
37 |
-
)
|
38 |
-
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
|
39 |
-
"facebook/dpr-ctx_encoder-single-nq-base"
|
40 |
-
)
|
41 |
|
42 |
# Question encoding and tokenization
|
43 |
-
self.q_encoder =
|
44 |
-
|
45 |
-
)
|
46 |
-
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
47 |
-
"facebook/dpr-question_encoder-single-nq-base"
|
48 |
-
)
|
49 |
|
50 |
self.paragraphs = paragraphs
|
51 |
-
self.embedding_path = embedding_path
|
52 |
|
53 |
self.index = self._init_index()
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def _init_index(
|
56 |
self,
|
57 |
force_new_embedding: bool = False):
|
@@ -64,16 +133,8 @@ class FaissRetriever(Retriever):
|
|
64 |
'embeddings', self.embedding_path) # type: ignore
|
65 |
return ds
|
66 |
else:
|
67 |
-
def embed(row):
|
68 |
-
# Inline helper function to perform embedding
|
69 |
-
p = row["text"]
|
70 |
-
tok = self.ctx_tokenizer(
|
71 |
-
p, return_tensors="pt", truncation=True)
|
72 |
-
enc = self.ctx_encoder(**tok)[0][0].numpy()
|
73 |
-
return {"embeddings": enc}
|
74 |
-
|
75 |
# Add FAISS embeddings
|
76 |
-
index = ds.map(
|
77 |
|
78 |
index.add_faiss_index(column="embeddings")
|
79 |
|
@@ -86,12 +147,7 @@ class FaissRetriever(Retriever):
|
|
86 |
|
87 |
@timeit("faissretriever.retrieve")
|
88 |
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
|
89 |
-
|
90 |
-
# Inline helper function to perform embedding
|
91 |
-
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
|
92 |
-
return self.q_encoder(**tok)[0][0].numpy()
|
93 |
-
|
94 |
-
question_embedding = embed(query)
|
95 |
scores, results = self.index.get_nearest_examples(
|
96 |
"embeddings", question_embedding, k=k
|
97 |
)
|
|
|
1 |
import os
|
2 |
import os.path
|
|
|
3 |
import torch
|
4 |
+
|
5 |
+
from datasets import DatasetDict
|
6 |
+
from dataclasses import dataclass
|
7 |
from transformers import (
|
8 |
DPRContextEncoder,
|
9 |
+
DPRContextEncoderTokenizerFast,
|
10 |
DPRQuestionEncoder,
|
11 |
+
DPRQuestionEncoderTokenizerFast,
|
12 |
+
LongformerModel,
|
13 |
+
LongformerTokenizerFast
|
14 |
)
|
15 |
+
from transformers.modeling_utils import PreTrainedModel
|
16 |
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
17 |
|
18 |
from src.retrievers.base_retriever import RetrieveType, Retriever
|
19 |
from src.utils.log import get_logger
|
|
|
28 |
logger = get_logger()
|
29 |
|
30 |
|
31 |
+
@dataclass
|
32 |
+
class FaissRetrieverOptions:
|
33 |
+
ctx_encoder: PreTrainedModel
|
34 |
+
ctx_tokenizer: PreTrainedTokenizerFast
|
35 |
+
q_encoder: PreTrainedModel
|
36 |
+
q_tokenizer: PreTrainedTokenizerFast
|
37 |
+
embedding_path: str
|
38 |
+
lm: str
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def dpr(embedding_path: str):
|
42 |
+
return FaissRetrieverOptions(
|
43 |
+
ctx_encoder=DPRContextEncoder.from_pretrained(
|
44 |
+
"facebook/dpr-ctx_encoder-single-nq-base"
|
45 |
+
),
|
46 |
+
ctx_tokenizer=DPRContextEncoderTokenizerFast.from_pretrained(
|
47 |
+
"facebook/dpr-ctx_encoder-single-nq-base"
|
48 |
+
),
|
49 |
+
q_encoder=DPRQuestionEncoder.from_pretrained(
|
50 |
+
"facebook/dpr-question_encoder-single-nq-base"
|
51 |
+
),
|
52 |
+
q_tokenizer=DPRQuestionEncoderTokenizerFast.from_pretrained(
|
53 |
+
"facebook/dpr-question_encoder-single-nq-base"
|
54 |
+
),
|
55 |
+
embedding_path=embedding_path,
|
56 |
+
lm="dpr"
|
57 |
+
)
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def longformer(embedding_path: str):
|
61 |
+
encoder = LongformerModel.from_pretrained(
|
62 |
+
"allenai/longformer-base-4096"
|
63 |
+
)
|
64 |
+
tokenizer = LongformerTokenizerFast.from_pretrained(
|
65 |
+
"allenai/longformer-base-4096"
|
66 |
+
)
|
67 |
+
return FaissRetrieverOptions(
|
68 |
+
ctx_encoder=encoder,
|
69 |
+
ctx_tokenizer=tokenizer,
|
70 |
+
q_encoder=encoder,
|
71 |
+
q_tokenizer=tokenizer,
|
72 |
+
embedding_path=embedding_path,
|
73 |
+
lm="longformer"
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
class FaissRetriever(Retriever):
|
78 |
"""A class used to retrieve relevant documents based on some query.
|
79 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
80 |
"""
|
81 |
|
82 |
+
def __init__(self, paragraphs: DatasetDict,
|
83 |
+
options: FaissRetrieverOptions) -> None:
|
84 |
torch.set_grad_enabled(False)
|
85 |
|
86 |
+
self.lm = options.lm
|
87 |
+
|
88 |
# Context encoding and tokenization
|
89 |
+
self.ctx_encoder = options.ctx_encoder
|
90 |
+
self.ctx_tokenizer = options.ctx_tokenizer
|
|
|
|
|
|
|
|
|
91 |
|
92 |
# Question encoding and tokenization
|
93 |
+
self.q_encoder = options.q_encoder
|
94 |
+
self.q_tokenizer = options.q_tokenizer
|
|
|
|
|
|
|
|
|
95 |
|
96 |
self.paragraphs = paragraphs
|
97 |
+
self.embedding_path = options.embedding_path
|
98 |
|
99 |
self.index = self._init_index()
|
100 |
|
101 |
+
def _embed_question(self, q):
|
102 |
+
match self.lm:
|
103 |
+
case "dpr":
|
104 |
+
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
|
105 |
+
return self.q_encoder(**tok)[0][0].numpy()
|
106 |
+
case "longformer":
|
107 |
+
tok = self.q_tokenizer(q, return_tensors="pt")
|
108 |
+
return self.q_encoder(**tok).last_hidden_state[0][0].numpy()
|
109 |
+
|
110 |
+
def _embed_context(self, row):
|
111 |
+
p = row["text"]
|
112 |
+
|
113 |
+
match self.lm:
|
114 |
+
case "dpr":
|
115 |
+
tok = self.ctx_tokenizer(
|
116 |
+
p, return_tensors="pt", truncation=True)
|
117 |
+
enc = self.ctx_encoder(**tok)[0][0].numpy()
|
118 |
+
return {"embeddings": enc}
|
119 |
+
case "longformer":
|
120 |
+
tok = self.ctx_tokenizer(p, return_tensors="pt")
|
121 |
+
enc = self.ctx_encoder(**tok).last_hidden_state[0][0].numpy()
|
122 |
+
return {"embeddings": enc}
|
123 |
+
|
124 |
def _init_index(
|
125 |
self,
|
126 |
force_new_embedding: bool = False):
|
|
|
133 |
'embeddings', self.embedding_path) # type: ignore
|
134 |
return ds
|
135 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
# Add FAISS embeddings
|
137 |
+
index = ds.map(self._embed_context) # type: ignore
|
138 |
|
139 |
index.add_faiss_index(column="embeddings")
|
140 |
|
|
|
147 |
|
148 |
@timeit("faissretriever.retrieve")
|
149 |
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
|
150 |
+
question_embedding = self._embed_question(query)
|
|
|
|
|
|
|
|
|
|
|
151 |
scores, results = self.index.get_nearest_examples(
|
152 |
"embeddings", question_embedding, k=k
|
153 |
)
|