Ramon Meffert commited on
Commit
0157dfd
·
1 Parent(s): be1f224

Fix timings and add timing results

Browse files
.env.example CHANGED
@@ -3,6 +3,6 @@ ELASTIC_PASSWORD=<password>
3
  ELASTIC_HOST=https://localhost:9200
4
 
5
  LOG_LEVEL=INFO
6
- TRANSFORMERS_NO_ADVISORY_WARNINGS
7
-
8
  ENABLE_TIMING=TRUE
 
3
  ELASTIC_HOST=https://localhost:9200
4
 
5
  LOG_LEVEL=INFO
6
+ TRANSFORMERS_NO_ADVISORY_WARNINGS=true
7
+ KMP_DUPLICATE_LIB_OK=true
8
  ENABLE_TIMING=TRUE
main.py CHANGED
@@ -1,25 +1,34 @@
1
- import random
2
- from typing import Dict, cast
 
 
 
 
3
 
4
- import torch
5
- import transformers
6
  from datasets import DatasetDict, load_dataset
7
- from dotenv import load_dotenv
8
- from query import print_answers
9
 
 
10
  from src.evaluation import evaluate
11
  from src.readers.dpr_reader import DprReader
 
12
  from src.retrievers.base_retriever import Retriever
13
  from src.retrievers.es_retriever import ESRetriever
14
- from src.retrievers.faiss_retriever import FaissRetriever
15
- from src.utils.log import get_logger
 
 
 
16
  from src.utils.preprocessing import context_to_reader_input
17
  from src.utils.timing import get_times, timeit
18
 
19
- logger = get_logger()
20
 
21
- load_dotenv()
22
- transformers.logging.set_verbosity_error()
 
 
 
23
 
24
  if __name__ == '__main__':
25
  dataset_name = "GroNLP/ik-nlp-22_slp"
@@ -28,41 +37,72 @@ if __name__ == '__main__':
28
  questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
29
 
30
  # Only doing a few questions for speed
31
- subset_idx = 3
32
  questions_test = questions["test"][:subset_idx]
33
 
34
- experiments: Dict[str, Retriever] = {
35
- "faiss": FaissRetriever(paragraphs),
36
- # "es": ESRetriever(paragraphs),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  }
38
 
39
- for experiment_name, retriever in experiments.items():
40
- reader = DprReader()
41
-
42
  for idx in range(subset_idx):
43
  question = questions_test["question"][idx]
44
  answer = questions_test["answer"][idx]
45
 
46
- scores, context = retriever.retrieve(question, 5)
 
 
 
 
 
 
 
 
47
  reader_input = context_to_reader_input(context)
48
 
49
- # workaround so we can use the decorator with a dynamic name for time recording
50
- time_wrapper = timeit(f"{experiment_name}.read")
51
- answers = time_wrapper(reader.read)(question, reader_input, 5)
52
 
53
  # Calculate softmaxed scores for readable output
54
- sm = torch.nn.Softmax(dim=0)
55
- document_scores = sm(torch.Tensor(
56
- [pred.relevance_score for pred in answers]))
57
- span_scores = sm(torch.Tensor(
58
- [pred.span_score for pred in answers]))
59
 
60
- print_answers(answers, scores, context)
61
 
62
  # TODO evaluation and storing of results
 
63
 
64
  times = get_times()
65
- print(times)
 
 
 
 
 
66
  # TODO evaluation and storing of results
67
 
68
  # # Initialize retriever
 
1
+ from dotenv import load_dotenv
2
+ # needs to happen as very first thing, otherwise HF ignores env vars
3
+ load_dotenv()
4
+
5
+ import os
6
+ import pandas as pd
7
 
8
+ from dataclasses import dataclass
9
+ from typing import Dict, cast
10
  from datasets import DatasetDict, load_dataset
 
 
11
 
12
+ from src.readers.base_reader import Reader
13
  from src.evaluation import evaluate
14
  from src.readers.dpr_reader import DprReader
15
+ from src.readers.longformer_reader import LongformerReader
16
  from src.retrievers.base_retriever import Retriever
17
  from src.retrievers.es_retriever import ESRetriever
18
+ from src.retrievers.faiss_retriever import (
19
+ FaissRetriever,
20
+ FaissRetrieverOptions
21
+ )
22
+ from src.utils.log import logger
23
  from src.utils.preprocessing import context_to_reader_input
24
  from src.utils.timing import get_times, timeit
25
 
 
26
 
27
+ @dataclass
28
+ class Experiment:
29
+ retriever: Retriever
30
+ reader: Reader
31
+
32
 
33
  if __name__ == '__main__':
34
  dataset_name = "GroNLP/ik-nlp-22_slp"
 
37
  questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
38
 
39
  # Only doing a few questions for speed
40
+ subset_idx = len(questions["test"])
41
  questions_test = questions["test"][:subset_idx]
42
 
43
+ experiments: Dict[str, Experiment] = {
44
+ "faiss_dpr": Experiment(
45
+ retriever=FaissRetriever(
46
+ paragraphs,
47
+ FaissRetrieverOptions.dpr("./src/models/dpr.faiss")),
48
+ reader=DprReader()
49
+ ),
50
+ "faiss_longformer": Experiment(
51
+ retriever=FaissRetriever(
52
+ paragraphs,
53
+ FaissRetrieverOptions.longformer("./src/models/longformer.faiss")),
54
+ reader=LongformerReader()
55
+ ),
56
+ "es_dpr": Experiment(
57
+ retriever=ESRetriever(paragraphs),
58
+ reader=DprReader()
59
+ ),
60
+ "es_longformer": Experiment(
61
+ retriever=ESRetriever(paragraphs),
62
+ reader=LongformerReader()
63
+ ),
64
  }
65
 
66
+ for experiment_name, experiment in experiments.items():
67
+ logger.info(f"Running experiment {experiment_name}...")
 
68
  for idx in range(subset_idx):
69
  question = questions_test["question"][idx]
70
  answer = questions_test["answer"][idx]
71
 
72
+ retrieve_timer = timeit(f"{experiment_name}.retrieve")
73
+ t_retrieve = retrieve_timer(experiment.retriever.retrieve)
74
+
75
+ read_timer = timeit(f"{experiment_name}.read")
76
+ t_read = read_timer(experiment.reader.read)
77
+
78
+ print(f"\x1b[1K\r[{idx+1:03}] - \"{question}\"", end='')
79
+
80
+ scores, context = t_retrieve(question, 5)
81
  reader_input = context_to_reader_input(context)
82
 
83
+ # workaround so we can use the decorator with a dynamic name for
84
+ # time recording
85
+ answers = t_read(question, reader_input, 5)
86
 
87
  # Calculate softmaxed scores for readable output
88
+ # sm = torch.nn.Softmax(dim=0)
89
+ # document_scores = sm(torch.Tensor(
90
+ # [pred.relevance_score for pred in answers]))
91
+ # span_scores = sm(torch.Tensor(
92
+ # [pred.span_score for pred in answers]))
93
 
94
+ # print_answers(answers, scores, context)
95
 
96
  # TODO evaluation and storing of results
97
+ print()
98
 
99
  times = get_times()
100
+
101
+ df = pd.DataFrame(times)
102
+ os.makedirs("./results/", exist_ok=True)
103
+ df.to_csv("./results/timings.csv")
104
+
105
+
106
  # TODO evaluation and storing of results
107
 
108
  # # Initialize retriever
poetry.lock CHANGED
@@ -212,6 +212,20 @@ category = "main"
212
  optional = false
213
  python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  [[package]]
216
  name = "cryptography"
217
  version = "36.0.2"
@@ -266,13 +280,13 @@ xxhash = "*"
266
  apache-beam = ["apache-beam (>=2.26.0)"]
267
  audio = ["librosa"]
268
  benchmarks = ["numpy (==1.18.5)", "tensorflow (==2.3.0)", "torch (==1.6.0)", "transformers (==3.0.2)"]
269
- dev = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore", "boto3", "botocore", "faiss-cpu (>=1.6.4)", "fsspec", "moto[s3,server] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "bert-score (>=0.3.6)", "rouge-score", "sacrebleu", "scipy", "seqeval", "scikit-learn", "jiwer", "sentencepiece", "torchmetrics (==0.6.0)", "mauve-text", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "wget (>=3.2)", "pytorch-nlp (==0.5.0)", "pytorch-lightning", "fastBPE (==0.1.0)", "fairseq", "black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)", "importlib-resources"]
270
  docs = ["docutils (==0.16.0)", "recommonmark", "sphinx (==3.1.2)", "sphinx-markdown-tables", "sphinx-rtd-theme (==0.4.3)", "sphinxext-opengraph (==0.4.1)", "sphinx-copybutton", "fsspec (<2021.9.0)", "s3fs", "sphinx-panels", "sphinx-inline-tabs", "myst-parser", "Markdown (!=3.3.5)"]
271
  quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)"]
272
  s3 = ["fsspec", "boto3", "botocore", "s3fs"]
273
  tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)"]
274
  tensorflow_gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
275
- tests = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore", "boto3", "botocore", "faiss-cpu (>=1.6.4)", "fsspec", "moto[s3,server] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "bert-score (>=0.3.6)", "rouge-score", "sacrebleu", "scipy", "seqeval", "scikit-learn", "jiwer", "sentencepiece", "torchmetrics (==0.6.0)", "mauve-text", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "wget (>=3.2)", "pytorch-nlp (==0.5.0)", "pytorch-lightning", "fastBPE (==0.1.0)", "fairseq", "importlib-resources"]
276
  torch = ["torch"]
277
  vision = ["Pillow (>=6.2.1)"]
278
 
@@ -439,7 +453,7 @@ python-versions = ">=3.7"
439
 
440
  [[package]]
441
  name = "fsspec"
442
- version = "2022.2.0"
443
  description = "File-system specification"
444
  category = "main"
445
  optional = false
@@ -470,10 +484,11 @@ s3 = ["s3fs"]
470
  sftp = ["paramiko"]
471
  smb = ["smbprotocol"]
472
  ssh = ["paramiko"]
 
473
 
474
  [[package]]
475
  name = "gradio"
476
- version = "2.9.0"
477
  description = "Python library for easily interacting with trained machine learning models"
478
  category = "main"
479
  optional = false
@@ -529,6 +544,17 @@ all = ["pytest", "datasets", "black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3
529
  dev = ["pytest", "datasets", "black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
530
  quality = ["black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
531
 
 
 
 
 
 
 
 
 
 
 
 
532
  [[package]]
533
  name = "idna"
534
  version = "3.3"
@@ -1099,6 +1125,14 @@ python-versions = ">=3.6"
1099
  [package.extras]
1100
  diagrams = ["jinja2", "railroad-diagrams"]
1101
 
 
 
 
 
 
 
 
 
1102
  [[package]]
1103
  name = "python-dateutil"
1104
  version = "2.8.2"
@@ -1498,7 +1532,7 @@ multidict = ">=4.0"
1498
  [metadata]
1499
  lock-version = "1.1"
1500
  python-versions = "^3.8"
1501
- content-hash = "a9ce48f30c8568321f3f4576e1c4987ef94a4216201ba4bce2dc719c397d5da6"
1502
 
1503
  [metadata.files]
1504
  aiohttp = [
@@ -1699,6 +1733,10 @@ colorama = [
1699
  {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"},
1700
  {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"},
1701
  ]
 
 
 
 
1702
  cryptography = [
1703
  {file = "cryptography-36.0.2-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:4e2dddd38a5ba733be6a025a1475a9f45e4e41139d1321f412c6b360b19070b6"},
1704
  {file = "cryptography-36.0.2-cp36-abi3-macosx_10_10_x86_64.whl", hash = "sha256:4881d09298cd0b669bb15b9cfe6166f16fc1277b4ed0d04a22f3d6430cb30f1d"},
@@ -1880,12 +1918,12 @@ frozenlist = [
1880
  {file = "frozenlist-1.3.0.tar.gz", hash = "sha256:ce6f2ba0edb7b0c1d8976565298ad2deba6f8064d2bebb6ffce2ca896eb35b0b"},
1881
  ]
1882
  fsspec = [
1883
- {file = "fsspec-2022.2.0-py3-none-any.whl", hash = "sha256:eb9c9d9aee49d23028deefffe53e87c55d3515512c63f57e893710301001449a"},
1884
- {file = "fsspec-2022.2.0.tar.gz", hash = "sha256:20322c659538501f52f6caa73b08b2ff570b7e8ea30a86559721d090e473ad5c"},
1885
  ]
1886
  gradio = [
1887
- {file = "gradio-2.9.0-py3-none-any.whl", hash = "sha256:02c3604d8c662dc35a60e75f55c3de175f8e2c30bf868c39e82f8c20a608d80b"},
1888
- {file = "gradio-2.9.0.tar.gz", hash = "sha256:2cfbde23425c97959291d88ceae55e3d83e1a32915a0e9f7032c8c81bd4f5b63"},
1889
  ]
1890
  h11 = [
1891
  {file = "h11-0.13.0-py3-none-any.whl", hash = "sha256:8ddd78563b633ca55346c8cd41ec0af27d3c79931828beffb46ce70a379e7442"},
@@ -1895,6 +1933,10 @@ huggingface-hub = [
1895
  {file = "huggingface_hub-0.4.0-py3-none-any.whl", hash = "sha256:808021af1ce1111104973ae54d81738eaf40be6d1e82fc6bdedb82f81c6206e7"},
1896
  {file = "huggingface_hub-0.4.0.tar.gz", hash = "sha256:f0e3389f8988eb7781b17de520ae7fd0aa50d9823534e3ae55344d943a88ac87"},
1897
  ]
 
 
 
 
1898
  idna = [
1899
  {file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"},
1900
  {file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"},
@@ -2193,6 +2235,7 @@ numpy = [
2193
  {file = "numpy-1.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8251ed96f38b47b4295b1ae51631de7ffa8260b5b087808ef09a39a9d66c97ab"},
2194
  {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a3aecd3b997bf452a2dedb11f4e79bc5bfd21a1d4cc760e703c31d57c84b3e"},
2195
  {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3bae1a2ed00e90b3ba5f7bd0a7c7999b55d609e0c54ceb2b076a25e345fa9f4"},
 
2196
  {file = "numpy-1.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:08d9b008d0156c70dc392bb3ab3abb6e7a711383c3247b410b39962263576cd4"},
2197
  {file = "numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:201b4d0552831f7250a08d3b38de0d989d6f6e4658b709a02a73c524ccc6ffce"},
2198
  {file = "numpy-1.22.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8c1f39caad2c896bc0018f699882b345b2a63708008be29b1f355ebf6f933fe"},
@@ -2510,6 +2553,10 @@ pyparsing = [
2510
  {file = "pyparsing-3.0.7-py3-none-any.whl", hash = "sha256:a6c06a88f252e6c322f65faf8f418b16213b51bdfaece0524c1c1bc30c63c484"},
2511
  {file = "pyparsing-3.0.7.tar.gz", hash = "sha256:18ee9022775d270c55187733956460083db60b37d0d0fb357445f3094eed3eea"},
2512
  ]
 
 
 
 
2513
  python-dateutil = [
2514
  {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
2515
  {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
 
212
  optional = false
213
  python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
214
 
215
+ [[package]]
216
+ name = "coloredlogs"
217
+ version = "15.0.1"
218
+ description = "Colored terminal output for Python's logging module"
219
+ category = "main"
220
+ optional = false
221
+ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
222
+
223
+ [package.dependencies]
224
+ humanfriendly = ">=9.1"
225
+
226
+ [package.extras]
227
+ cron = ["capturer (>=2.4)"]
228
+
229
  [[package]]
230
  name = "cryptography"
231
  version = "36.0.2"
 
280
  apache-beam = ["apache-beam (>=2.26.0)"]
281
  audio = ["librosa"]
282
  benchmarks = ["numpy (==1.18.5)", "tensorflow (==2.3.0)", "torch (==1.6.0)", "transformers (==3.0.2)"]
283
+ dev = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore", "boto3", "botocore", "faiss-cpu (>=1.6.4)", "fsspec", "moto[server,s3] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "bert-score (>=0.3.6)", "rouge-score", "sacrebleu", "scipy", "seqeval", "scikit-learn", "jiwer", "sentencepiece", "torchmetrics (==0.6.0)", "mauve-text", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "wget (>=3.2)", "pytorch-nlp (==0.5.0)", "pytorch-lightning", "fastBPE (==0.1.0)", "fairseq", "black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)", "importlib-resources"]
284
  docs = ["docutils (==0.16.0)", "recommonmark", "sphinx (==3.1.2)", "sphinx-markdown-tables", "sphinx-rtd-theme (==0.4.3)", "sphinxext-opengraph (==0.4.1)", "sphinx-copybutton", "fsspec (<2021.9.0)", "s3fs", "sphinx-panels", "sphinx-inline-tabs", "myst-parser", "Markdown (!=3.3.5)"]
285
  quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)"]
286
  s3 = ["fsspec", "boto3", "botocore", "s3fs"]
287
  tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)"]
288
  tensorflow_gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"]
289
+ tests = ["absl-py", "pytest", "pytest-datadir", "pytest-xdist", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "aiobotocore", "boto3", "botocore", "faiss-cpu (>=1.6.4)", "fsspec", "moto[server,s3] (==2.0.4)", "rarfile (>=4.0)", "s3fs (==2021.08.1)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1)", "torch", "torchaudio", "soundfile", "transformers", "bs4", "conllu", "h5py", "langdetect", "lxml", "mwparserfromhell", "nltk", "openpyxl", "py7zr", "tldextract", "zstandard", "bert-score (>=0.3.6)", "rouge-score", "sacrebleu", "scipy", "seqeval", "scikit-learn", "jiwer", "sentencepiece", "torchmetrics (==0.6.0)", "mauve-text", "toml (>=0.10.1)", "requests-file (>=1.5.1)", "tldextract (>=3.1.0)", "texttable (>=1.6.3)", "Werkzeug (>=1.0.1)", "six (>=1.15.0,<1.16.0)", "Pillow (>=6.2.1)", "librosa", "wget (>=3.2)", "pytorch-nlp (==0.5.0)", "pytorch-lightning", "fastBPE (==0.1.0)", "fairseq", "importlib-resources"]
290
  torch = ["torch"]
291
  vision = ["Pillow (>=6.2.1)"]
292
 
 
453
 
454
  [[package]]
455
  name = "fsspec"
456
+ version = "2022.3.0"
457
  description = "File-system specification"
458
  category = "main"
459
  optional = false
 
484
  sftp = ["paramiko"]
485
  smb = ["smbprotocol"]
486
  ssh = ["paramiko"]
487
+ tqdm = ["tqdm"]
488
 
489
  [[package]]
490
  name = "gradio"
491
+ version = "2.9.1"
492
  description = "Python library for easily interacting with trained machine learning models"
493
  category = "main"
494
  optional = false
 
544
  dev = ["pytest", "datasets", "black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
545
  quality = ["black (>=20.8b1)", "isort (>=5.5.4)", "flake8 (>=3.8.3)"]
546
 
547
+ [[package]]
548
+ name = "humanfriendly"
549
+ version = "10.0"
550
+ description = "Human friendly output for text interfaces using Python"
551
+ category = "main"
552
+ optional = false
553
+ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
554
+
555
+ [package.dependencies]
556
+ pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""}
557
+
558
  [[package]]
559
  name = "idna"
560
  version = "3.3"
 
1125
  [package.extras]
1126
  diagrams = ["jinja2", "railroad-diagrams"]
1127
 
1128
+ [[package]]
1129
+ name = "pyreadline3"
1130
+ version = "3.4.1"
1131
+ description = "A python implementation of GNU readline."
1132
+ category = "main"
1133
+ optional = false
1134
+ python-versions = "*"
1135
+
1136
  [[package]]
1137
  name = "python-dateutil"
1138
  version = "2.8.2"
 
1532
  [metadata]
1533
  lock-version = "1.1"
1534
  python-versions = "^3.8"
1535
+ content-hash = "881ba67f914b3c0690bcb34810061252ee77cebc0dac49b5ae76348394d810a8"
1536
 
1537
  [metadata.files]
1538
  aiohttp = [
 
1733
  {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"},
1734
  {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"},
1735
  ]
1736
+ coloredlogs = [
1737
+ {file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"},
1738
+ {file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"},
1739
+ ]
1740
  cryptography = [
1741
  {file = "cryptography-36.0.2-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:4e2dddd38a5ba733be6a025a1475a9f45e4e41139d1321f412c6b360b19070b6"},
1742
  {file = "cryptography-36.0.2-cp36-abi3-macosx_10_10_x86_64.whl", hash = "sha256:4881d09298cd0b669bb15b9cfe6166f16fc1277b4ed0d04a22f3d6430cb30f1d"},
 
1918
  {file = "frozenlist-1.3.0.tar.gz", hash = "sha256:ce6f2ba0edb7b0c1d8976565298ad2deba6f8064d2bebb6ffce2ca896eb35b0b"},
1919
  ]
1920
  fsspec = [
1921
+ {file = "fsspec-2022.3.0-py3-none-any.whl", hash = "sha256:a53491b003210fce6911dd8f2d37e20c41a27ce52a655eef11b885d1578ed4cf"},
1922
+ {file = "fsspec-2022.3.0.tar.gz", hash = "sha256:fd582cc4aa0db5968bad9317cae513450eddd08b2193c4428d9349265a995523"},
1923
  ]
1924
  gradio = [
1925
+ {file = "gradio-2.9.1-py3-none-any.whl", hash = "sha256:877616dcda82e0e13bc04404c13f084c7b3a06cccc314a4db06b21c5f15f6190"},
1926
+ {file = "gradio-2.9.1.tar.gz", hash = "sha256:d9dfde81f064f38bcd95967316501ab40698fec0bcc4435dd00ea4578f695042"},
1927
  ]
1928
  h11 = [
1929
  {file = "h11-0.13.0-py3-none-any.whl", hash = "sha256:8ddd78563b633ca55346c8cd41ec0af27d3c79931828beffb46ce70a379e7442"},
 
1933
  {file = "huggingface_hub-0.4.0-py3-none-any.whl", hash = "sha256:808021af1ce1111104973ae54d81738eaf40be6d1e82fc6bdedb82f81c6206e7"},
1934
  {file = "huggingface_hub-0.4.0.tar.gz", hash = "sha256:f0e3389f8988eb7781b17de520ae7fd0aa50d9823534e3ae55344d943a88ac87"},
1935
  ]
1936
+ humanfriendly = [
1937
+ {file = "humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477"},
1938
+ {file = "humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc"},
1939
+ ]
1940
  idna = [
1941
  {file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"},
1942
  {file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"},
 
2235
  {file = "numpy-1.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8251ed96f38b47b4295b1ae51631de7ffa8260b5b087808ef09a39a9d66c97ab"},
2236
  {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a3aecd3b997bf452a2dedb11f4e79bc5bfd21a1d4cc760e703c31d57c84b3e"},
2237
  {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3bae1a2ed00e90b3ba5f7bd0a7c7999b55d609e0c54ceb2b076a25e345fa9f4"},
2238
+ {file = "numpy-1.22.3-cp310-cp310-win32.whl", hash = "sha256:f950f8845b480cffe522913d35567e29dd381b0dc7e4ce6a4a9f9156417d2430"},
2239
  {file = "numpy-1.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:08d9b008d0156c70dc392bb3ab3abb6e7a711383c3247b410b39962263576cd4"},
2240
  {file = "numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:201b4d0552831f7250a08d3b38de0d989d6f6e4658b709a02a73c524ccc6ffce"},
2241
  {file = "numpy-1.22.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8c1f39caad2c896bc0018f699882b345b2a63708008be29b1f355ebf6f933fe"},
 
2553
  {file = "pyparsing-3.0.7-py3-none-any.whl", hash = "sha256:a6c06a88f252e6c322f65faf8f418b16213b51bdfaece0524c1c1bc30c63c484"},
2554
  {file = "pyparsing-3.0.7.tar.gz", hash = "sha256:18ee9022775d270c55187733956460083db60b37d0d0fb357445f3094eed3eea"},
2555
  ]
2556
+ pyreadline3 = [
2557
+ {file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"},
2558
+ {file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"},
2559
+ ]
2560
  python-dateutil = [
2561
  {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
2562
  {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
pyproject.toml CHANGED
@@ -15,6 +15,7 @@ python-dotenv = "^0.19.2"
15
  elasticsearch = "^8.1.0"
16
  gradio = {extras = ["Jinja2"], version = "^2.9.0"}
17
  Jinja2 = "^3.1.1"
 
18
 
19
  [tool.poetry.dev-dependencies]
20
  flake8 = "^4.0.1"
 
15
  elasticsearch = "^8.1.0"
16
  gradio = {extras = ["Jinja2"], version = "^2.9.0"}
17
  Jinja2 = "^3.1.1"
18
+ coloredlogs = "^15.0.1"
19
 
20
  [tool.poetry.dev-dependencies]
21
  flake8 = "^4.0.1"
query.py CHANGED
@@ -16,7 +16,12 @@ from src.retrievers.faiss_retriever import (
16
  FaissRetrieverOptions
17
  )
18
  from src.utils.preprocessing import context_to_reader_input
19
- from src.utils.log import get_logger
 
 
 
 
 
20
 
21
 
22
  def get_retriever(paragraphs: DatasetDict,
@@ -123,11 +128,6 @@ def main(args: argparse.Namespace):
123
 
124
 
125
  if __name__ == "__main__":
126
- # Setup environment
127
- load_dotenv()
128
- logger = get_logger()
129
- transformers.logging.set_verbosity_error()
130
-
131
  # Set up CLI arguments
132
  parser = argparse.ArgumentParser(
133
  formatter_class=argparse.MetavarTypeHelpFormatter
 
16
  FaissRetrieverOptions
17
  )
18
  from src.utils.preprocessing import context_to_reader_input
19
+ from src.utils.log import logger
20
+
21
+
22
+ # Setup environment
23
+ load_dotenv()
24
+ transformers.logging.set_verbosity_error()
25
 
26
 
27
  def get_retriever(paragraphs: DatasetDict,
 
128
 
129
 
130
  if __name__ == "__main__":
 
 
 
 
 
131
  # Set up CLI arguments
132
  parser = argparse.ArgumentParser(
133
  formatter_class=argparse.MetavarTypeHelpFormatter
results/timings.csv ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,faiss_dpr.retrieve,faiss_dpr.read,faiss_longformer.retrieve,faiss_longformer.read,es_dpr.retrieve,es_dpr.read,es_longformer.retrieve,es_longformer.read
2
+ 0,0.30384302139282227,4.566400051116943,0.9227948188781738,5.768368244171143,0.01930093765258789,2.7453649044036865,0.010576009750366211,4.998417854309082
3
+ 1,0.04573678970336914,1.9288370609283447,0.8380529880523682,5.916611671447754,0.018373966217041016,1.4845240116119385,0.012102842330932617,5.1692070960998535
4
+ 2,0.04764819145202637,0.6628780364990234,0.7756149768829346,5.4998250007629395,0.015324831008911133,1.7706871032714844,0.012642860412597656,5.202448844909668
5
+ 3,0.04507589340209961,1.219634771347046,0.8142738342285156,5.726102113723755,0.021118879318237305,1.987663984298706,0.012515068054199219,5.1083409786224365
6
+ 4,0.04347515106201172,1.5222840309143066,0.814906120300293,5.672412872314453,0.013732194900512695,1.660247802734375,0.011805057525634766,5.313212156295776
7
+ 5,0.07470989227294922,1.5599188804626465,0.8422539234161377,5.75390100479126,0.018023014068603516,1.5782928466796875,0.013046741485595703,5.419210195541382
8
+ 6,0.06162095069885254,1.4178202152252197,0.7837569713592529,4.765166282653809,0.014074325561523438,0.7626080513000488,0.010712146759033203,4.976129055023193
9
+ 7,0.0451970100402832,1.134779691696167,0.7723889350891113,4.5784592628479,0.012156963348388672,1.959972858428955,0.011015892028808594,4.5342161655426025
10
+ 8,0.03589582443237305,0.8912148475646973,0.8142461776733398,5.212156295776367,0.009800195693969727,1.820624828338623,0.009167194366455078,4.468229293823242
11
+ 9,0.06033587455749512,0.37888431549072266,1.1510162353515625,5.395290851593018,0.015815258026123047,1.0247371196746826,0.010970830917358398,5.1864588260650635
12
+ 10,0.056854963302612305,0.6068317890167236,0.7839999198913574,4.668170928955078,0.013895988464355469,0.8482949733734131,0.011837005615234375,4.493913173675537
13
+ 11,0.04697012901306152,0.341174840927124,0.8089311122894287,5.535298109054565,0.01624298095703125,1.4673452377319336,0.01172780990600586,5.264245986938477
14
+ 12,0.0444178581237793,1.4774150848388672,0.7612121105194092,5.504917860031128,0.01690196990966797,2.42773699760437,0.011591196060180664,4.463428974151611
15
+ 13,0.06343889236450195,0.8000409603118896,0.9072589874267578,6.015661954879761,0.010895252227783203,2.21577787399292,0.012058019638061523,5.302258253097534
16
+ 14,0.05022692680358887,0.9474368095397949,0.8324599266052246,6.4684131145477295,0.016222000122070312,3.1302390098571777,0.010799169540405273,5.240310907363892
17
+ 15,0.08313488960266113,0.746314287185669,0.8373219966888428,6.741006851196289,0.017467975616455078,1.353593111038208,0.020203113555908203,5.192620038986206
18
+ 16,0.17216706275939941,1.136448860168457,0.7760000228881836,5.48329496383667,0.018398046493530273,0.5325403213500977,0.012536287307739258,5.264941215515137
19
+ 17,0.04575324058532715,0.5853927135467529,0.7855441570281982,5.960904121398926,0.019134998321533203,2.6092309951782227,0.012146949768066406,5.331835031509399
20
+ 18,0.04746294021606445,0.9219169616699219,0.9516820907592773,10.146074295043945,0.011114835739135742,2.220487117767334,0.011821985244750977,5.161020994186401
21
+ 19,0.0443730354309082,0.4667840003967285,1.3496372699737549,8.213943719863892,0.01042485237121582,2.841907024383545,0.011536121368408203,5.236280918121338
22
+ 20,0.06004190444946289,0.6129250526428223,1.3677341938018799,7.00742769241333,0.022186994552612305,1.6846930980682373,0.010824918746948242,5.377984046936035
23
+ 21,0.06920814514160156,0.6232960224151611,1.4656860828399658,6.424375057220459,0.011613845825195312,1.0811800956726074,0.014858007431030273,5.279160022735596
24
+ 22,0.04999184608459473,0.6539132595062256,0.8720510005950928,5.889069080352783,0.016654014587402344,1.6599159240722656,0.012172698974609375,5.177525043487549
25
+ 23,0.05750322341918945,1.0169367790222168,0.9728169441223145,6.934185028076172,0.01772904396057129,1.2837882041931152,0.011108160018920898,5.186945199966431
26
+ 24,0.06264281272888184,1.7151312828063965,1.3927390575408936,7.122100114822388,0.016143798828125,1.5387201309204102,0.011415958404541016,4.558846950531006
27
+ 25,0.04831504821777344,0.7839398384094238,1.1007087230682373,5.4652369022369385,0.01099395751953125,1.5678913593292236,0.011976242065429688,4.612828969955444
28
+ 26,0.048091888427734375,0.9228200912475586,0.8567941188812256,4.832158803939819,0.013817787170410156,2.0290918350219727,0.015846967697143555,4.845104217529297
29
+ 27,0.04568672180175781,0.8964569568634033,0.7873432636260986,4.592561960220337,0.010241985321044922,0.3145887851715088,0.014873743057250977,4.759660720825195
30
+ 28,0.04340720176696777,0.5004391670227051,0.8122010231018066,4.68702507019043,0.012717008590698242,0.9207170009613037,0.014780759811401367,4.955734968185425
31
+ 29,0.045496225357055664,2.106112003326416,0.7901277542114258,5.48145604133606,0.009778976440429688,1.1795310974121094,0.011364936828613281,5.397800922393799
32
+ 30,0.05589914321899414,3.3801350593566895,0.7913417816162109,4.76953387260437,0.01169896125793457,2.8297739028930664,0.012899160385131836,4.7149817943573
33
+ 31,0.038469791412353516,1.2037632465362549,0.812114953994751,4.819751977920532,0.010591983795166016,1.0633080005645752,0.011631011962890625,4.603592157363892
34
+ 32,0.043640851974487305,0.7455379962921143,0.7684001922607422,5.490149021148682,0.010446786880493164,1.509342908859253,0.01111912727355957,5.431332111358643
35
+ 33,0.0411829948425293,0.7775781154632568,0.7725949287414551,5.5284202098846436,0.011181116104125977,1.4173851013183594,0.01881098747253418,5.2474939823150635
36
+ 34,0.04268312454223633,1.3576858043670654,0.7971670627593994,5.488955974578857,0.016661882400512695,0.6669139862060547,0.011193990707397461,5.231971263885498
37
+ 35,0.0432438850402832,0.49681520462036133,0.7736399173736572,4.675936698913574,0.013994932174682617,0.7481560707092285,0.01053619384765625,4.871787071228027
38
+ 36,0.038790225982666016,1.9925789833068848,0.7900221347808838,4.716547012329102,0.010754108428955078,0.8104310035705566,0.011471986770629883,4.582187175750732
39
+ 37,0.04674410820007324,0.8766942024230957,0.8192441463470459,5.454381704330444,0.012632131576538086,3.3098862171173096,0.01573491096496582,5.5617289543151855
40
+ 38,0.04983806610107422,0.5784440040588379,0.768744945526123,5.399757146835327,0.017091035842895508,1.0388100147247314,0.020289897918701172,5.327627897262573
41
+ 39,0.039936065673828125,0.9906370639801025,0.7951750755310059,4.816935062408447,0.009315729141235352,0.8949270248413086,0.012948989868164062,4.823601007461548
42
+ 40,0.04812121391296387,5.3651018142700195,0.7833847999572754,4.673122882843018,0.010359048843383789,1.6986067295074463,0.012405872344970703,4.822720050811768
43
+ 41,0.037177085876464844,0.8579537868499756,0.768902063369751,4.705405950546265,0.01087808609008789,1.1154420375823975,0.009827136993408203,5.295310020446777
44
+ 42,0.03615593910217285,0.6045210361480713,0.7767770290374756,4.721595048904419,0.012170076370239258,1.168515920639038,0.014606952667236328,4.778914213180542
45
+ 43,0.04032111167907715,1.0840678215026855,0.8039369583129883,5.5514678955078125,0.011640071868896484,3.7264089584350586,0.015080928802490234,6.431236028671265
46
+ 44,0.12291288375854492,2.7860946655273438,0.7999370098114014,4.700652122497559,0.010669708251953125,3.5256330966949463,0.00997614860534668,5.010454893112183
47
+ 45,0.03981208801269531,0.8575420379638672,0.7781379222869873,4.649600028991699,0.011057853698730469,3.4576022624969482,0.011123895645141602,5.414888143539429
48
+ 46,0.046558380126953125,0.6096041202545166,0.839914083480835,5.3846352100372314,0.0264890193939209,3.282578945159912,0.013241052627563477,6.356001853942871
49
+ 47,0.044730186462402344,0.6428439617156982,0.7774860858917236,5.471776962280273,0.009460926055908203,3.3428800106048584,0.012679100036621094,5.476663112640381
50
+ 48,0.04798007011413574,1.3710291385650635,0.7838289737701416,5.5646140575408936,0.011425018310546875,1.5621020793914795,0.019647836685180664,5.403181076049805
51
+ 49,0.06305599212646484,1.7375829219818115,0.7764248847961426,5.582126140594482,0.010413169860839844,1.6502351760864258,0.011098146438598633,6.15350604057312
52
+ 50,0.04781007766723633,0.919248104095459,0.8292880058288574,4.79367995262146,0.01233983039855957,4.761476755142212,0.01306009292602539,4.901428937911987
53
+ 51,0.04294776916503906,0.9060940742492676,0.7503399848937988,4.69527006149292,0.010550737380981445,1.3250057697296143,0.012276887893676758,4.790279388427734
54
+ 52,0.0449681282043457,0.74688720703125,0.7592051029205322,4.672075986862183,0.010587215423583984,1.7173192501068115,0.012059926986694336,4.9025468826293945
55
+ 53,0.04381895065307617,1.2078793048858643,0.8653321266174316,4.4878456592559814,0.008989810943603516,4.782422065734863,0.012001752853393555,5.331949949264526
56
+ 54,0.06584286689758301,1.0724549293518066,0.7348787784576416,5.094892740249634,0.00992584228515625,1.6900959014892578,0.018785953521728516,6.253833293914795
57
+ 55,0.05147194862365723,2.172264337539673,0.7367160320281982,5.056357145309448,0.009489059448242188,2.6061501502990723,0.011726140975952148,6.203222036361694
58
+ 56,0.05095195770263672,2.0959391593933105,0.7292170524597168,5.11798882484436,0.011085987091064453,1.258976936340332,0.02020883560180664,6.149781942367554
59
+ 57,0.05691885948181152,0.7286462783813477,0.7636628150939941,5.1169517040252686,0.01094675064086914,0.7379579544067383,0.012721061706542969,5.3568830490112305
60
+ 58,0.04192709922790527,0.8154990673065186,0.7308712005615234,5.066887140274048,0.010490894317626953,2.724623203277588,0.01871800422668457,5.368160009384155
src/reader.py DELETED
@@ -1,2 +0,0 @@
1
- class Reader():
2
- pass
 
 
 
src/readers/dpr_reader.py CHANGED
@@ -1,9 +1,13 @@
1
  from transformers import DPRReader, DPRReaderTokenizer
2
  from typing import List, Dict, Tuple
 
3
 
4
  from src.readers.base_reader import Reader
5
 
6
 
 
 
 
7
  class DprReader(Reader):
8
  def __init__(self) -> None:
9
  self._tokenizer = DPRReaderTokenizer.from_pretrained(
 
1
  from transformers import DPRReader, DPRReaderTokenizer
2
  from typing import List, Dict, Tuple
3
+ from dotenv import load_dotenv
4
 
5
  from src.readers.base_reader import Reader
6
 
7
 
8
+ load_dotenv()
9
+
10
+
11
  class DprReader(Reader):
12
  def __init__(self) -> None:
13
  self._tokenizer = DPRReaderTokenizer.from_pretrained(
src/readers/longformer_reader.py CHANGED
@@ -1,17 +1,21 @@
1
  import torch
2
  from transformers import (
3
- LongformerTokenizerFast,
4
  LongformerForQuestionAnswering
5
  )
6
  from typing import List, Dict, Tuple
 
7
 
8
  from src.readers.base_reader import Reader
9
 
10
 
 
 
 
11
  class LongformerReader(Reader):
12
  def __init__(self) -> None:
13
  checkpoint = "valhalla/longformer-base-4096-finetuned-squadv1"
14
- self.tokenizer = LongformerTokenizerFast.from_pretrained(checkpoint)
15
  self.model = LongformerForQuestionAnswering.from_pretrained(checkpoint)
16
 
17
  def read(self,
@@ -21,8 +25,7 @@ class LongformerReader(Reader):
21
  answers = []
22
 
23
  for text in context['texts']:
24
- encoding = self.tokenizer(
25
- query, text, return_tensors="pt")
26
  input_ids = encoding["input_ids"]
27
  attention_mask = encoding["attention_mask"]
28
  outputs = self.model(input_ids, attention_mask=attention_mask)
 
1
  import torch
2
  from transformers import (
3
+ LongformerTokenizer,
4
  LongformerForQuestionAnswering
5
  )
6
  from typing import List, Dict, Tuple
7
+ from dotenv import load_dotenv
8
 
9
  from src.readers.base_reader import Reader
10
 
11
 
12
+ load_dotenv()
13
+
14
+
15
  class LongformerReader(Reader):
16
  def __init__(self) -> None:
17
  checkpoint = "valhalla/longformer-base-4096-finetuned-squadv1"
18
+ self.tokenizer = LongformerTokenizer.from_pretrained(checkpoint)
19
  self.model = LongformerForQuestionAnswering.from_pretrained(checkpoint)
20
 
21
  def read(self,
 
25
  answers = []
26
 
27
  for text in context['texts']:
28
+ encoding = self.tokenizer(query, text, return_tensors="pt")
 
29
  input_ids = encoding["input_ids"]
30
  attention_mask = encoding["attention_mask"]
31
  outputs = self.model(input_ids, attention_mask=attention_mask)
src/retrievers/es_retriever.py CHANGED
@@ -1,13 +1,17 @@
 
1
  import os
2
 
3
  from datasets import DatasetDict
4
  from elasticsearch import Elasticsearch
 
 
5
 
6
  from src.retrievers.base_retriever import RetrieveType, Retriever
7
- from src.utils.log import get_logger
8
  from src.utils.timing import timeit
9
 
10
- logger = get_logger()
 
11
 
12
 
13
  class ESRetriever(Retriever):
@@ -23,6 +27,13 @@ class ESRetriever(Retriever):
23
  http_auth=(es_username, es_password),
24
  ca_certs="./http_ca.crt")
25
 
 
 
 
 
 
 
 
26
  if self.client.indices.exists(index="paragraphs"):
27
  self.paragraphs.load_elasticsearch_index(
28
  "paragraphs", es_index_name="paragraphs",
@@ -34,6 +45,5 @@ class ESRetriever(Retriever):
34
  es_index_name="paragraphs",
35
  es_client=self.client)
36
 
37
- @timeit("esretriever.retrieve")
38
  def retrieve(self, query: str, k: int = 5) -> RetrieveType:
39
  return self.paragraphs.get_nearest_examples("paragraphs", query, k)
 
1
+ import imp
2
  import os
3
 
4
  from datasets import DatasetDict
5
  from elasticsearch import Elasticsearch
6
+ from elastic_transport import ConnectionError
7
+ from dotenv import load_dotenv
8
 
9
  from src.retrievers.base_retriever import RetrieveType, Retriever
10
+ from src.utils.log import logger
11
  from src.utils.timing import timeit
12
 
13
+
14
+ load_dotenv()
15
 
16
 
17
  class ESRetriever(Retriever):
 
27
  http_auth=(es_username, es_password),
28
  ca_certs="./http_ca.crt")
29
 
30
+ try:
31
+ self.client.info()
32
+ except ConnectionError:
33
+ logger.error("Could not connect to ElasticSearch. " +
34
+ "Make sure it is running. Exiting now...")
35
+ exit()
36
+
37
  if self.client.indices.exists(index="paragraphs"):
38
  self.paragraphs.load_elasticsearch_index(
39
  "paragraphs", es_index_name="paragraphs",
 
45
  es_index_name="paragraphs",
46
  es_client=self.client)
47
 
 
48
  def retrieve(self, query: str, k: int = 5) -> RetrieveType:
49
  return self.paragraphs.get_nearest_examples("paragraphs", query, k)
src/retrievers/faiss_retriever.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import os.path
3
  import torch
4
 
 
5
  from datasets import DatasetDict
6
  from dataclasses import dataclass
7
  from transformers import (
@@ -10,22 +11,18 @@ from transformers import (
10
  DPRQuestionEncoder,
11
  DPRQuestionEncoderTokenizerFast,
12
  LongformerModel,
13
- LongformerTokenizerFast
14
  )
15
  from transformers.modeling_utils import PreTrainedModel
16
  from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
17
 
18
  from src.retrievers.base_retriever import RetrieveType, Retriever
19
- from src.utils.log import get_logger
20
  from src.utils.preprocessing import remove_formulas
21
  from src.utils.timing import timeit
22
 
23
- # Hacky fix for FAISS error on macOS
24
- # See https://stackoverflow.com/a/63374568/4545692
25
- os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
26
 
27
-
28
- logger = get_logger()
29
 
30
 
31
  @dataclass
@@ -59,10 +56,10 @@ class FaissRetrieverOptions:
59
  @staticmethod
60
  def longformer(embedding_path: str):
61
  encoder = LongformerModel.from_pretrained(
62
- "allenai/longformer-base-4096"
63
  )
64
- tokenizer = LongformerTokenizerFast.from_pretrained(
65
- "allenai/longformer-base-4096"
66
  )
67
  return FaissRetrieverOptions(
68
  ctx_encoder=encoder,
@@ -145,7 +142,6 @@ class FaissRetriever(Retriever):
145
 
146
  return index
147
 
148
- @timeit("faissretriever.retrieve")
149
  def retrieve(self, query: str, k: int = 5) -> RetrieveType:
150
  question_embedding = self._embed_question(query)
151
  scores, results = self.index.get_nearest_examples(
 
2
  import os.path
3
  import torch
4
 
5
+ from dotenv import load_dotenv
6
  from datasets import DatasetDict
7
  from dataclasses import dataclass
8
  from transformers import (
 
11
  DPRQuestionEncoder,
12
  DPRQuestionEncoderTokenizerFast,
13
  LongformerModel,
14
+ LongformerTokenizer
15
  )
16
  from transformers.modeling_utils import PreTrainedModel
17
  from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
18
 
19
  from src.retrievers.base_retriever import RetrieveType, Retriever
20
+ from src.utils.log import logger
21
  from src.utils.preprocessing import remove_formulas
22
  from src.utils.timing import timeit
23
 
 
 
 
24
 
25
+ load_dotenv()
 
26
 
27
 
28
  @dataclass
 
56
  @staticmethod
57
  def longformer(embedding_path: str):
58
  encoder = LongformerModel.from_pretrained(
59
+ "valhalla/longformer-base-4096-finetuned-squadv1"
60
  )
61
+ tokenizer = LongformerTokenizer.from_pretrained(
62
+ "valhalla/longformer-base-4096-finetuned-squadv1"
63
  )
64
  return FaissRetrieverOptions(
65
  ctx_encoder=encoder,
 
142
 
143
  return index
144
 
 
145
  def retrieve(self, query: str, k: int = 5) -> RetrieveType:
146
  question_embedding = self._embed_question(query)
147
  scores, results = self.index.get_nearest_examples(
src/utils/log.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import logging
2
  import os
3
 
@@ -5,27 +6,27 @@ from dotenv import load_dotenv
5
 
6
  load_dotenv()
7
 
 
 
 
8
 
9
- def get_logger():
10
- # creates a default logger for the project
11
- logger = logging.getLogger("Flashcards")
12
 
13
- log_level = os.getenv("LOG_LEVEL", "INFO")
14
- logger.setLevel(log_level)
 
15
 
16
- # Log format
17
- formatter = logging.Formatter(
18
- "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
19
 
20
- # file handler
21
- fh = logging.FileHandler("logs.log")
22
- fh.setFormatter(formatter)
23
 
24
- # stout
25
- ch = logging.StreamHandler()
26
- ch.setFormatter(formatter)
27
 
28
- logger.addHandler(fh)
29
- logger.addHandler(ch)
30
-
31
- return logger
 
1
+ import coloredlogs
2
  import logging
3
  import os
4
 
 
6
 
7
  load_dotenv()
8
 
9
+ # creates a default logger for the project. We declare it in the global scope
10
+ # so it acts like a singleton
11
+ logger = logging.getLogger("Flashcards")
12
 
13
+ log_level = os.getenv("LOG_LEVEL", "INFO")
14
+ logger.setLevel(log_level)
 
15
 
16
+ # Log format
17
+ formatter = coloredlogs.ColoredFormatter(
18
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
19
 
20
+ # stout
21
+ ch = logging.StreamHandler()
22
+ ch.setFormatter(formatter)
23
 
24
+ # colored output so log messages stand out more
25
+ # coloredlogs.install(level=log_level, logger=logger)
 
26
 
27
+ # file handler
28
+ fh = logging.FileHandler("logs.log")
29
+ fh.setFormatter(formatter)
30
 
31
+ logger.addHandler(fh)
32
+ logger.addHandler(ch)
 
 
src/utils/timing.py CHANGED
@@ -1,11 +1,8 @@
1
  import time
2
- from typing import Dict
3
  from dotenv import load_dotenv
4
  import os
5
- from src.utils.log import get_logger
6
-
7
-
8
- logger = get_logger()
9
 
10
 
11
  load_dotenv()
@@ -17,7 +14,7 @@ if ENABLE_TIMING:
17
  logger.info("Timing is enabled")
18
 
19
 
20
- TimingType = Dict[str, float]
21
 
22
  TIMES: TimingType = {}
23
 
 
1
  import time
2
+ from typing import Dict, List
3
  from dotenv import load_dotenv
4
  import os
5
+ from src.utils.log import logger
 
 
 
6
 
7
 
8
  load_dotenv()
 
14
  logger.info("Timing is enabled")
15
 
16
 
17
+ TimingType = Dict[str, List[float]]
18
 
19
  TIMES: TimingType = {}
20