GGroenendaal
commited on
Commit
Β·
51dabd6
1
Parent(s):
aa426fb
refactoring of code
Browse files- .env.example +4 -0
- README.md +3 -0
- base_model/main.py β main.py +10 -7
- poetry.lock +121 -2
- pyproject.toml +18 -0
- src/es_retriever.py +9 -0
- base_model/evaluate.py β src/evaluation.py +10 -9
- base_model/retriever.py β src/fais_retriever.py +15 -18
- {base_model β src}/reader.py +0 -0
- src/utils/log.py +31 -0
- {base_model β src/utils}/string_utils.py +0 -0
.env.example
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ELASTIC_USERNAME=elastic
|
2 |
+
ELASTIC_PASSWORD=<password>
|
3 |
+
|
4 |
+
LOG_LEVEL=INFO
|
README.md
CHANGED
@@ -73,3 +73,6 @@ poetry run python main.py
|
|
73 |
> shows that MT systems perform worse when they are asked to translate sentences
|
74 |
> that describe people with non-stereotypical gender roles, like "The doctor
|
75 |
> asked the nurse to help her in the > operation".
|
|
|
|
|
|
|
|
73 |
> shows that MT systems perform worse when they are asked to translate sentences
|
74 |
> that describe people with non-stereotypical gender roles, like "The doctor
|
75 |
> asked the nurse to help her in the > operation".
|
76 |
+
|
77 |
+
|
78 |
+
## Setting up elastic search.
|
base_model/main.py β main.py
RENAMED
@@ -1,20 +1,23 @@
|
|
1 |
-
from
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
if __name__ == '__main__':
|
5 |
# Initialize retriever
|
6 |
-
r =
|
7 |
|
8 |
# Retrieve example
|
9 |
scores, result = r.retrieve(
|
10 |
"What is the perplexity of a language model?")
|
11 |
|
12 |
for i, score in enumerate(scores):
|
13 |
-
|
14 |
-
|
15 |
-
print() # Newline
|
16 |
|
17 |
# Compute overall performance
|
18 |
exact_match, f1_score = r.evaluate()
|
19 |
-
|
20 |
-
|
|
|
1 |
+
from src.fais_retriever import FAISRetriever
|
2 |
+
from src.utils.log import get_logger
|
3 |
+
|
4 |
+
|
5 |
+
logger = get_logger()
|
6 |
|
7 |
|
8 |
if __name__ == '__main__':
|
9 |
# Initialize retriever
|
10 |
+
r = FAISRetriever()
|
11 |
|
12 |
# Retrieve example
|
13 |
scores, result = r.retrieve(
|
14 |
"What is the perplexity of a language model?")
|
15 |
|
16 |
for i, score in enumerate(scores):
|
17 |
+
logger.info(f"Result {i+1} (score: {score:.02f}):")
|
18 |
+
logger.info(result['text'][i])
|
|
|
19 |
|
20 |
# Compute overall performance
|
21 |
exact_match, f1_score = r.evaluate()
|
22 |
+
logger.info(f"Exact match: {exact_match:.02f}\n"
|
23 |
+
f"F1-score: {f1_score:.02f}")
|
poetry.lock
CHANGED
@@ -149,6 +149,36 @@ python-versions = ">=2.7, !=3.0.*"
|
|
149 |
[package.extras]
|
150 |
graph = ["objgraph (>=1.7.2)"]
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
[[package]]
|
153 |
name = "faiss-cpu"
|
154 |
version = "1.7.2"
|
@@ -291,6 +321,32 @@ python-versions = "*"
|
|
291 |
[package.dependencies]
|
292 |
dill = ">=0.3.4"
|
293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
[[package]]
|
295 |
name = "numpy"
|
296 |
version = "1.22.3"
|
@@ -380,6 +436,17 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
|
380 |
[package.dependencies]
|
381 |
six = ">=1.5"
|
382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
[[package]]
|
384 |
name = "pytz"
|
385 |
version = "2021.3"
|
@@ -480,6 +547,14 @@ category = "dev"
|
|
480 |
optional = false
|
481 |
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
|
482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
[[package]]
|
484 |
name = "torch"
|
485 |
version = "1.11.0"
|
@@ -610,7 +685,7 @@ multidict = ">=4.0"
|
|
610 |
[metadata]
|
611 |
lock-version = "1.1"
|
612 |
python-versions = "^3.8"
|
613 |
-
content-hash = "
|
614 |
|
615 |
[metadata.files]
|
616 |
aiohttp = [
|
@@ -727,6 +802,14 @@ dill = [
|
|
727 |
{file = "dill-0.3.4-py2.py3-none-any.whl", hash = "sha256:7e40e4a70304fd9ceab3535d36e58791d9c4a776b38ec7f7ec9afc8d3dca4d4f"},
|
728 |
{file = "dill-0.3.4.zip", hash = "sha256:9f9734205146b2b353ab3fec9af0070237b6ddae78452af83d2fca84d739e675"},
|
729 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
730 |
faiss-cpu = [
|
731 |
{file = "faiss-cpu-1.7.2.tar.gz", hash = "sha256:f7ea89de997f55764e3710afaf0a457b2529252f99ee63510d4d9348d5b419dd"},
|
732 |
{file = "faiss_cpu-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b7461f989d757917a3e6dc81eb171d0b563eb98d23ebaf7fc6684d0093ba267e"},
|
@@ -918,12 +1001,40 @@ multiprocess = [
|
|
918 |
{file = "multiprocess-0.70.12.2-py39-none-any.whl", hash = "sha256:6f812a1d3f198b7cacd63983f60e2dc1338bd4450893f90c435067b5a3127e6f"},
|
919 |
{file = "multiprocess-0.70.12.2.zip", hash = "sha256:206bb9b97b73f87fec1ed15a19f8762950256aa84225450abc7150d02855a083"},
|
920 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
921 |
numpy = [
|
922 |
{file = "numpy-1.22.3-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:92bfa69cfbdf7dfc3040978ad09a48091143cffb778ec3b03fa170c494118d75"},
|
923 |
{file = "numpy-1.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8251ed96f38b47b4295b1ae51631de7ffa8260b5b087808ef09a39a9d66c97ab"},
|
924 |
{file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a3aecd3b997bf452a2dedb11f4e79bc5bfd21a1d4cc760e703c31d57c84b3e"},
|
925 |
{file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3bae1a2ed00e90b3ba5f7bd0a7c7999b55d609e0c54ceb2b076a25e345fa9f4"},
|
926 |
-
{file = "numpy-1.22.3-cp310-cp310-win32.whl", hash = "sha256:f950f8845b480cffe522913d35567e29dd381b0dc7e4ce6a4a9f9156417d2430"},
|
927 |
{file = "numpy-1.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:08d9b008d0156c70dc392bb3ab3abb6e7a711383c3247b410b39962263576cd4"},
|
928 |
{file = "numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:201b4d0552831f7250a08d3b38de0d989d6f6e4658b709a02a73c524ccc6ffce"},
|
929 |
{file = "numpy-1.22.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8c1f39caad2c896bc0018f699882b345b2a63708008be29b1f355ebf6f933fe"},
|
@@ -1015,6 +1126,10 @@ python-dateutil = [
|
|
1015 |
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
|
1016 |
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
|
1017 |
]
|
|
|
|
|
|
|
|
|
1018 |
pytz = [
|
1019 |
{file = "pytz-2021.3-py2.py3-none-any.whl", hash = "sha256:3672058bc3453457b622aab7a1c3bfd5ab0bdae451512f6cf25f64ed37f5b87c"},
|
1020 |
{file = "pytz-2021.3.tar.gz", hash = "sha256:acad2d8b20a1af07d4e4c9d2e9285c5ed9104354062f275f3fcd88dcef4f1326"},
|
@@ -1189,6 +1304,10 @@ toml = [
|
|
1189 |
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
1190 |
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
1191 |
]
|
|
|
|
|
|
|
|
|
1192 |
torch = [
|
1193 |
{file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
|
1194 |
{file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
|
|
|
149 |
[package.extras]
|
150 |
graph = ["objgraph (>=1.7.2)"]
|
151 |
|
152 |
+
[[package]]
|
153 |
+
name = "elastic-transport"
|
154 |
+
version = "8.1.0"
|
155 |
+
description = "Transport classes and utilities shared among Python Elastic client libraries"
|
156 |
+
category = "main"
|
157 |
+
optional = false
|
158 |
+
python-versions = ">=3.6"
|
159 |
+
|
160 |
+
[package.dependencies]
|
161 |
+
certifi = "*"
|
162 |
+
urllib3 = ">=1.26.2,<2"
|
163 |
+
|
164 |
+
[package.extras]
|
165 |
+
develop = ["pytest", "pytest-cov", "pytest-mock", "pytest-asyncio", "mock", "requests", "aiohttp"]
|
166 |
+
|
167 |
+
[[package]]
|
168 |
+
name = "elasticsearch"
|
169 |
+
version = "8.1.0"
|
170 |
+
description = "Python client for Elasticsearch"
|
171 |
+
category = "main"
|
172 |
+
optional = false
|
173 |
+
python-versions = ">=3.6, <4"
|
174 |
+
|
175 |
+
[package.dependencies]
|
176 |
+
elastic-transport = ">=8,<9"
|
177 |
+
|
178 |
+
[package.extras]
|
179 |
+
async = ["aiohttp (>=3,<4)"]
|
180 |
+
requests = ["requests (>=2.4.0,<3.0.0)"]
|
181 |
+
|
182 |
[[package]]
|
183 |
name = "faiss-cpu"
|
184 |
version = "1.7.2"
|
|
|
321 |
[package.dependencies]
|
322 |
dill = ">=0.3.4"
|
323 |
|
324 |
+
[[package]]
|
325 |
+
name = "mypy"
|
326 |
+
version = "0.941"
|
327 |
+
description = "Optional static typing for Python"
|
328 |
+
category = "dev"
|
329 |
+
optional = false
|
330 |
+
python-versions = ">=3.6"
|
331 |
+
|
332 |
+
[package.dependencies]
|
333 |
+
mypy-extensions = ">=0.4.3"
|
334 |
+
tomli = ">=1.1.0"
|
335 |
+
typing-extensions = ">=3.10"
|
336 |
+
|
337 |
+
[package.extras]
|
338 |
+
dmypy = ["psutil (>=4.0)"]
|
339 |
+
python2 = ["typed-ast (>=1.4.0,<2)"]
|
340 |
+
reports = ["lxml"]
|
341 |
+
|
342 |
+
[[package]]
|
343 |
+
name = "mypy-extensions"
|
344 |
+
version = "0.4.3"
|
345 |
+
description = "Experimental type system extensions for programs checked with the mypy typechecker."
|
346 |
+
category = "dev"
|
347 |
+
optional = false
|
348 |
+
python-versions = "*"
|
349 |
+
|
350 |
[[package]]
|
351 |
name = "numpy"
|
352 |
version = "1.22.3"
|
|
|
436 |
[package.dependencies]
|
437 |
six = ">=1.5"
|
438 |
|
439 |
+
[[package]]
|
440 |
+
name = "python-dotenv"
|
441 |
+
version = "0.19.2"
|
442 |
+
description = "Read key-value pairs from a .env file and set them as environment variables"
|
443 |
+
category = "main"
|
444 |
+
optional = false
|
445 |
+
python-versions = ">=3.5"
|
446 |
+
|
447 |
+
[package.extras]
|
448 |
+
cli = ["click (>=5.0)"]
|
449 |
+
|
450 |
[[package]]
|
451 |
name = "pytz"
|
452 |
version = "2021.3"
|
|
|
547 |
optional = false
|
548 |
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
|
549 |
|
550 |
+
[[package]]
|
551 |
+
name = "tomli"
|
552 |
+
version = "2.0.1"
|
553 |
+
description = "A lil' TOML parser"
|
554 |
+
category = "dev"
|
555 |
+
optional = false
|
556 |
+
python-versions = ">=3.7"
|
557 |
+
|
558 |
[[package]]
|
559 |
name = "torch"
|
560 |
version = "1.11.0"
|
|
|
685 |
[metadata]
|
686 |
lock-version = "1.1"
|
687 |
python-versions = "^3.8"
|
688 |
+
content-hash = "7fadbb5aabac268ecd27c257e2c8f651d26896e78c9cc0ea7e61a8b6ec61c84c"
|
689 |
|
690 |
[metadata.files]
|
691 |
aiohttp = [
|
|
|
802 |
{file = "dill-0.3.4-py2.py3-none-any.whl", hash = "sha256:7e40e4a70304fd9ceab3535d36e58791d9c4a776b38ec7f7ec9afc8d3dca4d4f"},
|
803 |
{file = "dill-0.3.4.zip", hash = "sha256:9f9734205146b2b353ab3fec9af0070237b6ddae78452af83d2fca84d739e675"},
|
804 |
]
|
805 |
+
elastic-transport = [
|
806 |
+
{file = "elastic-transport-8.1.0.tar.gz", hash = "sha256:769ee4c7b28d270cdbce71359973b88129ac312b13be95b4f7479e35c49d9455"},
|
807 |
+
{file = "elastic_transport-8.1.0-py3-none-any.whl", hash = "sha256:0bb2ae3d13348e9e4587ca1f17cd813a528a7cc07f879505f56d69c81823b660"},
|
808 |
+
]
|
809 |
+
elasticsearch = [
|
810 |
+
{file = "elasticsearch-8.1.0-py3-none-any.whl", hash = "sha256:11e36565dfdf649b7911c2d3cb1f15b99267acfb7f82e94e7613c0323a9936e9"},
|
811 |
+
{file = "elasticsearch-8.1.0.tar.gz", hash = "sha256:648d1c707a632279535356d2762cbc63ae728c4633211fe160f43f87a3e1cdcd"},
|
812 |
+
]
|
813 |
faiss-cpu = [
|
814 |
{file = "faiss-cpu-1.7.2.tar.gz", hash = "sha256:f7ea89de997f55764e3710afaf0a457b2529252f99ee63510d4d9348d5b419dd"},
|
815 |
{file = "faiss_cpu-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b7461f989d757917a3e6dc81eb171d0b563eb98d23ebaf7fc6684d0093ba267e"},
|
|
|
1001 |
{file = "multiprocess-0.70.12.2-py39-none-any.whl", hash = "sha256:6f812a1d3f198b7cacd63983f60e2dc1338bd4450893f90c435067b5a3127e6f"},
|
1002 |
{file = "multiprocess-0.70.12.2.zip", hash = "sha256:206bb9b97b73f87fec1ed15a19f8762950256aa84225450abc7150d02855a083"},
|
1003 |
]
|
1004 |
+
mypy = [
|
1005 |
+
{file = "mypy-0.941-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:98f61aad0bb54f797b17da5b82f419e6ce214de0aa7e92211ebee9e40eb04276"},
|
1006 |
+
{file = "mypy-0.941-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6a8e1f63357851444940351e98fb3252956a15f2cabe3d698316d7a2d1f1f896"},
|
1007 |
+
{file = "mypy-0.941-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b30d29251dff4c59b2e5a1fa1bab91ff3e117b4658cb90f76d97702b7a2ae699"},
|
1008 |
+
{file = "mypy-0.941-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8eaf55fdf99242a1c8c792247c455565447353914023878beadb79600aac4a2a"},
|
1009 |
+
{file = "mypy-0.941-cp310-cp310-win_amd64.whl", hash = "sha256:080097eee5393fd740f32c63f9343580aaa0fb1cda0128fd859dfcf081321c3d"},
|
1010 |
+
{file = "mypy-0.941-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f79137d012ff3227866222049af534f25354c07a0d6b9a171dba9f1d6a1fdef4"},
|
1011 |
+
{file = "mypy-0.941-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8e5974583a77d630a5868eee18f85ac3093caf76e018c510aeb802b9973304ce"},
|
1012 |
+
{file = "mypy-0.941-cp36-cp36m-win_amd64.whl", hash = "sha256:0dd441fbacf48e19dc0c5c42fafa72b8e1a0ba0a39309c1af9c84b9397d9b15a"},
|
1013 |
+
{file = "mypy-0.941-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:0d3bcbe146247997e03bf030122000998b076b3ac6925b0b6563f46d1ce39b50"},
|
1014 |
+
{file = "mypy-0.941-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3bada0cf7b6965627954b3a128903a87cac79a79ccd83b6104912e723ef16c7b"},
|
1015 |
+
{file = "mypy-0.941-cp37-cp37m-win_amd64.whl", hash = "sha256:eea10982b798ff0ccc3b9e7e42628f932f552c5845066970e67cd6858655d52c"},
|
1016 |
+
{file = "mypy-0.941-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:108f3c7e14a038cf097d2444fa0155462362c6316e3ecb2d70f6dd99cd36084d"},
|
1017 |
+
{file = "mypy-0.941-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d61b73c01fc1de799226963f2639af831307fe1556b04b7c25e2b6c267a3bc76"},
|
1018 |
+
{file = "mypy-0.941-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:42c216a33d2bdba08098acaf5bae65b0c8196afeb535ef4b870919a788a27259"},
|
1019 |
+
{file = "mypy-0.941-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:fc5ecff5a3bbfbe20091b1cad82815507f5ae9c380a3a9bf40f740c70ce30a9b"},
|
1020 |
+
{file = "mypy-0.941-cp38-cp38-win_amd64.whl", hash = "sha256:bf446223b2e0e4f0a4792938e8d885e8a896834aded5f51be5c3c69566495540"},
|
1021 |
+
{file = "mypy-0.941-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:745071762f32f65e77de6df699366d707fad6c132a660d1342077cbf671ef589"},
|
1022 |
+
{file = "mypy-0.941-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:465a6ce9ca6268cadfbc27a2a94ddf0412568a6b27640ced229270be4f5d394d"},
|
1023 |
+
{file = "mypy-0.941-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d051ce0946521eba48e19b25f27f98e5ce4dbc91fff296de76240c46b4464df0"},
|
1024 |
+
{file = "mypy-0.941-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:818cfc51c25a5dbfd0705f3ac1919fff6971eb0c02e6f1a1f6a017a42405a7c0"},
|
1025 |
+
{file = "mypy-0.941-cp39-cp39-win_amd64.whl", hash = "sha256:b2ce2788df0c066c2ff4ba7190fa84f18937527c477247e926abeb9b1168b8cc"},
|
1026 |
+
{file = "mypy-0.941-py3-none-any.whl", hash = "sha256:3cf77f138efb31727ee7197bc824c9d6d7039204ed96756cc0f9ca7d8e8fc2a4"},
|
1027 |
+
{file = "mypy-0.941.tar.gz", hash = "sha256:cbcc691d8b507d54cb2b8521f0a2a3d4daa477f62fe77f0abba41e5febb377b7"},
|
1028 |
+
]
|
1029 |
+
mypy-extensions = [
|
1030 |
+
{file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"},
|
1031 |
+
{file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"},
|
1032 |
+
]
|
1033 |
numpy = [
|
1034 |
{file = "numpy-1.22.3-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:92bfa69cfbdf7dfc3040978ad09a48091143cffb778ec3b03fa170c494118d75"},
|
1035 |
{file = "numpy-1.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8251ed96f38b47b4295b1ae51631de7ffa8260b5b087808ef09a39a9d66c97ab"},
|
1036 |
{file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a3aecd3b997bf452a2dedb11f4e79bc5bfd21a1d4cc760e703c31d57c84b3e"},
|
1037 |
{file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3bae1a2ed00e90b3ba5f7bd0a7c7999b55d609e0c54ceb2b076a25e345fa9f4"},
|
|
|
1038 |
{file = "numpy-1.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:08d9b008d0156c70dc392bb3ab3abb6e7a711383c3247b410b39962263576cd4"},
|
1039 |
{file = "numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:201b4d0552831f7250a08d3b38de0d989d6f6e4658b709a02a73c524ccc6ffce"},
|
1040 |
{file = "numpy-1.22.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8c1f39caad2c896bc0018f699882b345b2a63708008be29b1f355ebf6f933fe"},
|
|
|
1126 |
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
|
1127 |
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
|
1128 |
]
|
1129 |
+
python-dotenv = [
|
1130 |
+
{file = "python-dotenv-0.19.2.tar.gz", hash = "sha256:a5de49a31e953b45ff2d2fd434bbc2670e8db5273606c1e737cc6b93eff3655f"},
|
1131 |
+
{file = "python_dotenv-0.19.2-py2.py3-none-any.whl", hash = "sha256:32b2bdc1873fd3a3c346da1c6db83d0053c3c62f28f1f38516070c4c8971b1d3"},
|
1132 |
+
]
|
1133 |
pytz = [
|
1134 |
{file = "pytz-2021.3-py2.py3-none-any.whl", hash = "sha256:3672058bc3453457b622aab7a1c3bfd5ab0bdae451512f6cf25f64ed37f5b87c"},
|
1135 |
{file = "pytz-2021.3.tar.gz", hash = "sha256:acad2d8b20a1af07d4e4c9d2e9285c5ed9104354062f275f3fcd88dcef4f1326"},
|
|
|
1304 |
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
1305 |
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
1306 |
]
|
1307 |
+
tomli = [
|
1308 |
+
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
1309 |
+
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
1310 |
+
]
|
1311 |
torch = [
|
1312 |
{file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
|
1313 |
{file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
|
pyproject.toml
CHANGED
@@ -11,10 +11,28 @@ transformers = "^4.17.0"
|
|
11 |
torch = "^1.11.0"
|
12 |
datasets = "^1.18.4"
|
13 |
faiss-cpu = "^1.7.2"
|
|
|
|
|
14 |
|
15 |
[tool.poetry.dev-dependencies]
|
16 |
flake8 = "^4.0.1"
|
17 |
autopep8 = "^1.6.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
[build-system]
|
20 |
requires = ["poetry-core>=1.0.0"]
|
|
|
11 |
torch = "^1.11.0"
|
12 |
datasets = "^1.18.4"
|
13 |
faiss-cpu = "^1.7.2"
|
14 |
+
python-dotenv = "^0.19.2"
|
15 |
+
elasticsearch = "^8.1.0"
|
16 |
|
17 |
[tool.poetry.dev-dependencies]
|
18 |
flake8 = "^4.0.1"
|
19 |
autopep8 = "^1.6.0"
|
20 |
+
mypy = "^0.941"
|
21 |
+
|
22 |
+
[tool.mypy]
|
23 |
+
no_implicit_optional=true
|
24 |
+
|
25 |
+
[[tool.mypy.overrides]]
|
26 |
+
module = [
|
27 |
+
"transformers",
|
28 |
+
"datasets",
|
29 |
+
]
|
30 |
+
ignore_missing_imports = true
|
31 |
+
|
32 |
+
|
33 |
+
[tool.isort]
|
34 |
+
profile = "black"
|
35 |
+
|
36 |
|
37 |
[build-system]
|
38 |
requires = ["poetry-core>=1.0.0"]
|
src/es_retriever.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class ESRetriever:
|
2 |
+
def __init__(self, dataset_name: str = "GroNLP/ik-nlp-22_slp"):
|
3 |
+
self.dataset_name = dataset_name
|
4 |
+
|
5 |
+
def _setup_data(self):
|
6 |
+
pass
|
7 |
+
|
8 |
+
def retrieve(self, query: str, k: int):
|
9 |
+
pass
|
base_model/evaluate.py β src/evaluation.py
RENAMED
@@ -1,15 +1,16 @@
|
|
1 |
from typing import Callable, List
|
2 |
|
3 |
-
from
|
|
|
4 |
|
5 |
|
6 |
-
def
|
7 |
for fun in preprocessing_functions:
|
8 |
inp = fun(inp)
|
9 |
return inp
|
10 |
|
11 |
|
12 |
-
def
|
13 |
"""Preprocesses the sentence string by normalizing.
|
14 |
|
15 |
Args:
|
@@ -21,10 +22,10 @@ def normalize_text_default(inp: str) -> str:
|
|
21 |
|
22 |
steps = [remove_articles, white_space_fix, remove_punc, lower]
|
23 |
|
24 |
-
return
|
25 |
|
26 |
|
27 |
-
def
|
28 |
"""Computes exact match for sentences.
|
29 |
|
30 |
Args:
|
@@ -34,10 +35,10 @@ def compute_exact_match(prediction: str, answer: str) -> int:
|
|
34 |
Returns:
|
35 |
int: 1 for exact match, 0 for not
|
36 |
"""
|
37 |
-
return int(
|
38 |
|
39 |
|
40 |
-
def
|
41 |
"""Computes F1-score on token overlap for sentences.
|
42 |
|
43 |
Args:
|
@@ -47,8 +48,8 @@ def compute_f1(prediction: str, answer: str) -> float:
|
|
47 |
Returns:
|
48 |
boolean: the f1 score
|
49 |
"""
|
50 |
-
pred_tokens =
|
51 |
-
answer_tokens =
|
52 |
|
53 |
if len(pred_tokens) == 0 or len(answer_tokens) == 0:
|
54 |
return int(pred_tokens == answer_tokens)
|
|
|
1 |
from typing import Callable, List
|
2 |
|
3 |
+
from src.utils.string_utils import (lower, remove_articles, remove_punc,
|
4 |
+
white_space_fix)
|
5 |
|
6 |
|
7 |
+
def _normalize_text(inp: str, preprocessing_functions: List[Callable[[str], str]]):
|
8 |
for fun in preprocessing_functions:
|
9 |
inp = fun(inp)
|
10 |
return inp
|
11 |
|
12 |
|
13 |
+
def _normalize_text_default(inp: str) -> str:
|
14 |
"""Preprocesses the sentence string by normalizing.
|
15 |
|
16 |
Args:
|
|
|
22 |
|
23 |
steps = [remove_articles, white_space_fix, remove_punc, lower]
|
24 |
|
25 |
+
return _normalize_text(inp, steps)
|
26 |
|
27 |
|
28 |
+
def exact_match(prediction: str, answer: str) -> int:
|
29 |
"""Computes exact match for sentences.
|
30 |
|
31 |
Args:
|
|
|
35 |
Returns:
|
36 |
int: 1 for exact match, 0 for not
|
37 |
"""
|
38 |
+
return int(_normalize_text_default(prediction) == _normalize_text_default(answer))
|
39 |
|
40 |
|
41 |
+
def f1(prediction: str, answer: str) -> float:
|
42 |
"""Computes F1-score on token overlap for sentences.
|
43 |
|
44 |
Args:
|
|
|
48 |
Returns:
|
49 |
boolean: the f1 score
|
50 |
"""
|
51 |
+
pred_tokens = _normalize_text_default(prediction).split()
|
52 |
+
answer_tokens = _normalize_text_default(answer).split()
|
53 |
|
54 |
if len(pred_tokens) == 0 or len(answer_tokens) == 0:
|
55 |
return int(pred_tokens == answer_tokens)
|
base_model/retriever.py β src/fais_retriever.py
RENAMED
@@ -1,23 +1,19 @@
|
|
1 |
-
from transformers import (
|
2 |
-
DPRContextEncoder,
|
3 |
-
DPRContextEncoderTokenizer,
|
4 |
-
DPRQuestionEncoder,
|
5 |
-
DPRQuestionEncoderTokenizer,
|
6 |
-
)
|
7 |
-
from datasets import load_dataset
|
8 |
-
import torch
|
9 |
-
import os.path
|
10 |
-
|
11 |
-
import evaluate
|
12 |
-
|
13 |
# Hacky fix for FAISS error on macOS
|
14 |
# See https://stackoverflow.com/a/63374568/4545692
|
15 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
18 |
|
19 |
|
20 |
-
class
|
21 |
"""A class used to retrieve relevant documents based on some query.
|
22 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
23 |
"""
|
@@ -67,12 +63,13 @@ class Retriever:
|
|
67 |
embeddings.
|
68 |
"""
|
69 |
# Load dataset
|
70 |
-
ds = load_dataset(dataset_name, name="paragraphs")[
|
|
|
71 |
print(ds)
|
72 |
|
73 |
if os.path.exists(embedding_path):
|
74 |
# If we already have FAISS embeddings, load them from disk
|
75 |
-
ds.load_faiss_index('embeddings', embedding_path)
|
76 |
return ds
|
77 |
else:
|
78 |
# If there are no FAISS embeddings, generate them
|
@@ -85,7 +82,7 @@ class Retriever:
|
|
85 |
return {"embeddings": enc}
|
86 |
|
87 |
# Add FAISS embeddings
|
88 |
-
ds_with_embeddings = ds.map(embed)
|
89 |
|
90 |
ds_with_embeddings.add_faiss_index(column="embeddings")
|
91 |
|
@@ -141,9 +138,9 @@ class Retriever:
|
|
141 |
scores += score[0]
|
142 |
predictions.append(result['text'][0])
|
143 |
|
144 |
-
exact_matches = [
|
145 |
predictions[i], answers[i]) for i in range(len(answers))]
|
146 |
-
f1_scores = [
|
147 |
predictions[i], answers[i]) for i in range(len(answers))]
|
148 |
|
149 |
return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Hacky fix for FAISS error on macOS
|
2 |
# See https://stackoverflow.com/a/63374568/4545692
|
3 |
import os
|
4 |
+
import os.path
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from datasets import load_dataset
|
8 |
+
from transformers import (DPRContextEncoder, DPRContextEncoderTokenizer,
|
9 |
+
DPRQuestionEncoder, DPRQuestionEncoderTokenizer)
|
10 |
+
|
11 |
+
from src.evaluation import exact_match, f1
|
12 |
|
13 |
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
14 |
|
15 |
|
16 |
+
class FAISRetriever:
|
17 |
"""A class used to retrieve relevant documents based on some query.
|
18 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
19 |
"""
|
|
|
63 |
embeddings.
|
64 |
"""
|
65 |
# Load dataset
|
66 |
+
ds = load_dataset(dataset_name, name="paragraphs")[
|
67 |
+
"train"] # type: ignore
|
68 |
print(ds)
|
69 |
|
70 |
if os.path.exists(embedding_path):
|
71 |
# If we already have FAISS embeddings, load them from disk
|
72 |
+
ds.load_faiss_index('embeddings', embedding_path) # type: ignore
|
73 |
return ds
|
74 |
else:
|
75 |
# If there are no FAISS embeddings, generate them
|
|
|
82 |
return {"embeddings": enc}
|
83 |
|
84 |
# Add FAISS embeddings
|
85 |
+
ds_with_embeddings = ds.map(embed) # type: ignore
|
86 |
|
87 |
ds_with_embeddings.add_faiss_index(column="embeddings")
|
88 |
|
|
|
138 |
scores += score[0]
|
139 |
predictions.append(result['text'][0])
|
140 |
|
141 |
+
exact_matches = [exact_match(
|
142 |
predictions[i], answers[i]) for i in range(len(answers))]
|
143 |
+
f1_scores = [f1(
|
144 |
predictions[i], answers[i]) for i in range(len(answers))]
|
145 |
|
146 |
return sum(exact_matches) / len(exact_matches), sum(f1_scores) / len(f1_scores)
|
{base_model β src}/reader.py
RENAMED
File without changes
|
src/utils/log.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
|
9 |
+
def get_logger():
|
10 |
+
# creates a default logger for the project
|
11 |
+
logger = logging.getLogger("Flashcards")
|
12 |
+
|
13 |
+
log_level = os.getenv("LOG_LEVEL", "INFO")
|
14 |
+
logger.setLevel(log_level)
|
15 |
+
|
16 |
+
# Log format
|
17 |
+
formatter = logging.Formatter(
|
18 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
19 |
+
|
20 |
+
# file handler
|
21 |
+
fh = logging.FileHandler("logs.log")
|
22 |
+
fh.setFormatter(formatter)
|
23 |
+
|
24 |
+
# stout
|
25 |
+
ch = logging.StreamHandler()
|
26 |
+
ch.setFormatter(formatter)
|
27 |
+
|
28 |
+
logger.addHandler(fh)
|
29 |
+
logger.addHandler(ch)
|
30 |
+
|
31 |
+
return logger
|
{base_model β src/utils}/string_utils.py
RENAMED
File without changes
|