lingyit1108 commited on
Commit
5a39f92
·
1 Parent(s): 06f450b

added fine-tuning notebook example

Browse files
notebooks/fine-tuning-embedding-model.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": null,
6
  "id": "ca2c990f-5215-4ab9-8143-1d79db28edc6",
7
  "metadata": {},
8
  "outputs": [],
@@ -16,7 +16,7 @@
16
  },
17
  {
18
  "cell_type": "code",
19
- "execution_count": null,
20
  "id": "2c535ad7-7846-4bef-8ba8-33e182490c3d",
21
  "metadata": {},
22
  "outputs": [],
@@ -30,7 +30,33 @@
30
  },
31
  {
32
  "cell_type": "code",
33
- "execution_count": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  "id": "12527049-a5cb-423c-8de5-099aee970c85",
35
  "metadata": {},
36
  "outputs": [],
@@ -40,10 +66,18 @@
40
  },
41
  {
42
  "cell_type": "code",
43
- "execution_count": null,
44
  "id": "abde5e6c-3474-460c-9fac-4f3352c38b53",
45
  "metadata": {},
46
- "outputs": [],
 
 
 
 
 
 
 
 
47
  "source": [
48
  "import llama_index\n",
49
  "print(llama_index.__version__)"
@@ -59,7 +93,7 @@
59
  },
60
  {
61
  "cell_type": "code",
62
- "execution_count": null,
63
  "id": "978cf71f-1ce7-4598-92fe-18fe22ca37c6",
64
  "metadata": {},
65
  "outputs": [],
@@ -81,7 +115,7 @@
81
  },
82
  {
83
  "cell_type": "code",
84
- "execution_count": null,
85
  "id": "26f614c8-eb45-4cc1-b067-2c7299587982",
86
  "metadata": {},
87
  "outputs": [],
@@ -114,7 +148,7 @@
114
  },
115
  {
116
  "cell_type": "code",
117
- "execution_count": null,
118
  "id": "84cc4308-8ac4-4eba-9478-b81d5b645c48",
119
  "metadata": {},
120
  "outputs": [],
@@ -150,7 +184,7 @@
150
  },
151
  {
152
  "cell_type": "code",
153
- "execution_count": null,
154
  "id": "8f17c832-e9ae-477b-8bf7-a9c8410f1ed8",
155
  "metadata": {},
156
  "outputs": [],
@@ -159,23 +193,67 @@
159
  " train_dataset,\n",
160
  " model_id=\"BAAI/bge-small-en-v1.5\",\n",
161
  " model_output_path=\"test_model\",\n",
162
- " val_dataset=val_dataset,\n",
 
163
  ")"
164
  ]
165
  },
166
  {
167
  "cell_type": "code",
168
- "execution_count": null,
169
  "id": "a6498d0b-da9a-4f7f-8c85-c9bf4d772c72",
170
  "metadata": {},
171
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  "source": [
173
  "finetune_engine.finetune()"
174
  ]
175
  },
176
  {
177
  "cell_type": "code",
178
- "execution_count": null,
179
  "id": "e057b405-aa0e-4e78-91e0-9bf40f01c1a9",
180
  "metadata": {},
181
  "outputs": [],
@@ -185,10 +263,21 @@
185
  },
186
  {
187
  "cell_type": "code",
188
- "execution_count": null,
189
  "id": "72d9f97a-0902-4e65-8459-b34613e419f6",
190
  "metadata": {},
191
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
192
  "source": [
193
  "embed_model"
194
  ]
@@ -200,6 +289,1016 @@
200
  "metadata": {},
201
  "outputs": [],
202
  "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  }
204
  ],
205
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 2,
6
  "id": "ca2c990f-5215-4ab9-8143-1d79db28edc6",
7
  "metadata": {},
8
  "outputs": [],
 
16
  },
17
  {
18
  "cell_type": "code",
19
+ "execution_count": 4,
20
  "id": "2c535ad7-7846-4bef-8ba8-33e182490c3d",
21
  "metadata": {},
22
  "outputs": [],
 
30
  },
31
  {
32
  "cell_type": "code",
33
+ "execution_count": 19,
34
+ "id": "25f0c7a3-c52f-4417-aec8-4b6cfbf7a1b5",
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "from llama_index.embeddings import OpenAIEmbedding\n",
39
+ "from llama_index import ServiceContext, VectorStoreIndex\n",
40
+ "from llama_index.schema import TextNode\n",
41
+ "from tqdm.notebook import tqdm\n",
42
+ "import pandas as pd"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 20,
48
+ "id": "62f4d7f0-748a-405e-b5f1-6520fd02bedc",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "from sentence_transformers.evaluation import InformationRetrievalEvaluator\n",
53
+ "from sentence_transformers import SentenceTransformer\n",
54
+ "from pathlib import Path"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 5,
60
  "id": "12527049-a5cb-423c-8de5-099aee970c85",
61
  "metadata": {},
62
  "outputs": [],
 
66
  },
67
  {
68
  "cell_type": "code",
69
+ "execution_count": 6,
70
  "id": "abde5e6c-3474-460c-9fac-4f3352c38b53",
71
  "metadata": {},
72
+ "outputs": [
73
+ {
74
+ "name": "stdout",
75
+ "output_type": "stream",
76
+ "text": [
77
+ "0.9.39\n"
78
+ ]
79
+ }
80
+ ],
81
  "source": [
82
  "import llama_index\n",
83
  "print(llama_index.__version__)"
 
93
  },
94
  {
95
  "cell_type": "code",
96
+ "execution_count": 7,
97
  "id": "978cf71f-1ce7-4598-92fe-18fe22ca37c6",
98
  "metadata": {},
99
  "outputs": [],
 
115
  },
116
  {
117
  "cell_type": "code",
118
+ "execution_count": 8,
119
  "id": "26f614c8-eb45-4cc1-b067-2c7299587982",
120
  "metadata": {},
121
  "outputs": [],
 
148
  },
149
  {
150
  "cell_type": "code",
151
+ "execution_count": 9,
152
  "id": "84cc4308-8ac4-4eba-9478-b81d5b645c48",
153
  "metadata": {},
154
  "outputs": [],
 
184
  },
185
  {
186
  "cell_type": "code",
187
+ "execution_count": 11,
188
  "id": "8f17c832-e9ae-477b-8bf7-a9c8410f1ed8",
189
  "metadata": {},
190
  "outputs": [],
 
193
  " train_dataset,\n",
194
  " model_id=\"BAAI/bge-small-en-v1.5\",\n",
195
  " model_output_path=\"test_model\",\n",
196
+ " batch_size=5,\n",
197
+ " val_dataset=val_dataset\n",
198
  ")"
199
  ]
200
  },
201
  {
202
  "cell_type": "code",
203
+ "execution_count": 12,
204
  "id": "a6498d0b-da9a-4f7f-8c85-c9bf4d772c72",
205
  "metadata": {},
206
+ "outputs": [
207
+ {
208
+ "data": {
209
+ "application/vnd.jupyter.widget-view+json": {
210
+ "model_id": "e80f94e7c7a84014b3cbf270dde3fcaf",
211
+ "version_major": 2,
212
+ "version_minor": 0
213
+ },
214
+ "text/plain": [
215
+ "Epoch: 0%| | 0/2 [00:00<?, ?it/s]"
216
+ ]
217
+ },
218
+ "metadata": {},
219
+ "output_type": "display_data"
220
+ },
221
+ {
222
+ "data": {
223
+ "application/vnd.jupyter.widget-view+json": {
224
+ "model_id": "d02eb3c3b1454494a566557e8b73174f",
225
+ "version_major": 2,
226
+ "version_minor": 0
227
+ },
228
+ "text/plain": [
229
+ "Iteration: 0%| | 0/183 [00:00<?, ?it/s]"
230
+ ]
231
+ },
232
+ "metadata": {},
233
+ "output_type": "display_data"
234
+ },
235
+ {
236
+ "data": {
237
+ "application/vnd.jupyter.widget-view+json": {
238
+ "model_id": "0d73a19c286e43afa7c12cfb5fb49d34",
239
+ "version_major": 2,
240
+ "version_minor": 0
241
+ },
242
+ "text/plain": [
243
+ "Iteration: 0%| | 0/183 [00:00<?, ?it/s]"
244
+ ]
245
+ },
246
+ "metadata": {},
247
+ "output_type": "display_data"
248
+ }
249
+ ],
250
  "source": [
251
  "finetune_engine.finetune()"
252
  ]
253
  },
254
  {
255
  "cell_type": "code",
256
+ "execution_count": 13,
257
  "id": "e057b405-aa0e-4e78-91e0-9bf40f01c1a9",
258
  "metadata": {},
259
  "outputs": [],
 
263
  },
264
  {
265
  "cell_type": "code",
266
+ "execution_count": 14,
267
  "id": "72d9f97a-0902-4e65-8459-b34613e419f6",
268
  "metadata": {},
269
+ "outputs": [
270
+ {
271
+ "data": {
272
+ "text/plain": [
273
+ "HuggingFaceEmbedding(model_name='test_model', embed_batch_size=10, callback_manager=<llama_index.callbacks.base.CallbackManager object at 0x3c7fadca0>, tokenizer_name='test_model', max_length=512, pooling=<Pooling.CLS: 'cls'>, normalize=True, query_instruction=None, text_instruction=None, cache_folder=None)"
274
+ ]
275
+ },
276
+ "execution_count": 14,
277
+ "metadata": {},
278
+ "output_type": "execute_result"
279
+ }
280
+ ],
281
  "source": [
282
  "embed_model"
283
  ]
 
289
  "metadata": {},
290
  "outputs": [],
291
  "source": []
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": null,
296
+ "id": "dad7589f-4855-4432-b710-01aff9c134ee",
297
+ "metadata": {},
298
+ "outputs": [],
299
+ "source": []
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": 15,
304
+ "id": "ac4a1a5b-974d-452e-8507-0950c962f9b2",
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": [
308
+ "def evaluate(\n",
309
+ " dataset,\n",
310
+ " embed_model,\n",
311
+ " top_k=5,\n",
312
+ " verbose=False,\n",
313
+ "):\n",
314
+ " corpus = dataset.corpus\n",
315
+ " queries = dataset.queries\n",
316
+ " relevant_docs = dataset.relevant_docs\n",
317
+ "\n",
318
+ " service_context = ServiceContext.from_defaults(embed_model=embed_model)\n",
319
+ " nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]\n",
320
+ " index = VectorStoreIndex(\n",
321
+ " nodes, service_context=service_context, show_progress=True\n",
322
+ " )\n",
323
+ " retriever = index.as_retriever(similarity_top_k=top_k)\n",
324
+ "\n",
325
+ " eval_results = []\n",
326
+ " for query_id, query in tqdm(queries.items()):\n",
327
+ " retrieved_nodes = retriever.retrieve(query)\n",
328
+ " retrieved_ids = [node.node.node_id for node in retrieved_nodes]\n",
329
+ " expected_id = relevant_docs[query_id][0]\n",
330
+ " is_hit = expected_id in retrieved_ids # assume 1 relevant doc\n",
331
+ "\n",
332
+ " eval_result = {\n",
333
+ " \"is_hit\": is_hit,\n",
334
+ " \"retrieved\": retrieved_ids,\n",
335
+ " \"expected\": expected_id,\n",
336
+ " \"query\": query_id,\n",
337
+ " }\n",
338
+ " eval_results.append(eval_result)\n",
339
+ " return eval_results"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": 16,
345
+ "id": "a53cf893-ce9f-4d9d-ad4a-e9e17fb058d3",
346
+ "metadata": {},
347
+ "outputs": [],
348
+ "source": [
349
+ "def evaluate_st(\n",
350
+ " dataset,\n",
351
+ " model_id,\n",
352
+ " name,\n",
353
+ "):\n",
354
+ " corpus = dataset.corpus\n",
355
+ " queries = dataset.queries\n",
356
+ " relevant_docs = dataset.relevant_docs\n",
357
+ "\n",
358
+ " evaluator = InformationRetrievalEvaluator(\n",
359
+ " queries, corpus, relevant_docs, name=name\n",
360
+ " )\n",
361
+ " model = SentenceTransformer(model_id)\n",
362
+ " output_path = \"results/\"\n",
363
+ " Path(output_path).mkdir(exist_ok=True, parents=True)\n",
364
+ " return evaluator(model, output_path=output_path)"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "execution_count": null,
370
+ "id": "703f9350-f7ab-43cc-abdf-055323ef67dd",
371
+ "metadata": {},
372
+ "outputs": [],
373
+ "source": []
374
+ },
375
+ {
376
+ "cell_type": "code",
377
+ "execution_count": null,
378
+ "id": "57d66621-49e6-4a8a-9ef2-83b2b33e33d7",
379
+ "metadata": {},
380
+ "outputs": [],
381
+ "source": []
382
+ },
383
+ {
384
+ "cell_type": "markdown",
385
+ "id": "b43ad08e-e96d-412b-9a88-14fe3af85b3d",
386
+ "metadata": {},
387
+ "source": [
388
+ "### Using OpenAI Ada embedding"
389
+ ]
390
+ },
391
+ {
392
+ "cell_type": "code",
393
+ "execution_count": 21,
394
+ "id": "91f057aa-4b59-48ea-b3d5-23012a4d487f",
395
+ "metadata": {},
396
+ "outputs": [
397
+ {
398
+ "data": {
399
+ "application/vnd.jupyter.widget-view+json": {
400
+ "model_id": "f4bf05fbe14c4c379c0b3e1912b84d36",
401
+ "version_major": 2,
402
+ "version_minor": 0
403
+ },
404
+ "text/plain": [
405
+ "Generating embeddings: 0%| | 0/100 [00:00<?, ?it/s]"
406
+ ]
407
+ },
408
+ "metadata": {},
409
+ "output_type": "display_data"
410
+ },
411
+ {
412
+ "name": "stderr",
413
+ "output_type": "stream",
414
+ "text": [
415
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
416
+ "To disable this warning, you can either:\n",
417
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
418
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
419
+ ]
420
+ },
421
+ {
422
+ "data": {
423
+ "application/vnd.jupyter.widget-view+json": {
424
+ "model_id": "4f365d1cab004fe897949e2a3928c457",
425
+ "version_major": 2,
426
+ "version_minor": 0
427
+ },
428
+ "text/plain": [
429
+ " 0%| | 0/200 [00:00<?, ?it/s]"
430
+ ]
431
+ },
432
+ "metadata": {},
433
+ "output_type": "display_data"
434
+ }
435
+ ],
436
+ "source": [
437
+ "ada = OpenAIEmbedding()\n",
438
+ "ada_val_results = evaluate(val_dataset, ada)"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": 22,
444
+ "id": "5d2f59c6-75d3-4970-bac3-dfe0eef00efe",
445
+ "metadata": {},
446
+ "outputs": [],
447
+ "source": [
448
+ "df_ada = pd.DataFrame(ada_val_results)"
449
+ ]
450
+ },
451
+ {
452
+ "cell_type": "code",
453
+ "execution_count": 24,
454
+ "id": "7a697cd8-6f39-4d5b-84f4-f08cf58adc4a",
455
+ "metadata": {},
456
+ "outputs": [
457
+ {
458
+ "data": {
459
+ "text/html": [
460
+ "<div>\n",
461
+ "<style scoped>\n",
462
+ " .dataframe tbody tr th:only-of-type {\n",
463
+ " vertical-align: middle;\n",
464
+ " }\n",
465
+ "\n",
466
+ " .dataframe tbody tr th {\n",
467
+ " vertical-align: top;\n",
468
+ " }\n",
469
+ "\n",
470
+ " .dataframe thead th {\n",
471
+ " text-align: right;\n",
472
+ " }\n",
473
+ "</style>\n",
474
+ "<table border=\"1\" class=\"dataframe\">\n",
475
+ " <thead>\n",
476
+ " <tr style=\"text-align: right;\">\n",
477
+ " <th></th>\n",
478
+ " <th>is_hit</th>\n",
479
+ " <th>retrieved</th>\n",
480
+ " <th>expected</th>\n",
481
+ " <th>query</th>\n",
482
+ " </tr>\n",
483
+ " </thead>\n",
484
+ " <tbody>\n",
485
+ " <tr>\n",
486
+ " <th>0</th>\n",
487
+ " <td>False</td>\n",
488
+ " <td>[5b9cd986-33dc-46f1-abae-e4e1dc9e3629, c3c1804...</td>\n",
489
+ " <td>6a756f03-638d-480d-8222-1a6bf3790e3c</td>\n",
490
+ " <td>011d84b2-0c26-4c5c-89d1-2a85498f30e0</td>\n",
491
+ " </tr>\n",
492
+ " <tr>\n",
493
+ " <th>1</th>\n",
494
+ " <td>True</td>\n",
495
+ " <td>[6a756f03-638d-480d-8222-1a6bf3790e3c, c3c1804...</td>\n",
496
+ " <td>6a756f03-638d-480d-8222-1a6bf3790e3c</td>\n",
497
+ " <td>70c5ddd7-eb86-4a41-af70-a23d2392f48d</td>\n",
498
+ " </tr>\n",
499
+ " <tr>\n",
500
+ " <th>2</th>\n",
501
+ " <td>True</td>\n",
502
+ " <td>[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824...</td>\n",
503
+ " <td>c83dbd8a-7e62-445e-8c12-a8ad604ff65e</td>\n",
504
+ " <td>a8f4290a-1281-4272-aab9-bf089954a45e</td>\n",
505
+ " </tr>\n",
506
+ " <tr>\n",
507
+ " <th>3</th>\n",
508
+ " <td>True</td>\n",
509
+ " <td>[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824...</td>\n",
510
+ " <td>c83dbd8a-7e62-445e-8c12-a8ad604ff65e</td>\n",
511
+ " <td>c1ef991a-1cc6-4dbf-b179-2df688c84301</td>\n",
512
+ " </tr>\n",
513
+ " <tr>\n",
514
+ " <th>4</th>\n",
515
+ " <td>True</td>\n",
516
+ " <td>[21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8...</td>\n",
517
+ " <td>21778248-2ed9-4147-bdb0-a60337a1a599</td>\n",
518
+ " <td>1ce25e78-c1e1-487e-9455-9418baa0b60c</td>\n",
519
+ " </tr>\n",
520
+ " </tbody>\n",
521
+ "</table>\n",
522
+ "</div>"
523
+ ],
524
+ "text/plain": [
525
+ " is_hit retrieved \\\n",
526
+ "0 False [5b9cd986-33dc-46f1-abae-e4e1dc9e3629, c3c1804... \n",
527
+ "1 True [6a756f03-638d-480d-8222-1a6bf3790e3c, c3c1804... \n",
528
+ "2 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824... \n",
529
+ "3 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824... \n",
530
+ "4 True [21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8... \n",
531
+ "\n",
532
+ " expected query \n",
533
+ "0 6a756f03-638d-480d-8222-1a6bf3790e3c 011d84b2-0c26-4c5c-89d1-2a85498f30e0 \n",
534
+ "1 6a756f03-638d-480d-8222-1a6bf3790e3c 70c5ddd7-eb86-4a41-af70-a23d2392f48d \n",
535
+ "2 c83dbd8a-7e62-445e-8c12-a8ad604ff65e a8f4290a-1281-4272-aab9-bf089954a45e \n",
536
+ "3 c83dbd8a-7e62-445e-8c12-a8ad604ff65e c1ef991a-1cc6-4dbf-b179-2df688c84301 \n",
537
+ "4 21778248-2ed9-4147-bdb0-a60337a1a599 1ce25e78-c1e1-487e-9455-9418baa0b60c "
538
+ ]
539
+ },
540
+ "execution_count": 24,
541
+ "metadata": {},
542
+ "output_type": "execute_result"
543
+ }
544
+ ],
545
+ "source": [
546
+ "df_ada[:5]"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "code",
551
+ "execution_count": 27,
552
+ "id": "3f7186fb-f392-4531-8959-25161e3905e4",
553
+ "metadata": {},
554
+ "outputs": [
555
+ {
556
+ "data": {
557
+ "text/plain": [
558
+ "(0.955, 200)"
559
+ ]
560
+ },
561
+ "execution_count": 27,
562
+ "metadata": {},
563
+ "output_type": "execute_result"
564
+ }
565
+ ],
566
+ "source": [
567
+ "hit_rate_ada = df_ada[\"is_hit\"].mean()\n",
568
+ "hit_rate_ada, len(df_ada)"
569
+ ]
570
+ },
571
+ {
572
+ "cell_type": "code",
573
+ "execution_count": null,
574
+ "id": "d044399a-e55b-40b7-a09d-6fb838383bfa",
575
+ "metadata": {},
576
+ "outputs": [],
577
+ "source": []
578
+ },
579
+ {
580
+ "cell_type": "markdown",
581
+ "id": "66746f3e-638a-432c-a38d-7cb99d2093f7",
582
+ "metadata": {},
583
+ "source": [
584
+ "### Using BAAI bge-small model without fine-tuning"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": 26,
590
+ "id": "b2905831-0eb9-4ea7-a0b9-5db286b0965e",
591
+ "metadata": {},
592
+ "outputs": [
593
+ {
594
+ "data": {
595
+ "application/vnd.jupyter.widget-view+json": {
596
+ "model_id": "784a67a3d51a400cad53c52bb16121fc",
597
+ "version_major": 2,
598
+ "version_minor": 0
599
+ },
600
+ "text/plain": [
601
+ "config.json: 0%| | 0.00/743 [00:00<?, ?B/s]"
602
+ ]
603
+ },
604
+ "metadata": {},
605
+ "output_type": "display_data"
606
+ },
607
+ {
608
+ "data": {
609
+ "application/vnd.jupyter.widget-view+json": {
610
+ "model_id": "1c0edb74b4154cb49931180def479320",
611
+ "version_major": 2,
612
+ "version_minor": 0
613
+ },
614
+ "text/plain": [
615
+ "model.safetensors: 0%| | 0.00/133M [00:00<?, ?B/s]"
616
+ ]
617
+ },
618
+ "metadata": {},
619
+ "output_type": "display_data"
620
+ },
621
+ {
622
+ "data": {
623
+ "application/vnd.jupyter.widget-view+json": {
624
+ "model_id": "af9cb2f4d3934e9a991969f0083fa495",
625
+ "version_major": 2,
626
+ "version_minor": 0
627
+ },
628
+ "text/plain": [
629
+ "tokenizer_config.json: 0%| | 0.00/366 [00:00<?, ?B/s]"
630
+ ]
631
+ },
632
+ "metadata": {},
633
+ "output_type": "display_data"
634
+ },
635
+ {
636
+ "data": {
637
+ "application/vnd.jupyter.widget-view+json": {
638
+ "model_id": "2370d77040d94ffb9a4d8ca2f45faa97",
639
+ "version_major": 2,
640
+ "version_minor": 0
641
+ },
642
+ "text/plain": [
643
+ "vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]"
644
+ ]
645
+ },
646
+ "metadata": {},
647
+ "output_type": "display_data"
648
+ },
649
+ {
650
+ "data": {
651
+ "application/vnd.jupyter.widget-view+json": {
652
+ "model_id": "0b7c293a142d4eaf91673c17222d232a",
653
+ "version_major": 2,
654
+ "version_minor": 0
655
+ },
656
+ "text/plain": [
657
+ "tokenizer.json: 0%| | 0.00/711k [00:00<?, ?B/s]"
658
+ ]
659
+ },
660
+ "metadata": {},
661
+ "output_type": "display_data"
662
+ },
663
+ {
664
+ "data": {
665
+ "application/vnd.jupyter.widget-view+json": {
666
+ "model_id": "7fcb86d759084084a8e41aec12738e19",
667
+ "version_major": 2,
668
+ "version_minor": 0
669
+ },
670
+ "text/plain": [
671
+ "special_tokens_map.json: 0%| | 0.00/125 [00:00<?, ?B/s]"
672
+ ]
673
+ },
674
+ "metadata": {},
675
+ "output_type": "display_data"
676
+ },
677
+ {
678
+ "data": {
679
+ "application/vnd.jupyter.widget-view+json": {
680
+ "model_id": "ab4d747b58f74fdb86481b7f936bf0c4",
681
+ "version_major": 2,
682
+ "version_minor": 0
683
+ },
684
+ "text/plain": [
685
+ "Generating embeddings: 0%| | 0/100 [00:00<?, ?it/s]"
686
+ ]
687
+ },
688
+ "metadata": {},
689
+ "output_type": "display_data"
690
+ },
691
+ {
692
+ "data": {
693
+ "application/vnd.jupyter.widget-view+json": {
694
+ "model_id": "baa0bb9ae0da4dfc86c20308477415fa",
695
+ "version_major": 2,
696
+ "version_minor": 0
697
+ },
698
+ "text/plain": [
699
+ " 0%| | 0/200 [00:00<?, ?it/s]"
700
+ ]
701
+ },
702
+ "metadata": {},
703
+ "output_type": "display_data"
704
+ }
705
+ ],
706
+ "source": [
707
+ "bge = \"local:BAAI/bge-small-en-v1.5\"\n",
708
+ "bge_val_results = evaluate(val_dataset, bge)"
709
+ ]
710
+ },
711
+ {
712
+ "cell_type": "code",
713
+ "execution_count": 28,
714
+ "id": "4e66270d-d3f6-429e-9e48-e8062866aa02",
715
+ "metadata": {},
716
+ "outputs": [],
717
+ "source": [
718
+ "df_bge = pd.DataFrame(bge_val_results)"
719
+ ]
720
+ },
721
+ {
722
+ "cell_type": "code",
723
+ "execution_count": 29,
724
+ "id": "698c1eb7-eba4-4383-98aa-931fc4ad56a4",
725
+ "metadata": {},
726
+ "outputs": [
727
+ {
728
+ "data": {
729
+ "text/html": [
730
+ "<div>\n",
731
+ "<style scoped>\n",
732
+ " .dataframe tbody tr th:only-of-type {\n",
733
+ " vertical-align: middle;\n",
734
+ " }\n",
735
+ "\n",
736
+ " .dataframe tbody tr th {\n",
737
+ " vertical-align: top;\n",
738
+ " }\n",
739
+ "\n",
740
+ " .dataframe thead th {\n",
741
+ " text-align: right;\n",
742
+ " }\n",
743
+ "</style>\n",
744
+ "<table border=\"1\" class=\"dataframe\">\n",
745
+ " <thead>\n",
746
+ " <tr style=\"text-align: right;\">\n",
747
+ " <th></th>\n",
748
+ " <th>is_hit</th>\n",
749
+ " <th>retrieved</th>\n",
750
+ " <th>expected</th>\n",
751
+ " <th>query</th>\n",
752
+ " </tr>\n",
753
+ " </thead>\n",
754
+ " <tbody>\n",
755
+ " <tr>\n",
756
+ " <th>0</th>\n",
757
+ " <td>False</td>\n",
758
+ " <td>[69a5696d-0c0e-482a-b6a9-f7b87f19945f, fa650c7...</td>\n",
759
+ " <td>6a756f03-638d-480d-8222-1a6bf3790e3c</td>\n",
760
+ " <td>011d84b2-0c26-4c5c-89d1-2a85498f30e0</td>\n",
761
+ " </tr>\n",
762
+ " <tr>\n",
763
+ " <th>1</th>\n",
764
+ " <td>True</td>\n",
765
+ " <td>[6a756f03-638d-480d-8222-1a6bf3790e3c, d89a649...</td>\n",
766
+ " <td>6a756f03-638d-480d-8222-1a6bf3790e3c</td>\n",
767
+ " <td>70c5ddd7-eb86-4a41-af70-a23d2392f48d</td>\n",
768
+ " </tr>\n",
769
+ " <tr>\n",
770
+ " <th>2</th>\n",
771
+ " <td>True</td>\n",
772
+ " <td>[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824...</td>\n",
773
+ " <td>c83dbd8a-7e62-445e-8c12-a8ad604ff65e</td>\n",
774
+ " <td>a8f4290a-1281-4272-aab9-bf089954a45e</td>\n",
775
+ " </tr>\n",
776
+ " <tr>\n",
777
+ " <th>3</th>\n",
778
+ " <td>True</td>\n",
779
+ " <td>[c83dbd8a-7e62-445e-8c12-a8ad604ff65e, ad2e3eb...</td>\n",
780
+ " <td>c83dbd8a-7e62-445e-8c12-a8ad604ff65e</td>\n",
781
+ " <td>c1ef991a-1cc6-4dbf-b179-2df688c84301</td>\n",
782
+ " </tr>\n",
783
+ " <tr>\n",
784
+ " <th>4</th>\n",
785
+ " <td>True</td>\n",
786
+ " <td>[21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8...</td>\n",
787
+ " <td>21778248-2ed9-4147-bdb0-a60337a1a599</td>\n",
788
+ " <td>1ce25e78-c1e1-487e-9455-9418baa0b60c</td>\n",
789
+ " </tr>\n",
790
+ " </tbody>\n",
791
+ "</table>\n",
792
+ "</div>"
793
+ ],
794
+ "text/plain": [
795
+ " is_hit retrieved \\\n",
796
+ "0 False [69a5696d-0c0e-482a-b6a9-f7b87f19945f, fa650c7... \n",
797
+ "1 True [6a756f03-638d-480d-8222-1a6bf3790e3c, d89a649... \n",
798
+ "2 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, 2177824... \n",
799
+ "3 True [c83dbd8a-7e62-445e-8c12-a8ad604ff65e, ad2e3eb... \n",
800
+ "4 True [21778248-2ed9-4147-bdb0-a60337a1a599, c83dbd8... \n",
801
+ "\n",
802
+ " expected query \n",
803
+ "0 6a756f03-638d-480d-8222-1a6bf3790e3c 011d84b2-0c26-4c5c-89d1-2a85498f30e0 \n",
804
+ "1 6a756f03-638d-480d-8222-1a6bf3790e3c 70c5ddd7-eb86-4a41-af70-a23d2392f48d \n",
805
+ "2 c83dbd8a-7e62-445e-8c12-a8ad604ff65e a8f4290a-1281-4272-aab9-bf089954a45e \n",
806
+ "3 c83dbd8a-7e62-445e-8c12-a8ad604ff65e c1ef991a-1cc6-4dbf-b179-2df688c84301 \n",
807
+ "4 21778248-2ed9-4147-bdb0-a60337a1a599 1ce25e78-c1e1-487e-9455-9418baa0b60c "
808
+ ]
809
+ },
810
+ "execution_count": 29,
811
+ "metadata": {},
812
+ "output_type": "execute_result"
813
+ }
814
+ ],
815
+ "source": [
816
+ "df_bge[:5]"
817
+ ]
818
+ },
819
+ {
820
+ "cell_type": "code",
821
+ "execution_count": 30,
822
+ "id": "9b1cb546-4605-4c48-bf4e-df812db97f13",
823
+ "metadata": {},
824
+ "outputs": [
825
+ {
826
+ "data": {
827
+ "text/plain": [
828
+ "(0.915, 200)"
829
+ ]
830
+ },
831
+ "execution_count": 30,
832
+ "metadata": {},
833
+ "output_type": "execute_result"
834
+ }
835
+ ],
836
+ "source": [
837
+ "hit_rate_bge = df_bge[\"is_hit\"].mean()\n",
838
+ "hit_rate_bge, len(df_bge)"
839
+ ]
840
+ },
841
+ {
842
+ "cell_type": "code",
843
+ "execution_count": null,
844
+ "id": "7dd69ad1-2153-4df0-93f7-807fc289d3fd",
845
+ "metadata": {},
846
+ "outputs": [],
847
+ "source": []
848
+ },
849
+ {
850
+ "cell_type": "code",
851
+ "execution_count": 31,
852
+ "id": "1b12ca3d-6ca2-41f6-9ddb-b12b9354ca83",
853
+ "metadata": {},
854
+ "outputs": [
855
+ {
856
+ "data": {
857
+ "text/plain": [
858
+ "0.7955697668171072"
859
+ ]
860
+ },
861
+ "execution_count": 31,
862
+ "metadata": {},
863
+ "output_type": "execute_result"
864
+ }
865
+ ],
866
+ "source": [
867
+ "evaluate_st(val_dataset, \"BAAI/bge-small-en-v1.5\", name=\"bge\")"
868
+ ]
869
+ },
870
+ {
871
+ "cell_type": "code",
872
+ "execution_count": null,
873
+ "id": "6023382b-0ff5-4d60-aeac-ad523153f943",
874
+ "metadata": {},
875
+ "outputs": [],
876
+ "source": []
877
+ },
878
+ {
879
+ "cell_type": "code",
880
+ "execution_count": null,
881
+ "id": "adf35a2a-3bb7-4251-9521-f35346a7c6e6",
882
+ "metadata": {},
883
+ "outputs": [],
884
+ "source": []
885
+ },
886
+ {
887
+ "cell_type": "markdown",
888
+ "id": "b3d290c2-784f-4c41-a258-e11d2c5117e7",
889
+ "metadata": {},
890
+ "source": [
891
+ "### Using BAAI bge-small model with `fine-tuning`"
892
+ ]
893
+ },
894
+ {
895
+ "cell_type": "code",
896
+ "execution_count": 32,
897
+ "id": "bd42b288-1f1f-41aa-9fd4-1ae4b1df462b",
898
+ "metadata": {},
899
+ "outputs": [
900
+ {
901
+ "data": {
902
+ "application/vnd.jupyter.widget-view+json": {
903
+ "model_id": "47dbb97a78c04f7f8fc1264c1013b5ea",
904
+ "version_major": 2,
905
+ "version_minor": 0
906
+ },
907
+ "text/plain": [
908
+ "Generating embeddings: 0%| | 0/100 [00:00<?, ?it/s]"
909
+ ]
910
+ },
911
+ "metadata": {},
912
+ "output_type": "display_data"
913
+ },
914
+ {
915
+ "data": {
916
+ "application/vnd.jupyter.widget-view+json": {
917
+ "model_id": "31c9e93debe34cc790bf32e579134a1a",
918
+ "version_major": 2,
919
+ "version_minor": 0
920
+ },
921
+ "text/plain": [
922
+ " 0%| | 0/200 [00:00<?, ?it/s]"
923
+ ]
924
+ },
925
+ "metadata": {},
926
+ "output_type": "display_data"
927
+ }
928
+ ],
929
+ "source": [
930
+ "finetuned = \"local:test_model\"\n",
931
+ "val_results_finetuned = evaluate(val_dataset, finetuned)"
932
+ ]
933
+ },
934
+ {
935
+ "cell_type": "code",
936
+ "execution_count": 33,
937
+ "id": "b1d7112d-b1b8-47db-8a4b-6c024ef99dd6",
938
+ "metadata": {},
939
+ "outputs": [],
940
+ "source": [
941
+ "df_finetuned = pd.DataFrame(val_results_finetuned)"
942
+ ]
943
+ },
944
+ {
945
+ "cell_type": "code",
946
+ "execution_count": 34,
947
+ "id": "62a4dd29-0631-4c5b-88e1-be43d48e1043",
948
+ "metadata": {},
949
+ "outputs": [
950
+ {
951
+ "data": {
952
+ "text/plain": [
953
+ "0.97"
954
+ ]
955
+ },
956
+ "execution_count": 34,
957
+ "metadata": {},
958
+ "output_type": "execute_result"
959
+ }
960
+ ],
961
+ "source": [
962
+ "hit_rate_finetuned = df_finetuned[\"is_hit\"].mean()\n",
963
+ "hit_rate_finetuned"
964
+ ]
965
+ },
966
+ {
967
+ "cell_type": "code",
968
+ "execution_count": 35,
969
+ "id": "4332594b-c861-40fb-a58b-ba36717d0519",
970
+ "metadata": {},
971
+ "outputs": [
972
+ {
973
+ "data": {
974
+ "text/plain": [
975
+ "0.8573385846534823"
976
+ ]
977
+ },
978
+ "execution_count": 35,
979
+ "metadata": {},
980
+ "output_type": "execute_result"
981
+ }
982
+ ],
983
+ "source": [
984
+ "evaluate_st(val_dataset, \"test_model\", name=\"finetuned\")"
985
+ ]
986
+ },
987
+ {
988
+ "cell_type": "code",
989
+ "execution_count": null,
990
+ "id": "b0003812-84a2-4ebd-9372-07bf874a486b",
991
+ "metadata": {},
992
+ "outputs": [],
993
+ "source": []
994
+ },
995
+ {
996
+ "cell_type": "markdown",
997
+ "id": "ae7eb6ff-181b-42c8-975c-ca3320158698",
998
+ "metadata": {},
999
+ "source": [
1000
+ "### Summary"
1001
+ ]
1002
+ },
1003
+ {
1004
+ "cell_type": "code",
1005
+ "execution_count": 36,
1006
+ "id": "3ca46cff-b186-463a-847d-a86c310268ec",
1007
+ "metadata": {},
1008
+ "outputs": [],
1009
+ "source": [
1010
+ "df_ada[\"model\"] = \"ada\"\n",
1011
+ "df_bge[\"model\"] = \"bge\"\n",
1012
+ "df_finetuned[\"model\"] = \"fine_tuned\""
1013
+ ]
1014
+ },
1015
+ {
1016
+ "cell_type": "code",
1017
+ "execution_count": 37,
1018
+ "id": "d1d3053e-2395-48a0-af59-fd27180e1e7b",
1019
+ "metadata": {},
1020
+ "outputs": [
1021
+ {
1022
+ "data": {
1023
+ "text/html": [
1024
+ "<div>\n",
1025
+ "<style scoped>\n",
1026
+ " .dataframe tbody tr th:only-of-type {\n",
1027
+ " vertical-align: middle;\n",
1028
+ " }\n",
1029
+ "\n",
1030
+ " .dataframe tbody tr th {\n",
1031
+ " vertical-align: top;\n",
1032
+ " }\n",
1033
+ "\n",
1034
+ " .dataframe thead th {\n",
1035
+ " text-align: right;\n",
1036
+ " }\n",
1037
+ "</style>\n",
1038
+ "<table border=\"1\" class=\"dataframe\">\n",
1039
+ " <thead>\n",
1040
+ " <tr style=\"text-align: right;\">\n",
1041
+ " <th></th>\n",
1042
+ " <th>is_hit</th>\n",
1043
+ " </tr>\n",
1044
+ " <tr>\n",
1045
+ " <th>model</th>\n",
1046
+ " <th></th>\n",
1047
+ " </tr>\n",
1048
+ " </thead>\n",
1049
+ " <tbody>\n",
1050
+ " <tr>\n",
1051
+ " <th>ada</th>\n",
1052
+ " <td>0.955</td>\n",
1053
+ " </tr>\n",
1054
+ " <tr>\n",
1055
+ " <th>bge</th>\n",
1056
+ " <td>0.915</td>\n",
1057
+ " </tr>\n",
1058
+ " <tr>\n",
1059
+ " <th>fine_tuned</th>\n",
1060
+ " <td>0.970</td>\n",
1061
+ " </tr>\n",
1062
+ " </tbody>\n",
1063
+ "</table>\n",
1064
+ "</div>"
1065
+ ],
1066
+ "text/plain": [
1067
+ " is_hit\n",
1068
+ "model \n",
1069
+ "ada 0.955\n",
1070
+ "bge 0.915\n",
1071
+ "fine_tuned 0.970"
1072
+ ]
1073
+ },
1074
+ "execution_count": 37,
1075
+ "metadata": {},
1076
+ "output_type": "execute_result"
1077
+ }
1078
+ ],
1079
+ "source": [
1080
+ "df_all = pd.concat([df_ada, df_bge, df_finetuned])\n",
1081
+ "df_all.groupby(\"model\").mean(\"is_hit\")"
1082
+ ]
1083
+ },
1084
+ {
1085
+ "cell_type": "code",
1086
+ "execution_count": null,
1087
+ "id": "72575c28-a221-4967-8f04-9579dcefa8f8",
1088
+ "metadata": {},
1089
+ "outputs": [],
1090
+ "source": []
1091
+ },
1092
+ {
1093
+ "cell_type": "code",
1094
+ "execution_count": 38,
1095
+ "id": "032cac38-c856-4aeb-9bbb-6d70ed53c614",
1096
+ "metadata": {},
1097
+ "outputs": [],
1098
+ "source": [
1099
+ "df_st_bge = pd.read_csv(\n",
1100
+ " \"results/Information-Retrieval_evaluation_bge_results.csv\"\n",
1101
+ ")\n",
1102
+ "df_st_finetuned = pd.read_csv(\n",
1103
+ " \"results/Information-Retrieval_evaluation_finetuned_results.csv\"\n",
1104
+ ")"
1105
+ ]
1106
+ },
1107
+ {
1108
+ "cell_type": "code",
1109
+ "execution_count": null,
1110
+ "id": "a509f239-8b28-4d0a-9101-c8de91c7943b",
1111
+ "metadata": {},
1112
+ "outputs": [],
1113
+ "source": []
1114
+ },
1115
+ {
1116
+ "cell_type": "code",
1117
+ "execution_count": 39,
1118
+ "id": "d2975262-c486-4a9a-a61f-ea535203a0f3",
1119
+ "metadata": {},
1120
+ "outputs": [
1121
+ {
1122
+ "data": {
1123
+ "text/html": [
1124
+ "<div>\n",
1125
+ "<style scoped>\n",
1126
+ " .dataframe tbody tr th:only-of-type {\n",
1127
+ " vertical-align: middle;\n",
1128
+ " }\n",
1129
+ "\n",
1130
+ " .dataframe tbody tr th {\n",
1131
+ " vertical-align: top;\n",
1132
+ " }\n",
1133
+ "\n",
1134
+ " .dataframe thead th {\n",
1135
+ " text-align: right;\n",
1136
+ " }\n",
1137
+ "</style>\n",
1138
+ "<table border=\"1\" class=\"dataframe\">\n",
1139
+ " <thead>\n",
1140
+ " <tr style=\"text-align: right;\">\n",
1141
+ " <th></th>\n",
1142
+ " <th>epoch</th>\n",
1143
+ " <th>steps</th>\n",
1144
+ " <th>cos_sim-Accuracy@1</th>\n",
1145
+ " <th>cos_sim-Accuracy@3</th>\n",
1146
+ " <th>cos_sim-Accuracy@5</th>\n",
1147
+ " <th>cos_sim-Accuracy@10</th>\n",
1148
+ " <th>cos_sim-Precision@1</th>\n",
1149
+ " <th>cos_sim-Recall@1</th>\n",
1150
+ " <th>cos_sim-Precision@3</th>\n",
1151
+ " <th>cos_sim-Recall@3</th>\n",
1152
+ " <th>...</th>\n",
1153
+ " <th>dot_score-Recall@1</th>\n",
1154
+ " <th>dot_score-Precision@3</th>\n",
1155
+ " <th>dot_score-Recall@3</th>\n",
1156
+ " <th>dot_score-Precision@5</th>\n",
1157
+ " <th>dot_score-Recall@5</th>\n",
1158
+ " <th>dot_score-Precision@10</th>\n",
1159
+ " <th>dot_score-Recall@10</th>\n",
1160
+ " <th>dot_score-MRR@10</th>\n",
1161
+ " <th>dot_score-NDCG@10</th>\n",
1162
+ " <th>dot_score-MAP@100</th>\n",
1163
+ " </tr>\n",
1164
+ " <tr>\n",
1165
+ " <th>model</th>\n",
1166
+ " <th></th>\n",
1167
+ " <th></th>\n",
1168
+ " <th></th>\n",
1169
+ " <th></th>\n",
1170
+ " <th></th>\n",
1171
+ " <th></th>\n",
1172
+ " <th></th>\n",
1173
+ " <th></th>\n",
1174
+ " <th></th>\n",
1175
+ " <th></th>\n",
1176
+ " <th></th>\n",
1177
+ " <th></th>\n",
1178
+ " <th></th>\n",
1179
+ " <th></th>\n",
1180
+ " <th></th>\n",
1181
+ " <th></th>\n",
1182
+ " <th></th>\n",
1183
+ " <th></th>\n",
1184
+ " <th></th>\n",
1185
+ " <th></th>\n",
1186
+ " <th></th>\n",
1187
+ " </tr>\n",
1188
+ " </thead>\n",
1189
+ " <tbody>\n",
1190
+ " <tr>\n",
1191
+ " <th>bge</th>\n",
1192
+ " <td>-1</td>\n",
1193
+ " <td>-1</td>\n",
1194
+ " <td>0.705</td>\n",
1195
+ " <td>0.865</td>\n",
1196
+ " <td>0.92</td>\n",
1197
+ " <td>0.96</td>\n",
1198
+ " <td>0.705</td>\n",
1199
+ " <td>0.705</td>\n",
1200
+ " <td>0.288333</td>\n",
1201
+ " <td>0.865</td>\n",
1202
+ " <td>...</td>\n",
1203
+ " <td>0.705</td>\n",
1204
+ " <td>0.288333</td>\n",
1205
+ " <td>0.865</td>\n",
1206
+ " <td>0.184</td>\n",
1207
+ " <td>0.92</td>\n",
1208
+ " <td>0.096</td>\n",
1209
+ " <td>0.96</td>\n",
1210
+ " <td>0.792935</td>\n",
1211
+ " <td>0.833595</td>\n",
1212
+ " <td>0.795570</td>\n",
1213
+ " </tr>\n",
1214
+ " <tr>\n",
1215
+ " <th>fine_tuned</th>\n",
1216
+ " <td>-1</td>\n",
1217
+ " <td>-1</td>\n",
1218
+ " <td>0.790</td>\n",
1219
+ " <td>0.900</td>\n",
1220
+ " <td>0.97</td>\n",
1221
+ " <td>0.98</td>\n",
1222
+ " <td>0.790</td>\n",
1223
+ " <td>0.790</td>\n",
1224
+ " <td>0.300000</td>\n",
1225
+ " <td>0.900</td>\n",
1226
+ " <td>...</td>\n",
1227
+ " <td>0.790</td>\n",
1228
+ " <td>0.300000</td>\n",
1229
+ " <td>0.900</td>\n",
1230
+ " <td>0.194</td>\n",
1231
+ " <td>0.97</td>\n",
1232
+ " <td>0.098</td>\n",
1233
+ " <td>0.98</td>\n",
1234
+ " <td>0.856264</td>\n",
1235
+ " <td>0.886738</td>\n",
1236
+ " <td>0.857339</td>\n",
1237
+ " </tr>\n",
1238
+ " </tbody>\n",
1239
+ "</table>\n",
1240
+ "<p>2 rows × 32 columns</p>\n",
1241
+ "</div>"
1242
+ ],
1243
+ "text/plain": [
1244
+ " epoch steps cos_sim-Accuracy@1 cos_sim-Accuracy@3 \\\n",
1245
+ "model \n",
1246
+ "bge -1 -1 0.705 0.865 \n",
1247
+ "fine_tuned -1 -1 0.790 0.900 \n",
1248
+ "\n",
1249
+ " cos_sim-Accuracy@5 cos_sim-Accuracy@10 cos_sim-Precision@1 \\\n",
1250
+ "model \n",
1251
+ "bge 0.92 0.96 0.705 \n",
1252
+ "fine_tuned 0.97 0.98 0.790 \n",
1253
+ "\n",
1254
+ " cos_sim-Recall@1 cos_sim-Precision@3 cos_sim-Recall@3 ... \\\n",
1255
+ "model ... \n",
1256
+ "bge 0.705 0.288333 0.865 ... \n",
1257
+ "fine_tuned 0.790 0.300000 0.900 ... \n",
1258
+ "\n",
1259
+ " dot_score-Recall@1 dot_score-Precision@3 dot_score-Recall@3 \\\n",
1260
+ "model \n",
1261
+ "bge 0.705 0.288333 0.865 \n",
1262
+ "fine_tuned 0.790 0.300000 0.900 \n",
1263
+ "\n",
1264
+ " dot_score-Precision@5 dot_score-Recall@5 dot_score-Precision@10 \\\n",
1265
+ "model \n",
1266
+ "bge 0.184 0.92 0.096 \n",
1267
+ "fine_tuned 0.194 0.97 0.098 \n",
1268
+ "\n",
1269
+ " dot_score-Recall@10 dot_score-MRR@10 dot_score-NDCG@10 \\\n",
1270
+ "model \n",
1271
+ "bge 0.96 0.792935 0.833595 \n",
1272
+ "fine_tuned 0.98 0.856264 0.886738 \n",
1273
+ "\n",
1274
+ " dot_score-MAP@100 \n",
1275
+ "model \n",
1276
+ "bge 0.795570 \n",
1277
+ "fine_tuned 0.857339 \n",
1278
+ "\n",
1279
+ "[2 rows x 32 columns]"
1280
+ ]
1281
+ },
1282
+ "execution_count": 39,
1283
+ "metadata": {},
1284
+ "output_type": "execute_result"
1285
+ }
1286
+ ],
1287
+ "source": [
1288
+ "df_st_bge[\"model\"] = \"bge\"\n",
1289
+ "df_st_finetuned[\"model\"] = \"fine_tuned\"\n",
1290
+ "df_st_all = pd.concat([df_st_bge, df_st_finetuned])\n",
1291
+ "df_st_all = df_st_all.set_index(\"model\")\n",
1292
+ "df_st_all"
1293
+ ]
1294
+ },
1295
+ {
1296
+ "cell_type": "code",
1297
+ "execution_count": null,
1298
+ "id": "6ed2321b-6618-4a2b-9b1c-028425e91b84",
1299
+ "metadata": {},
1300
+ "outputs": [],
1301
+ "source": []
1302
  }
1303
  ],
1304
  "metadata": {