GGroenendaal commited on
Commit
51dabd6
Β·
1 Parent(s): aa426fb

refactoring of code

Browse files
.env.example ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ELASTIC_USERNAME=elastic
2
+ ELASTIC_PASSWORD=<password>
3
+
4
+ LOG_LEVEL=INFO
README.md CHANGED
@@ -73,3 +73,6 @@ poetry run python main.py
73
  > shows that MT systems perform worse when they are asked to translate sentences
74
  > that describe people with non-stereotypical gender roles, like "The doctor
75
  > asked the nurse to help her in the > operation".
 
 
 
 
73
  > shows that MT systems perform worse when they are asked to translate sentences
74
  > that describe people with non-stereotypical gender roles, like "The doctor
75
  > asked the nurse to help her in the > operation".
76
+
77
+
78
+ ## Setting up elastic search.
base_model/main.py β†’ main.py RENAMED
@@ -1,20 +1,23 @@
1
- from retriever import Retriever
 
 
 
 
2
 
3
 
4
  if __name__ == '__main__':
5
  # Initialize retriever
6
- r = Retriever()
7
 
8
  # Retrieve example
9
  scores, result = r.retrieve(
10
  "What is the perplexity of a language model?")
11
 
12
  for i, score in enumerate(scores):
13
- print(f"Result {i+1} (score: {score:.02f}):")
14
- print(result['text'][i])
15
- print() # Newline
16
 
17
  # Compute overall performance
18
  exact_match, f1_score = r.evaluate()
19
- print(f"Exact match: {exact_match:.02f}\n"
20
- f"F1-score: {f1_score:.02f}")
 
1
+ from src.fais_retriever import FAISRetriever
2
+ from src.utils.log import get_logger
3
+
4
+
5
+ logger = get_logger()
6
 
7
 
8
  if __name__ == '__main__':
9
  # Initialize retriever
10
+ r = FAISRetriever()
11
 
12
  # Retrieve example
13
  scores, result = r.retrieve(
14
  "What is the perplexity of a language model?")
15
 
16
  for i, score in enumerate(scores):
17
+ logger.info(f"Result {i+1} (score: {score:.02f}):")
18
+ logger.info(result['text'][i])
 
19
 
20
  # Compute overall performance
21
  exact_match, f1_score = r.evaluate()
22
+ logger.info(f"Exact match: {exact_match:.02f}\n"
23
+ f"F1-score: {f1_score:.02f}")
poetry.lock CHANGED
@@ -149,6 +149,36 @@ python-versions = ">=2.7, !=3.0.*"
149
  [package.extras]
150
  graph = ["objgraph (>=1.7.2)"]
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  [[package]]
153
  name = "faiss-cpu"
154
  version = "1.7.2"
@@ -291,6 +321,32 @@ python-versions = "*"
291
  [package.dependencies]
292
  dill = ">=0.3.4"
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  [[package]]
295
  name = "numpy"
296
  version = "1.22.3"
@@ -380,6 +436,17 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
380
  [package.dependencies]
381
  six = ">=1.5"
382
 
 
 
 
 
 
 
 
 
 
 
 
383
  [[package]]
384
  name = "pytz"
385
  version = "2021.3"
@@ -480,6 +547,14 @@ category = "dev"
480
  optional = false
481
  python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
482
 
 
 
 
 
 
 
 
 
483
  [[package]]
484
  name = "torch"
485
  version = "1.11.0"
@@ -610,7 +685,7 @@ multidict = ">=4.0"
610
  [metadata]
611
  lock-version = "1.1"
612
  python-versions = "^3.8"
613
- content-hash = "227b922ee14abf36ca75bb238d239d712bed9213d54c567996566d465e465733"
614
 
615
  [metadata.files]
616
  aiohttp = [
@@ -727,6 +802,14 @@ dill = [
727
  {file = "dill-0.3.4-py2.py3-none-any.whl", hash = "sha256:7e40e4a70304fd9ceab3535d36e58791d9c4a776b38ec7f7ec9afc8d3dca4d4f"},
728
  {file = "dill-0.3.4.zip", hash = "sha256:9f9734205146b2b353ab3fec9af0070237b6ddae78452af83d2fca84d739e675"},
729
  ]
 
 
 
 
 
 
 
 
730
  faiss-cpu = [
731
  {file = "faiss-cpu-1.7.2.tar.gz", hash = "sha256:f7ea89de997f55764e3710afaf0a457b2529252f99ee63510d4d9348d5b419dd"},
732
  {file = "faiss_cpu-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b7461f989d757917a3e6dc81eb171d0b563eb98d23ebaf7fc6684d0093ba267e"},
@@ -918,12 +1001,40 @@ multiprocess = [
918
  {file = "multiprocess-0.70.12.2-py39-none-any.whl", hash = "sha256:6f812a1d3f198b7cacd63983f60e2dc1338bd4450893f90c435067b5a3127e6f"},
919
  {file = "multiprocess-0.70.12.2.zip", hash = "sha256:206bb9b97b73f87fec1ed15a19f8762950256aa84225450abc7150d02855a083"},
920
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
921
  numpy = [
922
  {file = "numpy-1.22.3-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:92bfa69cfbdf7dfc3040978ad09a48091143cffb778ec3b03fa170c494118d75"},
923
  {file = "numpy-1.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8251ed96f38b47b4295b1ae51631de7ffa8260b5b087808ef09a39a9d66c97ab"},
924
  {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a3aecd3b997bf452a2dedb11f4e79bc5bfd21a1d4cc760e703c31d57c84b3e"},
925
  {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3bae1a2ed00e90b3ba5f7bd0a7c7999b55d609e0c54ceb2b076a25e345fa9f4"},
926
- {file = "numpy-1.22.3-cp310-cp310-win32.whl", hash = "sha256:f950f8845b480cffe522913d35567e29dd381b0dc7e4ce6a4a9f9156417d2430"},
927
  {file = "numpy-1.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:08d9b008d0156c70dc392bb3ab3abb6e7a711383c3247b410b39962263576cd4"},
928
  {file = "numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:201b4d0552831f7250a08d3b38de0d989d6f6e4658b709a02a73c524ccc6ffce"},
929
  {file = "numpy-1.22.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8c1f39caad2c896bc0018f699882b345b2a63708008be29b1f355ebf6f933fe"},
@@ -1015,6 +1126,10 @@ python-dateutil = [
1015
  {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
1016
  {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
1017
  ]
 
 
 
 
1018
  pytz = [
1019
  {file = "pytz-2021.3-py2.py3-none-any.whl", hash = "sha256:3672058bc3453457b622aab7a1c3bfd5ab0bdae451512f6cf25f64ed37f5b87c"},
1020
  {file = "pytz-2021.3.tar.gz", hash = "sha256:acad2d8b20a1af07d4e4c9d2e9285c5ed9104354062f275f3fcd88dcef4f1326"},
@@ -1189,6 +1304,10 @@ toml = [
1189
  {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
1190
  {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
1191
  ]
 
 
 
 
1192
  torch = [
1193
  {file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
1194
  {file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
 
149
  [package.extras]
150
  graph = ["objgraph (>=1.7.2)"]
151
 
152
+ [[package]]
153
+ name = "elastic-transport"
154
+ version = "8.1.0"
155
+ description = "Transport classes and utilities shared among Python Elastic client libraries"
156
+ category = "main"
157
+ optional = false
158
+ python-versions = ">=3.6"
159
+
160
+ [package.dependencies]
161
+ certifi = "*"
162
+ urllib3 = ">=1.26.2,<2"
163
+
164
+ [package.extras]
165
+ develop = ["pytest", "pytest-cov", "pytest-mock", "pytest-asyncio", "mock", "requests", "aiohttp"]
166
+
167
+ [[package]]
168
+ name = "elasticsearch"
169
+ version = "8.1.0"
170
+ description = "Python client for Elasticsearch"
171
+ category = "main"
172
+ optional = false
173
+ python-versions = ">=3.6, <4"
174
+
175
+ [package.dependencies]
176
+ elastic-transport = ">=8,<9"
177
+
178
+ [package.extras]
179
+ async = ["aiohttp (>=3,<4)"]
180
+ requests = ["requests (>=2.4.0,<3.0.0)"]
181
+
182
  [[package]]
183
  name = "faiss-cpu"
184
  version = "1.7.2"
 
321
  [package.dependencies]
322
  dill = ">=0.3.4"
323
 
324
+ [[package]]
325
+ name = "mypy"
326
+ version = "0.941"
327
+ description = "Optional static typing for Python"
328
+ category = "dev"
329
+ optional = false
330
+ python-versions = ">=3.6"
331
+
332
+ [package.dependencies]
333
+ mypy-extensions = ">=0.4.3"
334
+ tomli = ">=1.1.0"
335
+ typing-extensions = ">=3.10"
336
+
337
+ [package.extras]
338
+ dmypy = ["psutil (>=4.0)"]
339
+ python2 = ["typed-ast (>=1.4.0,<2)"]
340
+ reports = ["lxml"]
341
+
342
+ [[package]]
343
+ name = "mypy-extensions"
344
+ version = "0.4.3"
345
+ description = "Experimental type system extensions for programs checked with the mypy typechecker."
346
+ category = "dev"
347
+ optional = false
348
+ python-versions = "*"
349
+
350
  [[package]]
351
  name = "numpy"
352
  version = "1.22.3"
 
436
  [package.dependencies]
437
  six = ">=1.5"
438
 
439
+ [[package]]
440
+ name = "python-dotenv"
441
+ version = "0.19.2"
442
+ description = "Read key-value pairs from a .env file and set them as environment variables"
443
+ category = "main"
444
+ optional = false
445
+ python-versions = ">=3.5"
446
+
447
+ [package.extras]
448
+ cli = ["click (>=5.0)"]
449
+
450
  [[package]]
451
  name = "pytz"
452
  version = "2021.3"
 
547
  optional = false
548
  python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
549
 
550
+ [[package]]
551
+ name = "tomli"
552
+ version = "2.0.1"
553
+ description = "A lil' TOML parser"
554
+ category = "dev"
555
+ optional = false
556
+ python-versions = ">=3.7"
557
+
558
  [[package]]
559
  name = "torch"
560
  version = "1.11.0"
 
685
  [metadata]
686
  lock-version = "1.1"
687
  python-versions = "^3.8"
688
+ content-hash = "7fadbb5aabac268ecd27c257e2c8f651d26896e78c9cc0ea7e61a8b6ec61c84c"
689
 
690
  [metadata.files]
691
  aiohttp = [
 
802
  {file = "dill-0.3.4-py2.py3-none-any.whl", hash = "sha256:7e40e4a70304fd9ceab3535d36e58791d9c4a776b38ec7f7ec9afc8d3dca4d4f"},
803
  {file = "dill-0.3.4.zip", hash = "sha256:9f9734205146b2b353ab3fec9af0070237b6ddae78452af83d2fca84d739e675"},
804
  ]
805
+ elastic-transport = [
806
+ {file = "elastic-transport-8.1.0.tar.gz", hash = "sha256:769ee4c7b28d270cdbce71359973b88129ac312b13be95b4f7479e35c49d9455"},
807
+ {file = "elastic_transport-8.1.0-py3-none-any.whl", hash = "sha256:0bb2ae3d13348e9e4587ca1f17cd813a528a7cc07f879505f56d69c81823b660"},
808
+ ]
809
+ elasticsearch = [
810
+ {file = "elasticsearch-8.1.0-py3-none-any.whl", hash = "sha256:11e36565dfdf649b7911c2d3cb1f15b99267acfb7f82e94e7613c0323a9936e9"},
811
+ {file = "elasticsearch-8.1.0.tar.gz", hash = "sha256:648d1c707a632279535356d2762cbc63ae728c4633211fe160f43f87a3e1cdcd"},
812
+ ]
813
  faiss-cpu = [
814
  {file = "faiss-cpu-1.7.2.tar.gz", hash = "sha256:f7ea89de997f55764e3710afaf0a457b2529252f99ee63510d4d9348d5b419dd"},
815
  {file = "faiss_cpu-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b7461f989d757917a3e6dc81eb171d0b563eb98d23ebaf7fc6684d0093ba267e"},
 
1001
  {file = "multiprocess-0.70.12.2-py39-none-any.whl", hash = "sha256:6f812a1d3f198b7cacd63983f60e2dc1338bd4450893f90c435067b5a3127e6f"},
1002
  {file = "multiprocess-0.70.12.2.zip", hash = "sha256:206bb9b97b73f87fec1ed15a19f8762950256aa84225450abc7150d02855a083"},
1003
  ]
1004
+ mypy = [
1005
+ {file = "mypy-0.941-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:98f61aad0bb54f797b17da5b82f419e6ce214de0aa7e92211ebee9e40eb04276"},
1006
+ {file = "mypy-0.941-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6a8e1f63357851444940351e98fb3252956a15f2cabe3d698316d7a2d1f1f896"},
1007
+ {file = "mypy-0.941-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b30d29251dff4c59b2e5a1fa1bab91ff3e117b4658cb90f76d97702b7a2ae699"},
1008
+ {file = "mypy-0.941-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8eaf55fdf99242a1c8c792247c455565447353914023878beadb79600aac4a2a"},
1009
+ {file = "mypy-0.941-cp310-cp310-win_amd64.whl", hash = "sha256:080097eee5393fd740f32c63f9343580aaa0fb1cda0128fd859dfcf081321c3d"},
1010
+ {file = "mypy-0.941-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f79137d012ff3227866222049af534f25354c07a0d6b9a171dba9f1d6a1fdef4"},
1011
+ {file = "mypy-0.941-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8e5974583a77d630a5868eee18f85ac3093caf76e018c510aeb802b9973304ce"},
1012
+ {file = "mypy-0.941-cp36-cp36m-win_amd64.whl", hash = "sha256:0dd441fbacf48e19dc0c5c42fafa72b8e1a0ba0a39309c1af9c84b9397d9b15a"},
1013
+ {file = "mypy-0.941-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:0d3bcbe146247997e03bf030122000998b076b3ac6925b0b6563f46d1ce39b50"},
1014
+ {file = "mypy-0.941-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3bada0cf7b6965627954b3a128903a87cac79a79ccd83b6104912e723ef16c7b"},
1015
+ {file = "mypy-0.941-cp37-cp37m-win_amd64.whl", hash = "sha256:eea10982b798ff0ccc3b9e7e42628f932f552c5845066970e67cd6858655d52c"},
1016
+ {file = "mypy-0.941-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:108f3c7e14a038cf097d2444fa0155462362c6316e3ecb2d70f6dd99cd36084d"},
1017
+ {file = "mypy-0.941-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d61b73c01fc1de799226963f2639af831307fe1556b04b7c25e2b6c267a3bc76"},
1018
+ {file = "mypy-0.941-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:42c216a33d2bdba08098acaf5bae65b0c8196afeb535ef4b870919a788a27259"},
1019
+ {file = "mypy-0.941-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:fc5ecff5a3bbfbe20091b1cad82815507f5ae9c380a3a9bf40f740c70ce30a9b"},
1020
+ {file = "mypy-0.941-cp38-cp38-win_amd64.whl", hash = "sha256:bf446223b2e0e4f0a4792938e8d885e8a896834aded5f51be5c3c69566495540"},
1021
+ {file = "mypy-0.941-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:745071762f32f65e77de6df699366d707fad6c132a660d1342077cbf671ef589"},
1022
+ {file = "mypy-0.941-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:465a6ce9ca6268cadfbc27a2a94ddf0412568a6b27640ced229270be4f5d394d"},
1023
+ {file = "mypy-0.941-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d051ce0946521eba48e19b25f27f98e5ce4dbc91fff296de76240c46b4464df0"},
1024
+ {file = "mypy-0.941-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:818cfc51c25a5dbfd0705f3ac1919fff6971eb0c02e6f1a1f6a017a42405a7c0"},
1025
+ {file = "mypy-0.941-cp39-cp39-win_amd64.whl", hash = "sha256:b2ce2788df0c066c2ff4ba7190fa84f18937527c477247e926abeb9b1168b8cc"},
1026
+ {file = "mypy-0.941-py3-none-any.whl", hash = "sha256:3cf77f138efb31727ee7197bc824c9d6d7039204ed96756cc0f9ca7d8e8fc2a4"},
1027
+ {file = "mypy-0.941.tar.gz", hash = "sha256:cbcc691d8b507d54cb2b8521f0a2a3d4daa477f62fe77f0abba41e5febb377b7"},
1028
+ ]
1029
+ mypy-extensions = [
1030
+ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"},
1031
+ {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"},
1032
+ ]
1033
  numpy = [
1034
  {file = "numpy-1.22.3-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:92bfa69cfbdf7dfc3040978ad09a48091143cffb778ec3b03fa170c494118d75"},
1035
  {file = "numpy-1.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8251ed96f38b47b4295b1ae51631de7ffa8260b5b087808ef09a39a9d66c97ab"},
1036
  {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a3aecd3b997bf452a2dedb11f4e79bc5bfd21a1d4cc760e703c31d57c84b3e"},
1037
  {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3bae1a2ed00e90b3ba5f7bd0a7c7999b55d609e0c54ceb2b076a25e345fa9f4"},
 
1038
  {file = "numpy-1.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:08d9b008d0156c70dc392bb3ab3abb6e7a711383c3247b410b39962263576cd4"},
1039
  {file = "numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:201b4d0552831f7250a08d3b38de0d989d6f6e4658b709a02a73c524ccc6ffce"},
1040
  {file = "numpy-1.22.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8c1f39caad2c896bc0018f699882b345b2a63708008be29b1f355ebf6f933fe"},
 
1126
  {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
1127
  {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
1128
  ]
1129
+ python-dotenv = [
1130
+ {file = "python-dotenv-0.19.2.tar.gz", hash = "sha256:a5de49a31e953b45ff2d2fd434bbc2670e8db5273606c1e737cc6b93eff3655f"},
1131
+ {file = "python_dotenv-0.19.2-py2.py3-none-any.whl", hash = "sha256:32b2bdc1873fd3a3c346da1c6db83d0053c3c62f28f1f38516070c4c8971b1d3"},
1132
+ ]
1133
  pytz = [
1134
  {file = "pytz-2021.3-py2.py3-none-any.whl", hash = "sha256:3672058bc3453457b622aab7a1c3bfd5ab0bdae451512f6cf25f64ed37f5b87c"},
1135
  {file = "pytz-2021.3.tar.gz", hash = "sha256:acad2d8b20a1af07d4e4c9d2e9285c5ed9104354062f275f3fcd88dcef4f1326"},
 
1304
  {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
1305
  {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
1306
  ]
1307
+ tomli = [
1308
+ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
1309
+ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
1310
+ ]
1311
  torch = [
1312
  {file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
1313
  {file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
pyproject.toml CHANGED
@@ -11,10 +11,28 @@ transformers = "^4.17.0"
11
  torch = "^1.11.0"
12
  datasets = "^1.18.4"
13
  faiss-cpu = "^1.7.2"
 
 
14
 
15
  [tool.poetry.dev-dependencies]
16
  flake8 = "^4.0.1"
17
  autopep8 = "^1.6.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  [build-system]
20
  requires = ["poetry-core>=1.0.0"]
 
11
  torch = "^1.11.0"
12
  datasets = "^1.18.4"
13
  faiss-cpu = "^1.7.2"
14
+ python-dotenv = "^0.19.2"
15
+ elasticsearch = "^8.1.0"
16
 
17
  [tool.poetry.dev-dependencies]
18
  flake8 = "^4.0.1"
19
  autopep8 = "^1.6.0"
20
+ mypy = "^0.941"
21
+
22
+ [tool.mypy]
23
+ no_implicit_optional=true
24
+
25
+ [[tool.mypy.overrides]]
26
+ module = [
27
+ "transformers",
28
+ "datasets",
29
+ ]
30
+ ignore_missing_imports = true
31
+
32
+
33
+ [tool.isort]
34
+ profile = "black"
35
+
36
 
37
  [build-system]
38
  requires = ["poetry-core>=1.0.0"]
src/es_retriever.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ class ESRetriever:
2
+ def __init__(self, dataset_name: str = "GroNLP/ik-nlp-22_slp"):
3
+ self.dataset_name = dataset_name
4
+
5
+ def _setup_data(self):
6
+ pass
7
+
8
+ def retrieve(self, query: str, k: int):
9
+ pass
base_model/evaluate.py β†’ src/evaluation.py RENAMED
@@ -1,15 +1,16 @@
1
  from typing import Callable, List
2
 
3
- from base_model.string_utils import lower, remove_articles, remove_punc, white_space_fix
 
4
 
5
 
6
- def normalize_text(inp: str, preprocessing_functions: List[Callable[[str], str]]):
7
  for fun in preprocessing_functions:
8
  inp = fun(inp)
9
  return inp
10
 
11
 
12
- def normalize_text_default(inp: str) -> str:
13
  """Preprocesses the sentence string by normalizing.
14
 
15
  Args:
@@ -21,10 +22,10 @@ def normalize_text_default(inp: str) -> str:
21
 
22
  steps = [remove_articles, white_space_fix, remove_punc, lower]
23
 
24
- return normalize_text(inp, steps)
25
 
26
 
27
- def compute_exact_match(prediction: str, answer: str) -> int:
28
  """Computes exact match for sentences.
29
 
30
  Args:
@@ -34,10 +35,10 @@ def compute_exact_match(prediction: str, answer: str) -> int:
34
  Returns:
35
  int: 1 for exact match, 0 for not
36
  """
37
- return int(normalize_text_default(prediction) == normalize_text_default(answer))
38
 
39
 
40
- def compute_f1(prediction: str, answer: str) -> float:
41
  """Computes F1-score on token overlap for sentences.
42
 
43
  Args:
@@ -47,8 +48,8 @@ def compute_f1(prediction: str, answer: str) -> float:
47
  Returns:
48
  boolean: the f1 score
49
  """
50
- pred_tokens = normalize_text_default(prediction).split()
51
- answer_tokens = normalize_text_default(answer).split()
52
 
53
  if len(pred_tokens) == 0 or len(answer_tokens) == 0:
54
  return int(pred_tokens == answer_tokens)
 
1
  from typing import Callable, List
2
 
3
+ from src.utils.string_utils import (lower, remove_articles, remove_punc,
4
+ white_space_fix)
5
 
6
 
7
+ def _normalize_text(inp: str, preprocessing_functions: List[Callable[[str], str]]):
8
  for fun in preprocessing_functions:
9
  inp = fun(inp)
10
  return inp
11
 
12
 
13
+ def _normalize_text_default(inp: str) -> str:
14
  """Preprocesses the sentence string by normalizing.
15
 
16
  Args:
 
22
 
23
  steps = [remove_articles, white_space_fix, remove_punc, lower]
24
 
25
+ return _normalize_text(inp, steps)
26
 
27
 
28
+ def exact_match(prediction: str, answer: str) -> int:
29
  """Computes exact match for sentences.
30
 
31
  Args:
 
35
  Returns:
36
  int: 1 for exact match, 0 for not
37
  """
38
+ return int(_normalize_text_default(prediction) == _normalize_text_default(answer))
39
 
40
 
41
+ def f1(prediction: str, answer: str) -> float:
42
  """Computes F1-score on token overlap for sentences.
43
 
44
  Args:
 
48
  Returns:
49
  boolean: the f1 score
50
  """
51
+ pred_tokens = _normalize_text_default(prediction).split()
52
+ answer_tokens = _normalize_text_default(answer).split()
53
 
54
  if len(pred_tokens) == 0 or len(answer_tokens) == 0:
55
  return int(pred_tokens == answer_tokens)
base_model/retriever.py β†’ src/fais_retriever.py RENAMED
@@ -1,23 +1,19 @@
1
- from transformers import (
2
- DPRContextEncoder,
3
- DPRContextEncoderTokenizer,
4
- DPRQuestionEncoder,
5
- DPRQuestionEncoderTokenizer,
6
- )
7
- from datasets import load_dataset
8
- import torch
9
- import os.path
10
-
11
- import evaluate
12
-
13
  # Hacky fix for FAISS error on macOS
14
  # See https://stackoverflow.com/a/63374568/4545692
15
  import os
 
 
 
 
 
 
 
 
16
 
17
  os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
18
 
19
 
20
- class Retriever:
21
  """A class used to retrieve relevant documents based on some query.
22
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
23
  """
@@ -67,12 +63,13 @@ class Retriever:
67
  embeddings.
68
  """
69
  # Load dataset
70
- ds = load_dataset(dataset_name, name="paragraphs")["train"]
 
71
  print(ds)
72
 
73
  if os.path.exists(embedding_path):
74
  # If we already have FAISS embeddings, load them from disk
75
- ds.load_faiss_index('embeddings', embedding_path)
76
  return ds
77
  else:
78
  # If there are no FAISS embeddings, generate them
@@ -85,7 +82,7 @@ class Retriever:
85
  return {"embeddings": enc}
86
 
87
  # Add FAISS embeddings
88
- ds_with_embeddings = ds.map(embed)
89
 
90
  ds_with_embeddings.add_faiss_index(column="embeddings")
91
 
@@ -141,9 +138,9 @@ class Retriever:
141
  scores += score[0]
142
  predictions.append(result['text'][0])
143
 
144
- exact_matches = [evaluate.compute_exact_match(
145
  predictions[i], answers[i]) for i in range(len(answers))]
146
- f1_scores = [evaluate.compute_f1(
147
  predictions[i], answers[i]) for i in range(len(answers))]
148
 
149
  return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Hacky fix for FAISS error on macOS
2
  # See https://stackoverflow.com/a/63374568/4545692
3
  import os
4
+ import os.path
5
+
6
+ import torch
7
+ from datasets import load_dataset
8
+ from transformers import (DPRContextEncoder, DPRContextEncoderTokenizer,
9
+ DPRQuestionEncoder, DPRQuestionEncoderTokenizer)
10
+
11
+ from src.evaluation import exact_match, f1
12
 
13
  os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
14
 
15
 
16
+ class FAISRetriever:
17
  """A class used to retrieve relevant documents based on some query.
18
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
19
  """
 
63
  embeddings.
64
  """
65
  # Load dataset
66
+ ds = load_dataset(dataset_name, name="paragraphs")[
67
+ "train"] # type: ignore
68
  print(ds)
69
 
70
  if os.path.exists(embedding_path):
71
  # If we already have FAISS embeddings, load them from disk
72
+ ds.load_faiss_index('embeddings', embedding_path) # type: ignore
73
  return ds
74
  else:
75
  # If there are no FAISS embeddings, generate them
 
82
  return {"embeddings": enc}
83
 
84
  # Add FAISS embeddings
85
+ ds_with_embeddings = ds.map(embed) # type: ignore
86
 
87
  ds_with_embeddings.add_faiss_index(column="embeddings")
88
 
 
138
  scores += score[0]
139
  predictions.append(result['text'][0])
140
 
141
+ exact_matches = [exact_match(
142
  predictions[i], answers[i]) for i in range(len(answers))]
143
+ f1_scores = [f1(
144
  predictions[i], answers[i]) for i in range(len(answers))]
145
 
146
  return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
{base_model β†’ src}/reader.py RENAMED
File without changes
src/utils/log.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+
9
+ def get_logger():
10
+ # creates a default logger for the project
11
+ logger = logging.getLogger("Flashcards")
12
+
13
+ log_level = os.getenv("LOG_LEVEL", "INFO")
14
+ logger.setLevel(log_level)
15
+
16
+ # Log format
17
+ formatter = logging.Formatter(
18
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
19
+
20
+ # file handler
21
+ fh = logging.FileHandler("logs.log")
22
+ fh.setFormatter(formatter)
23
+
24
+ # stout
25
+ ch = logging.StreamHandler()
26
+ ch.setFormatter(formatter)
27
+
28
+ logger.addHandler(fh)
29
+ logger.addHandler(ch)
30
+
31
+ return logger
{base_model β†’ src/utils}/string_utils.py RENAMED
File without changes