dragonkue commited on
Commit
66903f6
·
verified ·
1 Parent(s): bc7058c

Upload cross_encoder_eval.ipynb

Browse files
Files changed (1) hide show
  1. cross_encoder_eval.ipynb +236 -0
cross_encoder_eval.ipynb ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "!pip intall numpy pandas FlagEmbedding scikit-learn"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import pandas as pd\n",
17
+ "import numpy as np\n",
18
+ "from sklearn.metrics import precision_score, recall_score, f1_score\n",
19
+ "from FlagEmbedding import FlagReranker\n",
20
+ "import json"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "model_path = '...'\n",
30
+ "qd_df = pd.read_parquet('AutoRAG-example-korean-embedding-benchmark/data/qa_v4.parquet')\n",
31
+ "qd_df['generation_gt'].apply(lambda x : len(x)).describe()\n",
32
+ "qd_df['retrieval_gt'].apply(lambda x : len(x[0])).describe()\n",
33
+ "qd_df['retrieval_gt'] = qd_df['retrieval_gt'].apply(lambda x : x[0][0])\n",
34
+ "\n",
35
+ "corpus_df = pd.read_parquet('AutoRAG-example-korean-embedding-benchmark/data/ocr_corpus_v3.parquet')\n",
36
+ "corpus_id = {}\n",
37
+ "for idx, row in corpus_df.iterrows():\n",
38
+ " corpus_id[row[0]] = row[1]\n"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": 2,
44
+ "metadata": {},
45
+ "outputs": [
46
+ {
47
+ "name": "stderr",
48
+ "output_type": "stream",
49
+ "text": [
50
+ "/tmp/ipykernel_3861538/48936308.py:10: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`\n",
51
+ " corpus_id[row[0]] = row[1]\n",
52
+ "/tmp/ipykernel_3861538/48936308.py:18: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`\n",
53
+ " query_id[row[0]] = row[1]\n"
54
+ ]
55
+ }
56
+ ],
57
+ "source": [
58
+ "qd_df = qd_df[['qid','query','generation_gt','retrieval_gt']]\n",
59
+ "\n",
60
+ "query_id = {}\n",
61
+ "for idx, row in qd_df.iterrows():\n",
62
+ " query_id[row[0]] = row[1]\n",
63
+ "\n",
64
+ "qrel = qd_df[['qid','retrieval_gt']]\n",
65
+ "qrel_id = {}\n",
66
+ "for idx, row in qrel.iterrows():\n",
67
+ " q_id = row.iloc[0]\n",
68
+ " relevant_copus_id = row.iloc[1]\n",
69
+ " if q_id not in qrel_id:\n",
70
+ " qrel_id[q_id] = set()\n",
71
+ " qrel_id[q_id].add(relevant_copus_id)\n",
72
+ "\n",
73
+ "corpus_df = corpus_df[['doc_id','contents']]\n",
74
+ "\n",
75
+ "valid_dict = {}\n",
76
+ "valid_dict['qrel'] =qrel_id"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 3,
82
+ "metadata": {},
83
+ "outputs": [
84
+ {
85
+ "data": {
86
+ "text/plain": [
87
+ "doc_id commerce - B2BDigComm.pdf - 1\n",
88
+ "contents Adobe\\n디지털 커머스 시대,\\nB2B 비즈니스 생존 전략\\nB2B 비즈니스를 ...\n",
89
+ "Name: 0, dtype: object"
90
+ ]
91
+ },
92
+ "execution_count": 3,
93
+ "metadata": {},
94
+ "output_type": "execute_result"
95
+ }
96
+ ],
97
+ "source": [
98
+ "corpus_df.iloc[0]"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": 4,
104
+ "metadata": {},
105
+ "outputs": [
106
+ {
107
+ "data": {
108
+ "text/plain": [
109
+ "Index(['qid', 'query', 'generation_gt', 'retrieval_gt'], dtype='object')"
110
+ ]
111
+ },
112
+ "execution_count": 4,
113
+ "metadata": {},
114
+ "output_type": "execute_result"
115
+ }
116
+ ],
117
+ "source": [
118
+ "qd_df.columns"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 7,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "corpus_df = corpus_df.reset_index(drop=True)\n",
128
+ "qd_df = qd_df.reset_index(drop=True)"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": [
137
+ "def calculate_accuracy(ranks_list, valid_dict, qd_df, k_values=[1, 3, 5]):\n",
138
+ " accuracies = {k: 0 for k in k_values}\n",
139
+ " total_queries = len(qd_df)\n",
140
+ " \n",
141
+ " for i in range(total_queries):\n",
142
+ " search_idx = ranks_list[i]\n",
143
+ " true_doc_idx = corpus_df[corpus_df['doc_id'] == list(valid_dict['qrel'][qd_df.loc[i, 'qid']])[0]].index[0]\n",
144
+ " \n",
145
+ " for k in k_values:\n",
146
+ " top_k_preds = search_idx[:k]\n",
147
+ " if true_doc_idx in top_k_preds:\n",
148
+ " accuracies[k] += 1\n",
149
+ " \n",
150
+ " return {k: accuracies[k] / total_queries for k in k_values}\n",
151
+ "\n",
152
+ "def calculate_f1_recall_precision(ranks_list, valid_dict, qd_df, k_values=[1, 3, 5]):\n",
153
+ " f1_scores = {k: 0 for k in k_values}\n",
154
+ " recall_scores = {k: 0 for k in k_values}\n",
155
+ " precision_scores = {k: 0 for k in k_values}\n",
156
+ " \n",
157
+ " total_queries = len(qd_df)\n",
158
+ " \n",
159
+ " for i in range(total_queries):\n",
160
+ " search_idx = ranks_list[i]\n",
161
+ " true_doc_idx = corpus_df[corpus_df['doc_id'] == list(valid_dict['qrel'][qd_df.loc[i, 'qid']])[0]].index[0]\n",
162
+ " \n",
163
+ " for k in k_values:\n",
164
+ " top_k_preds = search_idx[:k]\n",
165
+ " y_true = [1 if idx == true_doc_idx else 0 for idx in top_k_preds]\n",
166
+ " y_pred = [1] * len(top_k_preds)\n",
167
+ " \n",
168
+ " # Precision, Recall, F1\n",
169
+ " precision_scores[k] += precision_score(y_true, y_pred)\n",
170
+ " recall_scores[k] += recall_score(y_true, y_pred)\n",
171
+ " f1_scores[k] += f1_score(y_true, y_pred)\n",
172
+ " \n",
173
+ " return {k: f1_scores[k] / total_queries for k in k_values}, \\\n",
174
+ " {k: recall_scores[k] / total_queries for k in k_values}, \\\n",
175
+ " {k: precision_scores[k] / total_queries for k in k_values}\n",
176
+ "\n",
177
+ "\n",
178
+ "def evaluate_model(corpus_df, qd_df, valid_dict, reranker):\n",
179
+ " scores_list = []\n",
180
+ " ranks_list = []\n",
181
+ " \n",
182
+ " for c, query in enumerate(qd_df['query'], start=1):\n",
183
+ " corpus_df['query'] = query\n",
184
+ " pair_df = corpus_df[['query', 'contents']]\n",
185
+ " scores = reranker.compute_score(pair_df.values.tolist(), normalize=True)\n",
186
+ " scores = np.array(scores)\n",
187
+ " \n",
188
+ " sorted_idxs = np.argsort(-scores)\n",
189
+ " scores_list.append(scores[sorted_idxs])\n",
190
+ " ranks_list.append(sorted_idxs)\n",
191
+ " print(f'{c}/{len(qd_df)}')\n",
192
+ "\n",
193
+ " k_values = [1, 3, 5, 10]\n",
194
+ " accuracies = calculate_accuracy(ranks_list, valid_dict, qd_df, k_values=k_values)\n",
195
+ " f1_scores, recalls, precisions = calculate_f1_recall_precision(ranks_list, valid_dict, qd_df, k_values=k_values)\n",
196
+ " \n",
197
+ " return accuracies, f1_scores, recalls, precisions\n",
198
+ "\n",
199
+ "\n",
200
+ "# 모델 평가 및 결과 저장\n",
201
+ "reranker = FlagReranker(model_path, use_fp16=True)\n",
202
+ "\n",
203
+ "accuracies, f1_scores, recalls, precisions = evaluate_model(\n",
204
+ " corpus_df.copy(), qd_df, valid_dict, reranker)\n",
205
+ "\n",
206
+ "print(f'Model: {model_path}')\n",
207
+ "for k in [1, 3, 5, 10]:\n",
208
+ " print(f'Accuracy@{k}: {accuracies[k]:.4f}')\n",
209
+ " print(f'F1@{k}: {f1_scores[k]:.4f}')\n",
210
+ " print(f'Recall@{k}: {recalls[k]:.4f}')\n",
211
+ " print(f'Precision@{k}: {precisions[k]:.4f}')\n"
212
+ ]
213
+ }
214
+ ],
215
+ "metadata": {
216
+ "kernelspec": {
217
+ "display_name": "sbert3",
218
+ "language": "python",
219
+ "name": "python3"
220
+ },
221
+ "language_info": {
222
+ "codemirror_mode": {
223
+ "name": "ipython",
224
+ "version": 3
225
+ },
226
+ "file_extension": ".py",
227
+ "mimetype": "text/x-python",
228
+ "name": "python",
229
+ "nbconvert_exporter": "python",
230
+ "pygments_lexer": "ipython3",
231
+ "version": "3.9.19"
232
+ }
233
+ },
234
+ "nbformat": 4,
235
+ "nbformat_minor": 2
236
+ }