Ramon Meffert
commited on
Commit
·
0157dfd
1
Parent(s):
be1f224
Fix timings and add timing results
Browse files- .env.example +2 -2
- main.py +69 -29
- poetry.lock +56 -9
- pyproject.toml +1 -0
- query.py +6 -6
- results/timings.csv +60 -0
- src/reader.py +0 -2
- src/readers/dpr_reader.py +4 -0
- src/readers/longformer_reader.py +7 -4
- src/retrievers/es_retriever.py +13 -3
- src/retrievers/faiss_retriever.py +7 -11
- src/utils/log.py +19 -18
- src/utils/timing.py +3 -6
.env.example
CHANGED
@@ -3,6 +3,6 @@ ELASTIC_PASSWORD=<password>
|
|
3 |
ELASTIC_HOST=https://localhost:9200
|
4 |
|
5 |
LOG_LEVEL=INFO
|
6 |
-
TRANSFORMERS_NO_ADVISORY_WARNINGS
|
7 |
-
|
8 |
ENABLE_TIMING=TRUE
|
|
|
3 |
ELASTIC_HOST=https://localhost:9200
|
4 |
|
5 |
LOG_LEVEL=INFO
|
6 |
+
TRANSFORMERS_NO_ADVISORY_WARNINGS=true
|
7 |
+
KMP_DUPLICATE_LIB_OK=true
|
8 |
ENABLE_TIMING=TRUE
|
main.py
CHANGED
@@ -1,25 +1,34 @@
|
|
1 |
-
import
|
2 |
-
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
import
|
5 |
-
import
|
6 |
from datasets import DatasetDict, load_dataset
|
7 |
-
from dotenv import load_dotenv
|
8 |
-
from query import print_answers
|
9 |
|
|
|
10 |
from src.evaluation import evaluate
|
11 |
from src.readers.dpr_reader import DprReader
|
|
|
12 |
from src.retrievers.base_retriever import Retriever
|
13 |
from src.retrievers.es_retriever import ESRetriever
|
14 |
-
from src.retrievers.faiss_retriever import
|
15 |
-
|
|
|
|
|
|
|
16 |
from src.utils.preprocessing import context_to_reader_input
|
17 |
from src.utils.timing import get_times, timeit
|
18 |
|
19 |
-
logger = get_logger()
|
20 |
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
|
24 |
if __name__ == '__main__':
|
25 |
dataset_name = "GroNLP/ik-nlp-22_slp"
|
@@ -28,41 +37,72 @@ if __name__ == '__main__':
|
|
28 |
questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
|
29 |
|
30 |
# Only doing a few questions for speed
|
31 |
-
subset_idx =
|
32 |
questions_test = questions["test"][:subset_idx]
|
33 |
|
34 |
-
experiments: Dict[str,
|
35 |
-
"
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
}
|
38 |
|
39 |
-
for experiment_name,
|
40 |
-
|
41 |
-
|
42 |
for idx in range(subset_idx):
|
43 |
question = questions_test["question"][idx]
|
44 |
answer = questions_test["answer"][idx]
|
45 |
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
reader_input = context_to_reader_input(context)
|
48 |
|
49 |
-
# workaround so we can use the decorator with a dynamic name for
|
50 |
-
|
51 |
-
answers =
|
52 |
|
53 |
# Calculate softmaxed scores for readable output
|
54 |
-
sm = torch.nn.Softmax(dim=0)
|
55 |
-
document_scores = sm(torch.Tensor(
|
56 |
-
|
57 |
-
span_scores = sm(torch.Tensor(
|
58 |
-
|
59 |
|
60 |
-
print_answers(answers, scores, context)
|
61 |
|
62 |
# TODO evaluation and storing of results
|
|
|
63 |
|
64 |
times = get_times()
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
66 |
# TODO evaluation and storing of results
|
67 |
|
68 |
# # Initialize retriever
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
# needs to happen as very first thing, otherwise HF ignores env vars
|
3 |
+
load_dotenv()
|
4 |
+
|
5 |
+
import os
|
6 |
+
import pandas as pd
|
7 |
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Dict, cast
|
10 |
from datasets import DatasetDict, load_dataset
|
|
|
|
|
11 |
|
12 |
+
from src.readers.base_reader import Reader
|
13 |
from src.evaluation import evaluate
|
14 |
from src.readers.dpr_reader import DprReader
|
15 |
+
from src.readers.longformer_reader import LongformerReader
|
16 |
from src.retrievers.base_retriever import Retriever
|
17 |
from src.retrievers.es_retriever import ESRetriever
|
18 |
+
from src.retrievers.faiss_retriever import (
|
19 |
+
FaissRetriever,
|
20 |
+
FaissRetrieverOptions
|
21 |
+
)
|
22 |
+
from src.utils.log import logger
|
23 |
from src.utils.preprocessing import context_to_reader_input
|
24 |
from src.utils.timing import get_times, timeit
|
25 |
|
|
|
26 |
|
27 |
+
@dataclass
|
28 |
+
class Experiment:
|
29 |
+
retriever: Retriever
|
30 |
+
reader: Reader
|
31 |
+
|
32 |
|
33 |
if __name__ == '__main__':
|
34 |
dataset_name = "GroNLP/ik-nlp-22_slp"
|
|
|
37 |
questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
|
38 |
|
39 |
# Only doing a few questions for speed
|
40 |
+
subset_idx = len(questions["test"])
|
41 |
questions_test = questions["test"][:subset_idx]
|
42 |
|
43 |
+
experiments: Dict[str, Experiment] = {
|
44 |
+
"faiss_dpr": Experiment(
|
45 |
+
retriever=FaissRetriever(
|
46 |
+
paragraphs,
|
47 |
+
FaissRetrieverOptions.dpr("./src/models/dpr.faiss")),
|
48 |
+
reader=DprReader()
|
49 |
+
),
|
50 |
+
"faiss_longformer": Experiment(
|
51 |
+
retriever=FaissRetriever(
|
52 |
+
paragraphs,
|
53 |
+
FaissRetrieverOptions.longformer("./src/models/longformer.faiss")),
|
54 |
+
reader=LongformerReader()
|
55 |
+
),
|
56 |
+
"es_dpr": Experiment(
|
57 |
+
retriever=ESRetriever(paragraphs),
|
58 |
+
reader=DprReader()
|
59 |
+
),
|
60 |
+
"es_longformer": Experiment(
|
61 |
+
retriever=ESRetriever(paragraphs),
|
62 |
+
reader=LongformerReader()
|
63 |
+
),
|
64 |
}
|
65 |
|
66 |
+
for experiment_name, experiment in experiments.items():
|
67 |
+
logger.info(f"Running experiment {experiment_name}...")
|
|
|
68 |
for idx in range(subset_idx):
|
69 |
question = questions_test["question"][idx]
|
70 |
answer = questions_test["answer"][idx]
|
71 |
|
72 |
+
retrieve_timer = timeit(f"{experiment_name}.retrieve")
|
73 |
+
t_retrieve = retrieve_timer(experiment.retriever.retrieve)
|
74 |
+
|
75 |
+
read_timer = timeit(f"{experiment_name}.read")
|
76 |
+
t_read = read_timer(experiment.reader.read)
|
77 |
+
|
78 |
+
print(f"\x1b[1K\r[{idx+1:03}] - \"{question}\"", end='')
|
79 |
+
|
80 |
+
scores, context = t_retrieve(question, 5)
|
81 |
reader_input = context_to_reader_input(context)
|
82 |
|
83 |
+
# workaround so we can use the decorator with a dynamic name for
|
84 |
+
# time recording
|
85 |
+
answers = t_read(question, reader_input, 5)
|
86 |
|
87 |
# Calculate softmaxed scores for readable output
|
88 |
+
# sm = torch.nn.Softmax(dim=0)
|
89 |
+
# document_scores = sm(torch.Tensor(
|
90 |
+
# [pred.relevance_score for pred in answers]))
|
91 |
+
# span_scores = sm(torch.Tensor(
|
92 |
+
# [pred.span_score for pred in answers]))
|
93 |
|
94 |
+
# print_answers(answers, scores, context)
|
95 |
|
96 |
# TODO evaluation and storing of results
|
97 |
+
print()
|
98 |
|
99 |
times = get_times()
|
100 |
+
|
101 |
+
df = pd.DataFrame(times)
|
102 |
+
os.makedirs("./results/", exist_ok=True)
|
103 |
+
df.to_csv("./results/timings.csv")
|
104 |
+
|
105 |
+
|
106 |
# TODO evaluation and storing of results
|
107 |
|
108 |
# # Initialize retriever
|
poetry.lock
CHANGED
@@ -212,6 +212,20 @@ category = "main"
|
|
212 |
optional = false
|
213 |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
[[package]]
|
216 |
name = "cryptography"
|
217 |
version = "36.0.2"
|
@@ -266,13 +280,13 @@ xxhash = "*"
|
|
266 |
apache-beam = ["apache-beam (>=2.26.0)"]
|
267 |
audio = ["librosa"]
|
268 |
benchmarks = ["numpy (==1.18.5)", "tensorflow (==2.3.0)", "torch (==1.6.0)", "transformers (==3.0.2)"]
|
269 |
-
dev = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore", "boto3", "botocore", "faiss-cpu (>=1.6.4)", "fsspec", "moto[s3
|
270 |
docs = ["docutils (==0.16.0)", "recommonmark", "sphinx (==3.1.2)", "sphinx-markdown-tables", "sphinx-rtd-theme (==0.4.3)", "sphinxext-opengraph (==0.4.1)", "sphinx-copybutton", "fsspec (<2021.9.0)", "s3fs", "sphinx-panels", "sphinx-inline-tabs", "myst-parser", "Markdown (!=3.3.5)"]
|
271 |
quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)"]
|
272 |
s3 = ["fsspec", "boto3", "botocore", "s3fs"]
|
273 |
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)"]
|
274 |
tensorflow_gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
|
275 |
-
tests = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore", "boto3", "botocore", "faiss-cpu (>=1.6.4)", "fsspec", "moto[s3
|
276 |
torch = ["torch"]
|
277 |
vision = ["Pillow (>=6.2.1)"]
|
278 |
|
@@ -439,7 +453,7 @@ python-versions = ">=3.7"
|
|
439 |
|
440 |
[[package]]
|
441 |
name = "fsspec"
|
442 |
-
version = "2022.
|
443 |
description = "File-system specification"
|
444 |
category = "main"
|
445 |
optional = false
|
@@ -470,10 +484,11 @@ s3 = ["s3fs"]
|
|
470 |
sftp = ["paramiko"]
|
471 |
smb = ["smbprotocol"]
|
472 |
ssh = ["paramiko"]
|
|
|
473 |
|
474 |
[[package]]
|
475 |
name = "gradio"
|
476 |
-
version = "2.9.
|
477 |
description = "Python library for easily interacting with trained machine learning models"
|
478 |
category = "main"
|
479 |
optional = false
|
@@ -529,6 +544,17 @@ all = ["pytest", "datasets", "black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3
|
|
529 |
dev = ["pytest", "datasets", "black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
|
530 |
quality = ["black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
|
531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
[[package]]
|
533 |
name = "idna"
|
534 |
version = "3.3"
|
@@ -1099,6 +1125,14 @@ python-versions = ">=3.6"
|
|
1099 |
[package.extras]
|
1100 |
diagrams = ["jinja2", "railroad-diagrams"]
|
1101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1102 |
[[package]]
|
1103 |
name = "python-dateutil"
|
1104 |
version = "2.8.2"
|
@@ -1498,7 +1532,7 @@ multidict = ">=4.0"
|
|
1498 |
[metadata]
|
1499 |
lock-version = "1.1"
|
1500 |
python-versions = "^3.8"
|
1501 |
-
content-hash = "
|
1502 |
|
1503 |
[metadata.files]
|
1504 |
aiohttp = [
|
@@ -1699,6 +1733,10 @@ colorama = [
|
|
1699 |
{file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"},
|
1700 |
{file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"},
|
1701 |
]
|
|
|
|
|
|
|
|
|
1702 |
cryptography = [
|
1703 |
{file = "cryptography-36.0.2-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:4e2dddd38a5ba733be6a025a1475a9f45e4e41139d1321f412c6b360b19070b6"},
|
1704 |
{file = "cryptography-36.0.2-cp36-abi3-macosx_10_10_x86_64.whl", hash = "sha256:4881d09298cd0b669bb15b9cfe6166f16fc1277b4ed0d04a22f3d6430cb30f1d"},
|
@@ -1880,12 +1918,12 @@ frozenlist = [
|
|
1880 |
{file = "frozenlist-1.3.0.tar.gz", hash = "sha256:ce6f2ba0edb7b0c1d8976565298ad2deba6f8064d2bebb6ffce2ca896eb35b0b"},
|
1881 |
]
|
1882 |
fsspec = [
|
1883 |
-
{file = "fsspec-2022.
|
1884 |
-
{file = "fsspec-2022.
|
1885 |
]
|
1886 |
gradio = [
|
1887 |
-
{file = "gradio-2.9.
|
1888 |
-
{file = "gradio-2.9.
|
1889 |
]
|
1890 |
h11 = [
|
1891 |
{file = "h11-0.13.0-py3-none-any.whl", hash = "sha256:8ddd78563b633ca55346c8cd41ec0af27d3c79931828beffb46ce70a379e7442"},
|
@@ -1895,6 +1933,10 @@ huggingface-hub = [
|
|
1895 |
{file = "huggingface_hub-0.4.0-py3-none-any.whl", hash = "sha256:808021af1ce1111104973ae54d81738eaf40be6d1e82fc6bdedb82f81c6206e7"},
|
1896 |
{file = "huggingface_hub-0.4.0.tar.gz", hash = "sha256:f0e3389f8988eb7781b17de520ae7fd0aa50d9823534e3ae55344d943a88ac87"},
|
1897 |
]
|
|
|
|
|
|
|
|
|
1898 |
idna = [
|
1899 |
{file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"},
|
1900 |
{file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"},
|
@@ -2193,6 +2235,7 @@ numpy = [
|
|
2193 |
{file = "numpy-1.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8251ed96f38b47b4295b1ae51631de7ffa8260b5b087808ef09a39a9d66c97ab"},
|
2194 |
{file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a3aecd3b997bf452a2dedb11f4e79bc5bfd21a1d4cc760e703c31d57c84b3e"},
|
2195 |
{file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3bae1a2ed00e90b3ba5f7bd0a7c7999b55d609e0c54ceb2b076a25e345fa9f4"},
|
|
|
2196 |
{file = "numpy-1.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:08d9b008d0156c70dc392bb3ab3abb6e7a711383c3247b410b39962263576cd4"},
|
2197 |
{file = "numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:201b4d0552831f7250a08d3b38de0d989d6f6e4658b709a02a73c524ccc6ffce"},
|
2198 |
{file = "numpy-1.22.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8c1f39caad2c896bc0018f699882b345b2a63708008be29b1f355ebf6f933fe"},
|
@@ -2510,6 +2553,10 @@ pyparsing = [
|
|
2510 |
{file = "pyparsing-3.0.7-py3-none-any.whl", hash = "sha256:a6c06a88f252e6c322f65faf8f418b16213b51bdfaece0524c1c1bc30c63c484"},
|
2511 |
{file = "pyparsing-3.0.7.tar.gz", hash = "sha256:18ee9022775d270c55187733956460083db60b37d0d0fb357445f3094eed3eea"},
|
2512 |
]
|
|
|
|
|
|
|
|
|
2513 |
python-dateutil = [
|
2514 |
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
|
2515 |
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
|
|
|
212 |
optional = false
|
213 |
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
214 |
|
215 |
+
[[package]]
|
216 |
+
name = "coloredlogs"
|
217 |
+
version = "15.0.1"
|
218 |
+
description = "Colored terminal output for Python's logging module"
|
219 |
+
category = "main"
|
220 |
+
optional = false
|
221 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
222 |
+
|
223 |
+
[package.dependencies]
|
224 |
+
humanfriendly = ">=9.1"
|
225 |
+
|
226 |
+
[package.extras]
|
227 |
+
cron = ["capturer (>=2.4)"]
|
228 |
+
|
229 |
[[package]]
|
230 |
name = "cryptography"
|
231 |
version = "36.0.2"
|
|
|
280 |
apache-beam = ["apache-beam (>=2.26.0)"]
|
281 |
audio = ["librosa"]
|
282 |
benchmarks = ["numpy (==1.18.5)", "tensorflow (==2.3.0)", "torch (==1.6.0)", "transformers (==3.0.2)"]
|
283 |
+
dev = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore", "boto3", "botocore", "faiss-cpu (>=1.6.4)", "fsspec", "moto[server,s3] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "bert-score (>=0.3.6)", "rouge-score", "sacrebleu", "scipy", "seqeval", "scikit-learn", "jiwer", "sentencepiece", "torchmetrics (==0.6.0)", "mauve-text", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "wget (>=3.2)", "pytorch-nlp (==0.5.0)", "pytorch-lightning", "fastBPE (==0.1.0)", "fairseq", "black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)", "importlib-resources"]
|
284 |
docs = ["docutils (==0.16.0)", "recommonmark", "sphinx (==3.1.2)", "sphinx-markdown-tables", "sphinx-rtd-theme (==0.4.3)", "sphinxext-opengraph (==0.4.1)", "sphinx-copybutton", "fsspec (<2021.9.0)", "s3fs", "sphinx-panels", "sphinx-inline-tabs", "myst-parser", "Markdown (!=3.3.5)"]
|
285 |
quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)"]
|
286 |
s3 = ["fsspec", "boto3", "botocore", "s3fs"]
|
287 |
tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)"]
|
288 |
tensorflow_gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
|
289 |
+
tests = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore", "boto3", "botocore", "faiss-cpu (>=1.6.4)", "fsspec", "moto[server,s3] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "bert-score (>=0.3.6)", "rouge-score", "sacrebleu", "scipy", "seqeval", "scikit-learn", "jiwer", "sentencepiece", "torchmetrics (==0.6.0)", "mauve-text", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "wget (>=3.2)", "pytorch-nlp (==0.5.0)", "pytorch-lightning", "fastBPE (==0.1.0)", "fairseq", "importlib-resources"]
|
290 |
torch = ["torch"]
|
291 |
vision = ["Pillow (>=6.2.1)"]
|
292 |
|
|
|
453 |
|
454 |
[[package]]
|
455 |
name = "fsspec"
|
456 |
+
version = "2022.3.0"
|
457 |
description = "File-system specification"
|
458 |
category = "main"
|
459 |
optional = false
|
|
|
484 |
sftp = ["paramiko"]
|
485 |
smb = ["smbprotocol"]
|
486 |
ssh = ["paramiko"]
|
487 |
+
tqdm = ["tqdm"]
|
488 |
|
489 |
[[package]]
|
490 |
name = "gradio"
|
491 |
+
version = "2.9.1"
|
492 |
description = "Python library for easily interacting with trained machine learning models"
|
493 |
category = "main"
|
494 |
optional = false
|
|
|
544 |
dev = ["pytest", "datasets", "black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
|
545 |
quality = ["black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
|
546 |
|
547 |
+
[[package]]
|
548 |
+
name = "humanfriendly"
|
549 |
+
version = "10.0"
|
550 |
+
description = "Human friendly output for text interfaces using Python"
|
551 |
+
category = "main"
|
552 |
+
optional = false
|
553 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
554 |
+
|
555 |
+
[package.dependencies]
|
556 |
+
pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""}
|
557 |
+
|
558 |
[[package]]
|
559 |
name = "idna"
|
560 |
version = "3.3"
|
|
|
1125 |
[package.extras]
|
1126 |
diagrams = ["jinja2", "railroad-diagrams"]
|
1127 |
|
1128 |
+
[[package]]
|
1129 |
+
name = "pyreadline3"
|
1130 |
+
version = "3.4.1"
|
1131 |
+
description = "A python implementation of GNU readline."
|
1132 |
+
category = "main"
|
1133 |
+
optional = false
|
1134 |
+
python-versions = "*"
|
1135 |
+
|
1136 |
[[package]]
|
1137 |
name = "python-dateutil"
|
1138 |
version = "2.8.2"
|
|
|
1532 |
[metadata]
|
1533 |
lock-version = "1.1"
|
1534 |
python-versions = "^3.8"
|
1535 |
+
content-hash = "881ba67f914b3c0690bcb34810061252ee77cebc0dac49b5ae76348394d810a8"
|
1536 |
|
1537 |
[metadata.files]
|
1538 |
aiohttp = [
|
|
|
1733 |
{file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"},
|
1734 |
{file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"},
|
1735 |
]
|
1736 |
+
coloredlogs = [
|
1737 |
+
{file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"},
|
1738 |
+
{file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"},
|
1739 |
+
]
|
1740 |
cryptography = [
|
1741 |
{file = "cryptography-36.0.2-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:4e2dddd38a5ba733be6a025a1475a9f45e4e41139d1321f412c6b360b19070b6"},
|
1742 |
{file = "cryptography-36.0.2-cp36-abi3-macosx_10_10_x86_64.whl", hash = "sha256:4881d09298cd0b669bb15b9cfe6166f16fc1277b4ed0d04a22f3d6430cb30f1d"},
|
|
|
1918 |
{file = "frozenlist-1.3.0.tar.gz", hash = "sha256:ce6f2ba0edb7b0c1d8976565298ad2deba6f8064d2bebb6ffce2ca896eb35b0b"},
|
1919 |
]
|
1920 |
fsspec = [
|
1921 |
+
{file = "fsspec-2022.3.0-py3-none-any.whl", hash = "sha256:a53491b003210fce6911dd8f2d37e20c41a27ce52a655eef11b885d1578ed4cf"},
|
1922 |
+
{file = "fsspec-2022.3.0.tar.gz", hash = "sha256:fd582cc4aa0db5968bad9317cae513450eddd08b2193c4428d9349265a995523"},
|
1923 |
]
|
1924 |
gradio = [
|
1925 |
+
{file = "gradio-2.9.1-py3-none-any.whl", hash = "sha256:877616dcda82e0e13bc04404c13f084c7b3a06cccc314a4db06b21c5f15f6190"},
|
1926 |
+
{file = "gradio-2.9.1.tar.gz", hash = "sha256:d9dfde81f064f38bcd95967316501ab40698fec0bcc4435dd00ea4578f695042"},
|
1927 |
]
|
1928 |
h11 = [
|
1929 |
{file = "h11-0.13.0-py3-none-any.whl", hash = "sha256:8ddd78563b633ca55346c8cd41ec0af27d3c79931828beffb46ce70a379e7442"},
|
|
|
1933 |
{file = "huggingface_hub-0.4.0-py3-none-any.whl", hash = "sha256:808021af1ce1111104973ae54d81738eaf40be6d1e82fc6bdedb82f81c6206e7"},
|
1934 |
{file = "huggingface_hub-0.4.0.tar.gz", hash = "sha256:f0e3389f8988eb7781b17de520ae7fd0aa50d9823534e3ae55344d943a88ac87"},
|
1935 |
]
|
1936 |
+
humanfriendly = [
|
1937 |
+
{file = "humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477"},
|
1938 |
+
{file = "humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc"},
|
1939 |
+
]
|
1940 |
idna = [
|
1941 |
{file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"},
|
1942 |
{file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"},
|
|
|
2235 |
{file = "numpy-1.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8251ed96f38b47b4295b1ae51631de7ffa8260b5b087808ef09a39a9d66c97ab"},
|
2236 |
{file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a3aecd3b997bf452a2dedb11f4e79bc5bfd21a1d4cc760e703c31d57c84b3e"},
|
2237 |
{file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3bae1a2ed00e90b3ba5f7bd0a7c7999b55d609e0c54ceb2b076a25e345fa9f4"},
|
2238 |
+
{file = "numpy-1.22.3-cp310-cp310-win32.whl", hash = "sha256:f950f8845b480cffe522913d35567e29dd381b0dc7e4ce6a4a9f9156417d2430"},
|
2239 |
{file = "numpy-1.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:08d9b008d0156c70dc392bb3ab3abb6e7a711383c3247b410b39962263576cd4"},
|
2240 |
{file = "numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:201b4d0552831f7250a08d3b38de0d989d6f6e4658b709a02a73c524ccc6ffce"},
|
2241 |
{file = "numpy-1.22.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8c1f39caad2c896bc0018f699882b345b2a63708008be29b1f355ebf6f933fe"},
|
|
|
2553 |
{file = "pyparsing-3.0.7-py3-none-any.whl", hash = "sha256:a6c06a88f252e6c322f65faf8f418b16213b51bdfaece0524c1c1bc30c63c484"},
|
2554 |
{file = "pyparsing-3.0.7.tar.gz", hash = "sha256:18ee9022775d270c55187733956460083db60b37d0d0fb357445f3094eed3eea"},
|
2555 |
]
|
2556 |
+
pyreadline3 = [
|
2557 |
+
{file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"},
|
2558 |
+
{file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"},
|
2559 |
+
]
|
2560 |
python-dateutil = [
|
2561 |
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
|
2562 |
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
|
pyproject.toml
CHANGED
@@ -15,6 +15,7 @@ python-dotenv = "^0.19.2"
|
|
15 |
elasticsearch = "^8.1.0"
|
16 |
gradio = {extras = ["Jinja2"], version = "^2.9.0"}
|
17 |
Jinja2 = "^3.1.1"
|
|
|
18 |
|
19 |
[tool.poetry.dev-dependencies]
|
20 |
flake8 = "^4.0.1"
|
|
|
15 |
elasticsearch = "^8.1.0"
|
16 |
gradio = {extras = ["Jinja2"], version = "^2.9.0"}
|
17 |
Jinja2 = "^3.1.1"
|
18 |
+
coloredlogs = "^15.0.1"
|
19 |
|
20 |
[tool.poetry.dev-dependencies]
|
21 |
flake8 = "^4.0.1"
|
query.py
CHANGED
@@ -16,7 +16,12 @@ from src.retrievers.faiss_retriever import (
|
|
16 |
FaissRetrieverOptions
|
17 |
)
|
18 |
from src.utils.preprocessing import context_to_reader_input
|
19 |
-
from src.utils.log import
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
def get_retriever(paragraphs: DatasetDict,
|
@@ -123,11 +128,6 @@ def main(args: argparse.Namespace):
|
|
123 |
|
124 |
|
125 |
if __name__ == "__main__":
|
126 |
-
# Setup environment
|
127 |
-
load_dotenv()
|
128 |
-
logger = get_logger()
|
129 |
-
transformers.logging.set_verbosity_error()
|
130 |
-
|
131 |
# Set up CLI arguments
|
132 |
parser = argparse.ArgumentParser(
|
133 |
formatter_class=argparse.MetavarTypeHelpFormatter
|
|
|
16 |
FaissRetrieverOptions
|
17 |
)
|
18 |
from src.utils.preprocessing import context_to_reader_input
|
19 |
+
from src.utils.log import logger
|
20 |
+
|
21 |
+
|
22 |
+
# Setup environment
|
23 |
+
load_dotenv()
|
24 |
+
transformers.logging.set_verbosity_error()
|
25 |
|
26 |
|
27 |
def get_retriever(paragraphs: DatasetDict,
|
|
|
128 |
|
129 |
|
130 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
131 |
# Set up CLI arguments
|
132 |
parser = argparse.ArgumentParser(
|
133 |
formatter_class=argparse.MetavarTypeHelpFormatter
|
results/timings.csv
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,faiss_dpr.retrieve,faiss_dpr.read,faiss_longformer.retrieve,faiss_longformer.read,es_dpr.retrieve,es_dpr.read,es_longformer.retrieve,es_longformer.read
|
2 |
+
0,0.30384302139282227,4.566400051116943,0.9227948188781738,5.768368244171143,0.01930093765258789,2.7453649044036865,0.010576009750366211,4.998417854309082
|
3 |
+
1,0.04573678970336914,1.9288370609283447,0.8380529880523682,5.916611671447754,0.018373966217041016,1.4845240116119385,0.012102842330932617,5.1692070960998535
|
4 |
+
2,0.04764819145202637,0.6628780364990234,0.7756149768829346,5.4998250007629395,0.015324831008911133,1.7706871032714844,0.012642860412597656,5.202448844909668
|
5 |
+
3,0.04507589340209961,1.219634771347046,0.8142738342285156,5.726102113723755,0.021118879318237305,1.987663984298706,0.012515068054199219,5.1083409786224365
|
6 |
+
4,0.04347515106201172,1.5222840309143066,0.814906120300293,5.672412872314453,0.013732194900512695,1.660247802734375,0.011805057525634766,5.313212156295776
|
7 |
+
5,0.07470989227294922,1.5599188804626465,0.8422539234161377,5.75390100479126,0.018023014068603516,1.5782928466796875,0.013046741485595703,5.419210195541382
|
8 |
+
6,0.06162095069885254,1.4178202152252197,0.7837569713592529,4.765166282653809,0.014074325561523438,0.7626080513000488,0.010712146759033203,4.976129055023193
|
9 |
+
7,0.0451970100402832,1.134779691696167,0.7723889350891113,4.5784592628479,0.012156963348388672,1.959972858428955,0.011015892028808594,4.5342161655426025
|
10 |
+
8,0.03589582443237305,0.8912148475646973,0.8142461776733398,5.212156295776367,0.009800195693969727,1.820624828338623,0.009167194366455078,4.468229293823242
|
11 |
+
9,0.06033587455749512,0.37888431549072266,1.1510162353515625,5.395290851593018,0.015815258026123047,1.0247371196746826,0.010970830917358398,5.1864588260650635
|
12 |
+
10,0.056854963302612305,0.6068317890167236,0.7839999198913574,4.668170928955078,0.013895988464355469,0.8482949733734131,0.011837005615234375,4.493913173675537
|
13 |
+
11,0.04697012901306152,0.341174840927124,0.8089311122894287,5.535298109054565,0.01624298095703125,1.4673452377319336,0.01172780990600586,5.264245986938477
|
14 |
+
12,0.0444178581237793,1.4774150848388672,0.7612121105194092,5.504917860031128,0.01690196990966797,2.42773699760437,0.011591196060180664,4.463428974151611
|
15 |
+
13,0.06343889236450195,0.8000409603118896,0.9072589874267578,6.015661954879761,0.010895252227783203,2.21577787399292,0.012058019638061523,5.302258253097534
|
16 |
+
14,0.05022692680358887,0.9474368095397949,0.8324599266052246,6.4684131145477295,0.016222000122070312,3.1302390098571777,0.010799169540405273,5.240310907363892
|
17 |
+
15,0.08313488960266113,0.746314287185669,0.8373219966888428,6.741006851196289,0.017467975616455078,1.353593111038208,0.020203113555908203,5.192620038986206
|
18 |
+
16,0.17216706275939941,1.136448860168457,0.7760000228881836,5.48329496383667,0.018398046493530273,0.5325403213500977,0.012536287307739258,5.264941215515137
|
19 |
+
17,0.04575324058532715,0.5853927135467529,0.7855441570281982,5.960904121398926,0.019134998321533203,2.6092309951782227,0.012146949768066406,5.331835031509399
|
20 |
+
18,0.04746294021606445,0.9219169616699219,0.9516820907592773,10.146074295043945,0.011114835739135742,2.220487117767334,0.011821985244750977,5.161020994186401
|
21 |
+
19,0.0443730354309082,0.4667840003967285,1.3496372699737549,8.213943719863892,0.01042485237121582,2.841907024383545,0.011536121368408203,5.236280918121338
|
22 |
+
20,0.06004190444946289,0.6129250526428223,1.3677341938018799,7.00742769241333,0.022186994552612305,1.6846930980682373,0.010824918746948242,5.377984046936035
|
23 |
+
21,0.06920814514160156,0.6232960224151611,1.4656860828399658,6.424375057220459,0.011613845825195312,1.0811800956726074,0.014858007431030273,5.279160022735596
|
24 |
+
22,0.04999184608459473,0.6539132595062256,0.8720510005950928,5.889069080352783,0.016654014587402344,1.6599159240722656,0.012172698974609375,5.177525043487549
|
25 |
+
23,0.05750322341918945,1.0169367790222168,0.9728169441223145,6.934185028076172,0.01772904396057129,1.2837882041931152,0.011108160018920898,5.186945199966431
|
26 |
+
24,0.06264281272888184,1.7151312828063965,1.3927390575408936,7.122100114822388,0.016143798828125,1.5387201309204102,0.011415958404541016,4.558846950531006
|
27 |
+
25,0.04831504821777344,0.7839398384094238,1.1007087230682373,5.4652369022369385,0.01099395751953125,1.5678913593292236,0.011976242065429688,4.612828969955444
|
28 |
+
26,0.048091888427734375,0.9228200912475586,0.8567941188812256,4.832158803939819,0.013817787170410156,2.0290918350219727,0.015846967697143555,4.845104217529297
|
29 |
+
27,0.04568672180175781,0.8964569568634033,0.7873432636260986,4.592561960220337,0.010241985321044922,0.3145887851715088,0.014873743057250977,4.759660720825195
|
30 |
+
28,0.04340720176696777,0.5004391670227051,0.8122010231018066,4.68702507019043,0.012717008590698242,0.9207170009613037,0.014780759811401367,4.955734968185425
|
31 |
+
29,0.045496225357055664,2.106112003326416,0.7901277542114258,5.48145604133606,0.009778976440429688,1.1795310974121094,0.011364936828613281,5.397800922393799
|
32 |
+
30,0.05589914321899414,3.3801350593566895,0.7913417816162109,4.76953387260437,0.01169896125793457,2.8297739028930664,0.012899160385131836,4.7149817943573
|
33 |
+
31,0.038469791412353516,1.2037632465362549,0.812114953994751,4.819751977920532,0.010591983795166016,1.0633080005645752,0.011631011962890625,4.603592157363892
|
34 |
+
32,0.043640851974487305,0.7455379962921143,0.7684001922607422,5.490149021148682,0.010446786880493164,1.509342908859253,0.01111912727355957,5.431332111358643
|
35 |
+
33,0.0411829948425293,0.7775781154632568,0.7725949287414551,5.5284202098846436,0.011181116104125977,1.4173851013183594,0.01881098747253418,5.2474939823150635
|
36 |
+
34,0.04268312454223633,1.3576858043670654,0.7971670627593994,5.488955974578857,0.016661882400512695,0.6669139862060547,0.011193990707397461,5.231971263885498
|
37 |
+
35,0.0432438850402832,0.49681520462036133,0.7736399173736572,4.675936698913574,0.013994932174682617,0.7481560707092285,0.01053619384765625,4.871787071228027
|
38 |
+
36,0.038790225982666016,1.9925789833068848,0.7900221347808838,4.716547012329102,0.010754108428955078,0.8104310035705566,0.011471986770629883,4.582187175750732
|
39 |
+
37,0.04674410820007324,0.8766942024230957,0.8192441463470459,5.454381704330444,0.012632131576538086,3.3098862171173096,0.01573491096496582,5.5617289543151855
|
40 |
+
38,0.04983806610107422,0.5784440040588379,0.768744945526123,5.399757146835327,0.017091035842895508,1.0388100147247314,0.020289897918701172,5.327627897262573
|
41 |
+
39,0.039936065673828125,0.9906370639801025,0.7951750755310059,4.816935062408447,0.009315729141235352,0.8949270248413086,0.012948989868164062,4.823601007461548
|
42 |
+
40,0.04812121391296387,5.3651018142700195,0.7833847999572754,4.673122882843018,0.010359048843383789,1.6986067295074463,0.012405872344970703,4.822720050811768
|
43 |
+
41,0.037177085876464844,0.8579537868499756,0.768902063369751,4.705405950546265,0.01087808609008789,1.1154420375823975,0.009827136993408203,5.295310020446777
|
44 |
+
42,0.03615593910217285,0.6045210361480713,0.7767770290374756,4.721595048904419,0.012170076370239258,1.168515920639038,0.014606952667236328,4.778914213180542
|
45 |
+
43,0.04032111167907715,1.0840678215026855,0.8039369583129883,5.5514678955078125,0.011640071868896484,3.7264089584350586,0.015080928802490234,6.431236028671265
|
46 |
+
44,0.12291288375854492,2.7860946655273438,0.7999370098114014,4.700652122497559,0.010669708251953125,3.5256330966949463,0.00997614860534668,5.010454893112183
|
47 |
+
45,0.03981208801269531,0.8575420379638672,0.7781379222869873,4.649600028991699,0.011057853698730469,3.4576022624969482,0.011123895645141602,5.414888143539429
|
48 |
+
46,0.046558380126953125,0.6096041202545166,0.839914083480835,5.3846352100372314,0.0264890193939209,3.282578945159912,0.013241052627563477,6.356001853942871
|
49 |
+
47,0.044730186462402344,0.6428439617156982,0.7774860858917236,5.471776962280273,0.009460926055908203,3.3428800106048584,0.012679100036621094,5.476663112640381
|
50 |
+
48,0.04798007011413574,1.3710291385650635,0.7838289737701416,5.5646140575408936,0.011425018310546875,1.5621020793914795,0.019647836685180664,5.403181076049805
|
51 |
+
49,0.06305599212646484,1.7375829219818115,0.7764248847961426,5.582126140594482,0.010413169860839844,1.6502351760864258,0.011098146438598633,6.15350604057312
|
52 |
+
50,0.04781007766723633,0.919248104095459,0.8292880058288574,4.79367995262146,0.01233983039855957,4.761476755142212,0.01306009292602539,4.901428937911987
|
53 |
+
51,0.04294776916503906,0.9060940742492676,0.7503399848937988,4.69527006149292,0.010550737380981445,1.3250057697296143,0.012276887893676758,4.790279388427734
|
54 |
+
52,0.0449681282043457,0.74688720703125,0.7592051029205322,4.672075986862183,0.010587215423583984,1.7173192501068115,0.012059926986694336,4.9025468826293945
|
55 |
+
53,0.04381895065307617,1.2078793048858643,0.8653321266174316,4.4878456592559814,0.008989810943603516,4.782422065734863,0.012001752853393555,5.331949949264526
|
56 |
+
54,0.06584286689758301,1.0724549293518066,0.7348787784576416,5.094892740249634,0.00992584228515625,1.6900959014892578,0.018785953521728516,6.253833293914795
|
57 |
+
55,0.05147194862365723,2.172264337539673,0.7367160320281982,5.056357145309448,0.009489059448242188,2.6061501502990723,0.011726140975952148,6.203222036361694
|
58 |
+
56,0.05095195770263672,2.0959391593933105,0.7292170524597168,5.11798882484436,0.011085987091064453,1.258976936340332,0.02020883560180664,6.149781942367554
|
59 |
+
57,0.05691885948181152,0.7286462783813477,0.7636628150939941,5.1169517040252686,0.01094675064086914,0.7379579544067383,0.012721061706542969,5.3568830490112305
|
60 |
+
58,0.04192709922790527,0.8154990673065186,0.7308712005615234,5.066887140274048,0.010490894317626953,2.724623203277588,0.01871800422668457,5.368160009384155
|
src/reader.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
class Reader():
|
2 |
-
pass
|
|
|
|
|
|
src/readers/dpr_reader.py
CHANGED
@@ -1,9 +1,13 @@
|
|
1 |
from transformers import DPRReader, DPRReaderTokenizer
|
2 |
from typing import List, Dict, Tuple
|
|
|
3 |
|
4 |
from src.readers.base_reader import Reader
|
5 |
|
6 |
|
|
|
|
|
|
|
7 |
class DprReader(Reader):
|
8 |
def __init__(self) -> None:
|
9 |
self._tokenizer = DPRReaderTokenizer.from_pretrained(
|
|
|
1 |
from transformers import DPRReader, DPRReaderTokenizer
|
2 |
from typing import List, Dict, Tuple
|
3 |
+
from dotenv import load_dotenv
|
4 |
|
5 |
from src.readers.base_reader import Reader
|
6 |
|
7 |
|
8 |
+
load_dotenv()
|
9 |
+
|
10 |
+
|
11 |
class DprReader(Reader):
|
12 |
def __init__(self) -> None:
|
13 |
self._tokenizer = DPRReaderTokenizer.from_pretrained(
|
src/readers/longformer_reader.py
CHANGED
@@ -1,17 +1,21 @@
|
|
1 |
import torch
|
2 |
from transformers import (
|
3 |
-
|
4 |
LongformerForQuestionAnswering
|
5 |
)
|
6 |
from typing import List, Dict, Tuple
|
|
|
7 |
|
8 |
from src.readers.base_reader import Reader
|
9 |
|
10 |
|
|
|
|
|
|
|
11 |
class LongformerReader(Reader):
|
12 |
def __init__(self) -> None:
|
13 |
checkpoint = "valhalla/longformer-base-4096-finetuned-squadv1"
|
14 |
-
self.tokenizer =
|
15 |
self.model = LongformerForQuestionAnswering.from_pretrained(checkpoint)
|
16 |
|
17 |
def read(self,
|
@@ -21,8 +25,7 @@ class LongformerReader(Reader):
|
|
21 |
answers = []
|
22 |
|
23 |
for text in context['texts']:
|
24 |
-
encoding = self.tokenizer(
|
25 |
-
query, text, return_tensors="pt")
|
26 |
input_ids = encoding["input_ids"]
|
27 |
attention_mask = encoding["attention_mask"]
|
28 |
outputs = self.model(input_ids, attention_mask=attention_mask)
|
|
|
1 |
import torch
|
2 |
from transformers import (
|
3 |
+
LongformerTokenizer,
|
4 |
LongformerForQuestionAnswering
|
5 |
)
|
6 |
from typing import List, Dict, Tuple
|
7 |
+
from dotenv import load_dotenv
|
8 |
|
9 |
from src.readers.base_reader import Reader
|
10 |
|
11 |
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
|
15 |
class LongformerReader(Reader):
|
16 |
def __init__(self) -> None:
|
17 |
checkpoint = "valhalla/longformer-base-4096-finetuned-squadv1"
|
18 |
+
self.tokenizer = LongformerTokenizer.from_pretrained(checkpoint)
|
19 |
self.model = LongformerForQuestionAnswering.from_pretrained(checkpoint)
|
20 |
|
21 |
def read(self,
|
|
|
25 |
answers = []
|
26 |
|
27 |
for text in context['texts']:
|
28 |
+
encoding = self.tokenizer(query, text, return_tensors="pt")
|
|
|
29 |
input_ids = encoding["input_ids"]
|
30 |
attention_mask = encoding["attention_mask"]
|
31 |
outputs = self.model(input_ids, attention_mask=attention_mask)
|
src/retrievers/es_retriever.py
CHANGED
@@ -1,13 +1,17 @@
|
|
|
|
1 |
import os
|
2 |
|
3 |
from datasets import DatasetDict
|
4 |
from elasticsearch import Elasticsearch
|
|
|
|
|
5 |
|
6 |
from src.retrievers.base_retriever import RetrieveType, Retriever
|
7 |
-
from src.utils.log import
|
8 |
from src.utils.timing import timeit
|
9 |
|
10 |
-
|
|
|
11 |
|
12 |
|
13 |
class ESRetriever(Retriever):
|
@@ -23,6 +27,13 @@ class ESRetriever(Retriever):
|
|
23 |
http_auth=(es_username, es_password),
|
24 |
ca_certs="./http_ca.crt")
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
if self.client.indices.exists(index="paragraphs"):
|
27 |
self.paragraphs.load_elasticsearch_index(
|
28 |
"paragraphs", es_index_name="paragraphs",
|
@@ -34,6 +45,5 @@ class ESRetriever(Retriever):
|
|
34 |
es_index_name="paragraphs",
|
35 |
es_client=self.client)
|
36 |
|
37 |
-
@timeit("esretriever.retrieve")
|
38 |
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
|
39 |
return self.paragraphs.get_nearest_examples("paragraphs", query, k)
|
|
|
1 |
+
import imp
|
2 |
import os
|
3 |
|
4 |
from datasets import DatasetDict
|
5 |
from elasticsearch import Elasticsearch
|
6 |
+
from elastic_transport import ConnectionError
|
7 |
+
from dotenv import load_dotenv
|
8 |
|
9 |
from src.retrievers.base_retriever import RetrieveType, Retriever
|
10 |
+
from src.utils.log import logger
|
11 |
from src.utils.timing import timeit
|
12 |
|
13 |
+
|
14 |
+
load_dotenv()
|
15 |
|
16 |
|
17 |
class ESRetriever(Retriever):
|
|
|
27 |
http_auth=(es_username, es_password),
|
28 |
ca_certs="./http_ca.crt")
|
29 |
|
30 |
+
try:
|
31 |
+
self.client.info()
|
32 |
+
except ConnectionError:
|
33 |
+
logger.error("Could not connect to ElasticSearch. " +
|
34 |
+
"Make sure it is running. Exiting now...")
|
35 |
+
exit()
|
36 |
+
|
37 |
if self.client.indices.exists(index="paragraphs"):
|
38 |
self.paragraphs.load_elasticsearch_index(
|
39 |
"paragraphs", es_index_name="paragraphs",
|
|
|
45 |
es_index_name="paragraphs",
|
46 |
es_client=self.client)
|
47 |
|
|
|
48 |
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
|
49 |
return self.paragraphs.get_nearest_examples("paragraphs", query, k)
|
src/retrievers/faiss_retriever.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import os.path
|
3 |
import torch
|
4 |
|
|
|
5 |
from datasets import DatasetDict
|
6 |
from dataclasses import dataclass
|
7 |
from transformers import (
|
@@ -10,22 +11,18 @@ from transformers import (
|
|
10 |
DPRQuestionEncoder,
|
11 |
DPRQuestionEncoderTokenizerFast,
|
12 |
LongformerModel,
|
13 |
-
|
14 |
)
|
15 |
from transformers.modeling_utils import PreTrainedModel
|
16 |
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
17 |
|
18 |
from src.retrievers.base_retriever import RetrieveType, Retriever
|
19 |
-
from src.utils.log import
|
20 |
from src.utils.preprocessing import remove_formulas
|
21 |
from src.utils.timing import timeit
|
22 |
|
23 |
-
# Hacky fix for FAISS error on macOS
|
24 |
-
# See https://stackoverflow.com/a/63374568/4545692
|
25 |
-
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
26 |
|
27 |
-
|
28 |
-
logger = get_logger()
|
29 |
|
30 |
|
31 |
@dataclass
|
@@ -59,10 +56,10 @@ class FaissRetrieverOptions:
|
|
59 |
@staticmethod
|
60 |
def longformer(embedding_path: str):
|
61 |
encoder = LongformerModel.from_pretrained(
|
62 |
-
"
|
63 |
)
|
64 |
-
tokenizer =
|
65 |
-
"
|
66 |
)
|
67 |
return FaissRetrieverOptions(
|
68 |
ctx_encoder=encoder,
|
@@ -145,7 +142,6 @@ class FaissRetriever(Retriever):
|
|
145 |
|
146 |
return index
|
147 |
|
148 |
-
@timeit("faissretriever.retrieve")
|
149 |
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
|
150 |
question_embedding = self._embed_question(query)
|
151 |
scores, results = self.index.get_nearest_examples(
|
|
|
2 |
import os.path
|
3 |
import torch
|
4 |
|
5 |
+
from dotenv import load_dotenv
|
6 |
from datasets import DatasetDict
|
7 |
from dataclasses import dataclass
|
8 |
from transformers import (
|
|
|
11 |
DPRQuestionEncoder,
|
12 |
DPRQuestionEncoderTokenizerFast,
|
13 |
LongformerModel,
|
14 |
+
LongformerTokenizer
|
15 |
)
|
16 |
from transformers.modeling_utils import PreTrainedModel
|
17 |
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
18 |
|
19 |
from src.retrievers.base_retriever import RetrieveType, Retriever
|
20 |
+
from src.utils.log import logger
|
21 |
from src.utils.preprocessing import remove_formulas
|
22 |
from src.utils.timing import timeit
|
23 |
|
|
|
|
|
|
|
24 |
|
25 |
+
load_dotenv()
|
|
|
26 |
|
27 |
|
28 |
@dataclass
|
|
|
56 |
@staticmethod
|
57 |
def longformer(embedding_path: str):
|
58 |
encoder = LongformerModel.from_pretrained(
|
59 |
+
"valhalla/longformer-base-4096-finetuned-squadv1"
|
60 |
)
|
61 |
+
tokenizer = LongformerTokenizer.from_pretrained(
|
62 |
+
"valhalla/longformer-base-4096-finetuned-squadv1"
|
63 |
)
|
64 |
return FaissRetrieverOptions(
|
65 |
ctx_encoder=encoder,
|
|
|
142 |
|
143 |
return index
|
144 |
|
|
|
145 |
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
|
146 |
question_embedding = self._embed_question(query)
|
147 |
scores, results = self.index.get_nearest_examples(
|
src/utils/log.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
|
@@ -5,27 +6,27 @@ from dotenv import load_dotenv
|
|
5 |
|
6 |
load_dotenv()
|
7 |
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
logger = logging.getLogger("Flashcards")
|
12 |
|
13 |
-
|
14 |
-
|
|
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
fh.setFormatter(formatter)
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
return logger
|
|
|
1 |
+
import coloredlogs
|
2 |
import logging
|
3 |
import os
|
4 |
|
|
|
6 |
|
7 |
load_dotenv()
|
8 |
|
9 |
+
# creates a default logger for the project. We declare it in the global scope
|
10 |
+
# so it acts like a singleton
|
11 |
+
logger = logging.getLogger("Flashcards")
|
12 |
|
13 |
+
log_level = os.getenv("LOG_LEVEL", "INFO")
|
14 |
+
logger.setLevel(log_level)
|
|
|
15 |
|
16 |
+
# Log format
|
17 |
+
formatter = coloredlogs.ColoredFormatter(
|
18 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
19 |
|
20 |
+
# stout
|
21 |
+
ch = logging.StreamHandler()
|
22 |
+
ch.setFormatter(formatter)
|
23 |
|
24 |
+
# colored output so log messages stand out more
|
25 |
+
# coloredlogs.install(level=log_level, logger=logger)
|
|
|
26 |
|
27 |
+
# file handler
|
28 |
+
fh = logging.FileHandler("logs.log")
|
29 |
+
fh.setFormatter(formatter)
|
30 |
|
31 |
+
logger.addHandler(fh)
|
32 |
+
logger.addHandler(ch)
|
|
|
|
src/utils/timing.py
CHANGED
@@ -1,11 +1,8 @@
|
|
1 |
import time
|
2 |
-
from typing import Dict
|
3 |
from dotenv import load_dotenv
|
4 |
import os
|
5 |
-
from src.utils.log import
|
6 |
-
|
7 |
-
|
8 |
-
logger = get_logger()
|
9 |
|
10 |
|
11 |
load_dotenv()
|
@@ -17,7 +14,7 @@ if ENABLE_TIMING:
|
|
17 |
logger.info("Timing is enabled")
|
18 |
|
19 |
|
20 |
-
TimingType = Dict[str, float]
|
21 |
|
22 |
TIMES: TimingType = {}
|
23 |
|
|
|
1 |
import time
|
2 |
+
from typing import Dict, List
|
3 |
from dotenv import load_dotenv
|
4 |
import os
|
5 |
+
from src.utils.log import logger
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
load_dotenv()
|
|
|
14 |
logger.info("Timing is enabled")
|
15 |
|
16 |
|
17 |
+
TimingType = Dict[str, List[float]]
|
18 |
|
19 |
TIMES: TimingType = {}
|
20 |
|