kenken999 commited on
Commit
acac1f6
·
1 Parent(s): ac6fdb9
Files changed (1) hide show
  1. polls/test.ipynb +427 -0
polls/test.ipynb ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Requirement already satisfied: datasets in /usr/local/lib/python3.10/site-packages (2.19.2)\n",
13
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/site-packages (from datasets) (1.26.4)\n",
14
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/site-packages (from datasets) (6.0.1)\n",
15
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/site-packages (from datasets) (3.4.1)\n",
16
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/site-packages (from datasets) (3.9.5)\n",
17
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/site-packages (from datasets) (0.3.8)\n",
18
+ "Requirement already satisfied: fsspec[http]<=2024.3.1,>=2023.1.0 in /usr/local/lib/python3.10/site-packages (from datasets) (2024.3.1)\n",
19
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/site-packages (from datasets) (24.0)\n",
20
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/site-packages (from datasets) (3.14.0)\n",
21
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/site-packages (from datasets) (2.2.2)\n",
22
+ "Requirement already satisfied: requests>=2.32.1 in /usr/local/lib/python3.10/site-packages (from datasets) (2.32.3)\n",
23
+ "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/site-packages (from datasets) (4.66.4)\n",
24
+ "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/site-packages (from datasets) (16.1.0)\n",
25
+ "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/site-packages (from datasets) (0.6)\n",
26
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/site-packages (from datasets) (0.70.16)\n",
27
+ "Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/site-packages (from datasets) (0.23.3)\n",
28
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.1)\n",
29
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/site-packages (from aiohttp->datasets) (23.2.0)\n",
30
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/site-packages (from aiohttp->datasets) (1.4.1)\n",
31
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/site-packages (from aiohttp->datasets) (4.0.3)\n",
32
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/site-packages (from aiohttp->datasets) (6.0.5)\n",
33
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/site-packages (from aiohttp->datasets) (1.9.4)\n",
34
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/site-packages (from huggingface-hub>=0.21.2->datasets) (4.10.0)\n",
35
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/site-packages (from requests>=2.32.1->datasets) (2024.2.2)\n",
36
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/site-packages (from requests>=2.32.1->datasets) (3.6)\n",
37
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/site-packages (from requests>=2.32.1->datasets) (3.3.2)\n",
38
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/site-packages (from requests>=2.32.1->datasets) (2.2.1)\n",
39
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/site-packages (from pandas->datasets) (2024.1)\n",
40
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/site-packages (from pandas->datasets) (2024.1)\n",
41
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/site-packages (from pandas->datasets) (2.9.0.post0)\n",
42
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
43
+ "\n",
44
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
45
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
46
+ ]
47
+ }
48
+ ],
49
+ "source": [
50
+ "!pip install datasets"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 2,
56
+ "metadata": {},
57
+ "outputs": [
58
+ {
59
+ "name": "stderr",
60
+ "output_type": "stream",
61
+ "text": [
62
+ "/usr/local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
63
+ " from .autonotebook import tqdm as notebook_tqdm\n"
64
+ ]
65
+ }
66
+ ],
67
+ "source": [
68
+ "import pandas as pd\n",
69
+ "from datasets import Dataset\n",
70
+ "\n",
71
+ "# QAペアのデータセットを作成\n",
72
+ "data = {\n",
73
+ " \"question\": [\"What is the capital of France?\", \"Who wrote 1984?\", \"What is the largest planet in our solar system?\"],\n",
74
+ " \"answer\": [\"Paris\", \"George Orwell\", \"Jupiter\"]\n",
75
+ "}\n",
76
+ "\n",
77
+ "df = pd.DataFrame(data)\n",
78
+ "dataset = Dataset.from_pandas(df)\n"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": 3,
84
+ "metadata": {},
85
+ "outputs": [
86
+ {
87
+ "name": "stderr",
88
+ "output_type": "stream",
89
+ "text": [
90
+ "/usr/local/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
91
+ " warnings.warn(\n",
92
+ "Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']\n",
93
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
94
+ ]
95
+ }
96
+ ],
97
+ "source": [
98
+ "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, Trainer, TrainingArguments\n",
99
+ "\n",
100
+ "model_name = \"distilbert-base-uncased\"\n",
101
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
102
+ "model = AutoModelForQuestionAnswering.from_pretrained(model_name)\n"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": 5,
108
+ "metadata": {},
109
+ "outputs": [
110
+ {
111
+ "name": "stderr",
112
+ "output_type": "stream",
113
+ "text": [
114
+ "Map: 100%|██████████| 3/3 [00:00<00:00, 576.25 examples/s]\n"
115
+ ]
116
+ },
117
+ {
118
+ "ename": "ValueError",
119
+ "evalue": "The model did not return a loss from the inputs, only the following keys: start_logits,end_logits. For reference, the inputs it received are input_ids,attention_mask.",
120
+ "output_type": "error",
121
+ "traceback": [
122
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
123
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
124
+ "\u001b[1;32m/home/user/app/polls/test.ipynb Cell 4\u001b[0m line \u001b[0;36m2\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=9'>10</a>\u001b[0m training_args \u001b[39m=\u001b[39m TrainingArguments(\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10'>11</a>\u001b[0m output_dir\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m./results\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=11'>12</a>\u001b[0m evaluation_strategy\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mepoch\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=15'>16</a>\u001b[0m weight_decay\u001b[39m=\u001b[39m\u001b[39m0.01\u001b[39m,\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=16'>17</a>\u001b[0m )\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=18'>19</a>\u001b[0m trainer \u001b[39m=\u001b[39m Trainer(\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=19'>20</a>\u001b[0m model\u001b[39m=\u001b[39mmodel,\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=20'>21</a>\u001b[0m args\u001b[39m=\u001b[39mtraining_args,\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=21'>22</a>\u001b[0m train_dataset\u001b[39m=\u001b[39mtokenized_dataset,\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=22'>23</a>\u001b[0m eval_dataset\u001b[39m=\u001b[39mtokenized_dataset,\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=23'>24</a>\u001b[0m )\n\u001b[0;32m---> <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W3sdnNjb2RlLXJlbW90ZQ%3D%3D?line=25'>26</a>\u001b[0m trainer\u001b[39m.\u001b[39;49mtrain()\n",
125
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/transformers/trainer.py:1885\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1883\u001b[0m hf_hub_utils\u001b[39m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1884\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1885\u001b[0m \u001b[39mreturn\u001b[39;00m inner_training_loop(\n\u001b[1;32m 1886\u001b[0m args\u001b[39m=\u001b[39;49margs,\n\u001b[1;32m 1887\u001b[0m resume_from_checkpoint\u001b[39m=\u001b[39;49mresume_from_checkpoint,\n\u001b[1;32m 1888\u001b[0m trial\u001b[39m=\u001b[39;49mtrial,\n\u001b[1;32m 1889\u001b[0m ignore_keys_for_eval\u001b[39m=\u001b[39;49mignore_keys_for_eval,\n\u001b[1;32m 1890\u001b[0m )\n",
126
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/transformers/trainer.py:2216\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2213\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcontrol \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcallback_handler\u001b[39m.\u001b[39mon_step_begin(args, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcontrol)\n\u001b[1;32m 2215\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39maccelerator\u001b[39m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 2216\u001b[0m tr_loss_step \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtraining_step(model, inputs)\n\u001b[1;32m 2218\u001b[0m \u001b[39mif\u001b[39;00m (\n\u001b[1;32m 2219\u001b[0m args\u001b[39m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2220\u001b[0m \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2221\u001b[0m \u001b[39mand\u001b[39;00m (torch\u001b[39m.\u001b[39misnan(tr_loss_step) \u001b[39mor\u001b[39;00m torch\u001b[39m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2222\u001b[0m ):\n\u001b[1;32m 2223\u001b[0m \u001b[39m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2224\u001b[0m tr_loss \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m tr_loss \u001b[39m/\u001b[39m (\u001b[39m1\u001b[39m \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstate\u001b[39m.\u001b[39mglobal_step \u001b[39m-\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_globalstep_last_logged)\n",
127
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/transformers/trainer.py:3238\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 3235\u001b[0m \u001b[39mreturn\u001b[39;00m loss_mb\u001b[39m.\u001b[39mreduce_mean()\u001b[39m.\u001b[39mdetach()\u001b[39m.\u001b[39mto(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mdevice)\n\u001b[1;32m 3237\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 3238\u001b[0m loss \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcompute_loss(model, inputs)\n\u001b[1;32m 3240\u001b[0m \u001b[39mdel\u001b[39;00m inputs\n\u001b[1;32m 3241\u001b[0m torch\u001b[39m.\u001b[39mcuda\u001b[39m.\u001b[39mempty_cache()\n",
128
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/transformers/trainer.py:3282\u001b[0m, in \u001b[0;36mTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 3280\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 3281\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(outputs, \u001b[39mdict\u001b[39m) \u001b[39mand\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39mloss\u001b[39m\u001b[39m\"\u001b[39m \u001b[39mnot\u001b[39;00m \u001b[39min\u001b[39;00m outputs:\n\u001b[0;32m-> 3282\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 3283\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mThe model did not return a loss from the inputs, only the following keys: \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 3284\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m,\u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(outputs\u001b[39m.\u001b[39mkeys())\u001b[39m}\u001b[39;00m\u001b[39m. For reference, the inputs it received are \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m,\u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(inputs\u001b[39m.\u001b[39mkeys())\u001b[39m}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 3285\u001b[0m )\n\u001b[1;32m 3286\u001b[0m \u001b[39m# We don't use .loss here since the model may return tuples instead of ModelOutput.\u001b[39;00m\n\u001b[1;32m 3287\u001b[0m loss \u001b[39m=\u001b[39m outputs[\u001b[39m\"\u001b[39m\u001b[39mloss\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(outputs, \u001b[39mdict\u001b[39m) \u001b[39melse\u001b[39;00m outputs[\u001b[39m0\u001b[39m]\n",
129
+ "\u001b[0;31mValueError\u001b[0m: The model did not return a loss from the inputs, only the following keys: start_logits,end_logits. For reference, the inputs it received are input_ids,attention_mask."
130
+ ]
131
+ }
132
+ ],
133
+ "source": [
134
+ "def preprocess_function(examples):\n",
135
+ " questions = examples[\"question\"]\n",
136
+ " answers = examples[\"answer\"]\n",
137
+ " inputs = tokenizer(questions, truncation=True, padding=True)\n",
138
+ " inputs[\"labels\"] = tokenizer(answers, truncation=True, padding=True)[\"input_ids\"]\n",
139
+ " return inputs\n",
140
+ "\n",
141
+ "tokenized_dataset = dataset.map(preprocess_function, batched=True)\n",
142
+ "\n",
143
+ "training_args = TrainingArguments(\n",
144
+ " output_dir=\"./results\",\n",
145
+ " evaluation_strategy=\"epoch\",\n",
146
+ " learning_rate=2e-5,\n",
147
+ " per_device_train_batch_size=2,\n",
148
+ " num_train_epochs=3,\n",
149
+ " weight_decay=0.01,\n",
150
+ ")\n",
151
+ "\n",
152
+ "trainer = Trainer(\n",
153
+ " model=model,\n",
154
+ " args=training_args,\n",
155
+ " train_dataset=tokenized_dataset,\n",
156
+ " eval_dataset=tokenized_dataset,\n",
157
+ ")\n",
158
+ "\n",
159
+ "trainer.train()\n"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": 6,
165
+ "metadata": {},
166
+ "outputs": [
167
+ {
168
+ "name": "stderr",
169
+ "output_type": "stream",
170
+ "text": [
171
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
172
+ "To disable this warning, you can either:\n",
173
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
174
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
175
+ ]
176
+ },
177
+ {
178
+ "name": "stdout",
179
+ "output_type": "stream",
180
+ "text": [
181
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/site-packages (4.41.2)\n",
182
+ "Requirement already satisfied: datasets in /usr/local/lib/python3.10/site-packages (2.19.2)\n",
183
+ "Collecting faiss-cpu\n",
184
+ " Downloading faiss_cpu-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (27.0 MB)\n",
185
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m27.0/27.0 MB\u001b[0m \u001b[31m64.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
186
+ "\u001b[?25hRequirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/site-packages (from transformers) (6.0.1)\n",
187
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/site-packages (from transformers) (2024.5.15)\n",
188
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/site-packages (from transformers) (2.32.3)\n",
189
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/site-packages (from transformers) (3.14.0)\n",
190
+ "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/site-packages (from transformers) (0.19.1)\n",
191
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.23.0 in /usr/local/lib/python3.10/site-packages (from transformers) (0.23.3)\n",
192
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/site-packages (from transformers) (4.66.4)\n",
193
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/site-packages (from transformers) (1.26.4)\n",
194
+ "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/site-packages (from transformers) (0.4.3)\n",
195
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/site-packages (from transformers) (24.0)\n",
196
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/site-packages (from datasets) (0.70.16)\n",
197
+ "Requirement already satisfied: fsspec[http]<=2024.3.1,>=2023.1.0 in /usr/local/lib/python3.10/site-packages (from datasets) (2024.3.1)\n",
198
+ "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/site-packages (from datasets) (0.6)\n",
199
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/site-packages (from datasets) (2.2.2)\n",
200
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/site-packages (from datasets) (3.4.1)\n",
201
+ "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/site-packages (from datasets) (16.1.0)\n",
202
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/site-packages (from datasets) (0.3.8)\n",
203
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/site-packages (from datasets) (3.9.5)\n",
204
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/site-packages (from aiohttp->datasets) (4.0.3)\n",
205
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/site-packages (from aiohttp->datasets) (1.4.1)\n",
206
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/site-packages (from aiohttp->datasets) (23.2.0)\n",
207
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.1)\n",
208
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/site-packages (from aiohttp->datasets) (1.9.4)\n",
209
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/site-packages (from aiohttp->datasets) (6.0.5)\n",
210
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.23.0->transformers) (4.10.0)\n",
211
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/site-packages (from requests->transformers) (3.3.2)\n",
212
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/site-packages (from requests->transformers) (3.6)\n",
213
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/site-packages (from requests->transformers) (2.2.1)\n",
214
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/site-packages (from requests->transformers) (2024.2.2)\n",
215
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/site-packages (from pandas->datasets) (2024.1)\n",
216
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/site-packages (from pandas->datasets) (2024.1)\n",
217
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/site-packages (from pandas->datasets) (2.9.0.post0)\n",
218
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
219
+ "Installing collected packages: faiss-cpu\n",
220
+ "Successfully installed faiss-cpu-1.8.0\n",
221
+ "\n",
222
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
223
+ "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
224
+ ]
225
+ }
226
+ ],
227
+ "source": [
228
+ "!pip install transformers datasets faiss-cpu\n"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": 8,
234
+ "metadata": {},
235
+ "outputs": [
236
+ {
237
+ "ename": "ValueError",
238
+ "evalue": "Loading wiki_dpr requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set the option `trust_remote_code=True` to remove this error.",
239
+ "output_type": "error",
240
+ "traceback": [
241
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
242
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
243
+ "\u001b[1;32m/home/user/app/polls/test.ipynb Cell 6\u001b[0m line \u001b[0;36m4\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mdatasets\u001b[39;00m \u001b[39mimport\u001b[39;00m load_dataset\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39m# データセットのロード\u001b[39;00m\n\u001b[0;32m----> <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>\u001b[0m dataset \u001b[39m=\u001b[39m load_dataset(\u001b[39m'\u001b[39;49m\u001b[39mwiki_dpr\u001b[39;49m\u001b[39m'\u001b[39;49m, \u001b[39m'\u001b[39;49m\u001b[39mpsgs_w100\u001b[39;49m\u001b[39m'\u001b[39;49m)\n",
244
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/load.py:2592\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[1;32m 2587\u001b[0m verification_mode \u001b[39m=\u001b[39m VerificationMode(\n\u001b[1;32m 2588\u001b[0m (verification_mode \u001b[39mor\u001b[39;00m VerificationMode\u001b[39m.\u001b[39mBASIC_CHECKS) \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m save_infos \u001b[39melse\u001b[39;00m VerificationMode\u001b[39m.\u001b[39mALL_CHECKS\n\u001b[1;32m 2589\u001b[0m )\n\u001b[1;32m 2591\u001b[0m \u001b[39m# Create a dataset builder\u001b[39;00m\n\u001b[0;32m-> 2592\u001b[0m builder_instance \u001b[39m=\u001b[39m load_dataset_builder(\n\u001b[1;32m 2593\u001b[0m path\u001b[39m=\u001b[39;49mpath,\n\u001b[1;32m 2594\u001b[0m name\u001b[39m=\u001b[39;49mname,\n\u001b[1;32m 2595\u001b[0m data_dir\u001b[39m=\u001b[39;49mdata_dir,\n\u001b[1;32m 2596\u001b[0m data_files\u001b[39m=\u001b[39;49mdata_files,\n\u001b[1;32m 2597\u001b[0m cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m 2598\u001b[0m features\u001b[39m=\u001b[39;49mfeatures,\n\u001b[1;32m 2599\u001b[0m download_config\u001b[39m=\u001b[39;49mdownload_config,\n\u001b[1;32m 2600\u001b[0m download_mode\u001b[39m=\u001b[39;49mdownload_mode,\n\u001b[1;32m 2601\u001b[0m revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m 2602\u001b[0m token\u001b[39m=\u001b[39;49mtoken,\n\u001b[1;32m 2603\u001b[0m storage_options\u001b[39m=\u001b[39;49mstorage_options,\n\u001b[1;32m 2604\u001b[0m trust_remote_code\u001b[39m=\u001b[39;49mtrust_remote_code,\n\u001b[1;32m 2605\u001b[0m _require_default_config_name\u001b[39m=\u001b[39;49mname \u001b[39mis\u001b[39;49;00m \u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 2606\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mconfig_kwargs,\n\u001b[1;32m 2607\u001b[0m )\n\u001b[1;32m 2609\u001b[0m \u001b[39m# Return iterable dataset in case of streaming\u001b[39;00m\n\u001b[1;32m 2610\u001b[0m \u001b[39mif\u001b[39;00m streaming:\n",
245
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/load.py:2264\u001b[0m, in \u001b[0;36mload_dataset_builder\u001b[0;34m(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, token, use_auth_token, storage_options, trust_remote_code, _require_default_config_name, **config_kwargs)\u001b[0m\n\u001b[1;32m 2262\u001b[0m download_config \u001b[39m=\u001b[39m download_config\u001b[39m.\u001b[39mcopy() \u001b[39mif\u001b[39;00m download_config \u001b[39melse\u001b[39;00m DownloadConfig()\n\u001b[1;32m 2263\u001b[0m download_config\u001b[39m.\u001b[39mstorage_options\u001b[39m.\u001b[39mupdate(storage_options)\n\u001b[0;32m-> 2264\u001b[0m dataset_module \u001b[39m=\u001b[39m dataset_module_factory(\n\u001b[1;32m 2265\u001b[0m path,\n\u001b[1;32m 2266\u001b[0m revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m 2267\u001b[0m download_config\u001b[39m=\u001b[39;49mdownload_config,\n\u001b[1;32m 2268\u001b[0m download_mode\u001b[39m=\u001b[39;49mdownload_mode,\n\u001b[1;32m 2269\u001b[0m data_dir\u001b[39m=\u001b[39;49mdata_dir,\n\u001b[1;32m 2270\u001b[0m data_files\u001b[39m=\u001b[39;49mdata_files,\n\u001b[1;32m 2271\u001b[0m cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m 2272\u001b[0m trust_remote_code\u001b[39m=\u001b[39;49mtrust_remote_code,\n\u001b[1;32m 2273\u001b[0m _require_default_config_name\u001b[39m=\u001b[39;49m_require_default_config_name,\n\u001b[1;32m 2274\u001b[0m _require_custom_configs\u001b[39m=\u001b[39;49m\u001b[39mbool\u001b[39;49m(config_kwargs),\n\u001b[1;32m 2275\u001b[0m )\n\u001b[1;32m 2276\u001b[0m \u001b[39m# Get dataset builder class from the processing script\u001b[39;00m\n\u001b[1;32m 2277\u001b[0m builder_kwargs \u001b[39m=\u001b[39m dataset_module\u001b[39m.\u001b[39mbuilder_kwargs\n",
246
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/load.py:1915\u001b[0m, in \u001b[0;36mdataset_module_factory\u001b[0;34m(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, cache_dir, trust_remote_code, _require_default_config_name, _require_custom_configs, **download_kwargs)\u001b[0m\n\u001b[1;32m 1910\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(e1, \u001b[39mFileNotFoundError\u001b[39;00m):\n\u001b[1;32m 1911\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mFileNotFoundError\u001b[39;00m(\n\u001b[1;32m 1912\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mCouldn\u001b[39m\u001b[39m'\u001b[39m\u001b[39mt find a dataset script at \u001b[39m\u001b[39m{\u001b[39;00mrelative_to_absolute_path(combined_path)\u001b[39m}\u001b[39;00m\u001b[39m or any data file in the same directory. \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 1913\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mCouldn\u001b[39m\u001b[39m'\u001b[39m\u001b[39mt find \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mpath\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m on the Hugging Face Hub either: \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mtype\u001b[39m(e1)\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m: \u001b[39m\u001b[39m{\u001b[39;00me1\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m 1914\u001b[0m ) \u001b[39mfrom\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m-> 1915\u001b[0m \u001b[39mraise\u001b[39;00m e1 \u001b[39mfrom\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 1916\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 1917\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mFileNotFoundError\u001b[39;00m(\n\u001b[1;32m 1918\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mCouldn\u001b[39m\u001b[39m'\u001b[39m\u001b[39mt find a dataset script at \u001b[39m\u001b[39m{\u001b[39;00mrelative_to_absolute_path(combined_path)\u001b[39m}\u001b[39;00m\u001b[39m or any data file in the same directory.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 1919\u001b[0m )\n",
247
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/load.py:1888\u001b[0m, in \u001b[0;36mdataset_module_factory\u001b[0;34m(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, cache_dir, trust_remote_code, _require_default_config_name, _require_custom_configs, **download_kwargs)\u001b[0m\n\u001b[1;32m 1879\u001b[0m \u001b[39mpass\u001b[39;00m\n\u001b[1;32m 1880\u001b[0m \u001b[39m# Otherwise we must use the dataset script if the user trusts it\u001b[39;00m\n\u001b[1;32m 1881\u001b[0m \u001b[39mreturn\u001b[39;00m HubDatasetModuleFactoryWithScript(\n\u001b[1;32m 1882\u001b[0m path,\n\u001b[1;32m 1883\u001b[0m revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m 1884\u001b[0m download_config\u001b[39m=\u001b[39;49mdownload_config,\n\u001b[1;32m 1885\u001b[0m download_mode\u001b[39m=\u001b[39;49mdownload_mode,\n\u001b[1;32m 1886\u001b[0m dynamic_modules_path\u001b[39m=\u001b[39;49mdynamic_modules_path,\n\u001b[1;32m 1887\u001b[0m trust_remote_code\u001b[39m=\u001b[39;49mtrust_remote_code,\n\u001b[0;32m-> 1888\u001b[0m )\u001b[39m.\u001b[39;49mget_module()\n\u001b[1;32m 1889\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 1890\u001b[0m \u001b[39mreturn\u001b[39;00m HubDatasetModuleFactoryWithoutScript(\n\u001b[1;32m 1891\u001b[0m path,\n\u001b[1;32m 1892\u001b[0m revision\u001b[39m=\u001b[39mrevision,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1896\u001b[0m download_mode\u001b[39m=\u001b[39mdownload_mode,\n\u001b[1;32m 1897\u001b[0m )\u001b[39m.\u001b[39mget_module()\n",
248
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/load.py:1537\u001b[0m, in \u001b[0;36mHubDatasetModuleFactoryWithScript.get_module\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1526\u001b[0m _create_importable_file(\n\u001b[1;32m 1527\u001b[0m local_path\u001b[39m=\u001b[39mlocal_path,\n\u001b[1;32m 1528\u001b[0m local_imports\u001b[39m=\u001b[39mlocal_imports,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1534\u001b[0m download_mode\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdownload_mode,\n\u001b[1;32m 1535\u001b[0m )\n\u001b[1;32m 1536\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 1537\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 1538\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mLoading \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mname\u001b[39m}\u001b[39;00m\u001b[39m requires you to execute the dataset script in that\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 1539\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m repo on your local machine. Make sure you have read the code there to avoid malicious use, then\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 1540\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m set the option `trust_remote_code=True` to remove this error.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 1541\u001b[0m )\n\u001b[1;32m 1542\u001b[0m module_path, \u001b[39mhash\u001b[39m \u001b[39m=\u001b[39m _load_importable_file(\n\u001b[1;32m 1543\u001b[0m dynamic_modules_path\u001b[39m=\u001b[39mdynamic_modules_path,\n\u001b[1;32m 1544\u001b[0m module_namespace\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mdatasets\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 1545\u001b[0m subdirectory_name\u001b[39m=\u001b[39m\u001b[39mhash\u001b[39m,\n\u001b[1;32m 1546\u001b[0m name\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mname,\n\u001b[1;32m 1547\u001b[0m )\n\u001b[1;32m 1548\u001b[0m \u001b[39m# make the new module to be noticed by the import system\u001b[39;00m\n",
249
+ "\u001b[0;31mValueError\u001b[0m: Loading wiki_dpr requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set the option `trust_remote_code=True` to remove this error."
250
+ ]
251
+ }
252
+ ],
253
+ "source": [
254
+ "from datasets import load_dataset\n",
255
+ "\n",
256
+ "# データセットのロード\n",
257
+ "dataset = load_dataset('wiki_dpr', 'psgs_w100')\n"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": 9,
263
+ "metadata": {},
264
+ "outputs": [
265
+ {
266
+ "ename": "ValueError",
267
+ "evalue": "BuilderConfig 'psgs_w100' not found. Available: ['psgs_w100.nq.exact', 'psgs_w100.nq.compressed', 'psgs_w100.nq.no_index', 'psgs_w100.multiset.exact', 'psgs_w100.multiset.compressed', 'psgs_w100.multiset.no_index', 'psgs_w100.nq.exact.no_embeddings', 'psgs_w100.nq.compressed.no_embeddings', 'psgs_w100.nq.no_index.no_embeddings', 'psgs_w100.multiset.exact.no_embeddings', 'psgs_w100.multiset.compressed.no_embeddings', 'psgs_w100.multiset.no_index.no_embeddings']",
268
+ "output_type": "error",
269
+ "traceback": [
270
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
271
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
272
+ "\u001b[1;32m/home/user/app/polls/test.ipynb Cell 7\u001b[0m line \u001b[0;36m4\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mdatasets\u001b[39;00m \u001b[39mimport\u001b[39;00m load_dataset\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39m# データセットのロード\u001b[39;00m\n\u001b[0;32m----> <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>\u001b[0m dataset \u001b[39m=\u001b[39m load_dataset(\u001b[39m'\u001b[39;49m\u001b[39mwiki_dpr\u001b[39;49m\u001b[39m'\u001b[39;49m, \u001b[39m'\u001b[39;49m\u001b[39mpsgs_w100\u001b[39;49m\u001b[39m'\u001b[39;49m, trust_remote_code\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n",
273
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/load.py:2592\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[1;32m 2587\u001b[0m verification_mode \u001b[39m=\u001b[39m VerificationMode(\n\u001b[1;32m 2588\u001b[0m (verification_mode \u001b[39mor\u001b[39;00m VerificationMode\u001b[39m.\u001b[39mBASIC_CHECKS) \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m save_infos \u001b[39melse\u001b[39;00m VerificationMode\u001b[39m.\u001b[39mALL_CHECKS\n\u001b[1;32m 2589\u001b[0m )\n\u001b[1;32m 2591\u001b[0m \u001b[39m# Create a dataset builder\u001b[39;00m\n\u001b[0;32m-> 2592\u001b[0m builder_instance \u001b[39m=\u001b[39m load_dataset_builder(\n\u001b[1;32m 2593\u001b[0m path\u001b[39m=\u001b[39;49mpath,\n\u001b[1;32m 2594\u001b[0m name\u001b[39m=\u001b[39;49mname,\n\u001b[1;32m 2595\u001b[0m data_dir\u001b[39m=\u001b[39;49mdata_dir,\n\u001b[1;32m 2596\u001b[0m data_files\u001b[39m=\u001b[39;49mdata_files,\n\u001b[1;32m 2597\u001b[0m cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m 2598\u001b[0m features\u001b[39m=\u001b[39;49mfeatures,\n\u001b[1;32m 2599\u001b[0m download_config\u001b[39m=\u001b[39;49mdownload_config,\n\u001b[1;32m 2600\u001b[0m download_mode\u001b[39m=\u001b[39;49mdownload_mode,\n\u001b[1;32m 2601\u001b[0m revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m 2602\u001b[0m token\u001b[39m=\u001b[39;49mtoken,\n\u001b[1;32m 2603\u001b[0m storage_options\u001b[39m=\u001b[39;49mstorage_options,\n\u001b[1;32m 2604\u001b[0m trust_remote_code\u001b[39m=\u001b[39;49mtrust_remote_code,\n\u001b[1;32m 2605\u001b[0m _require_default_config_name\u001b[39m=\u001b[39;49mname \u001b[39mis\u001b[39;49;00m \u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 2606\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mconfig_kwargs,\n\u001b[1;32m 2607\u001b[0m )\n\u001b[1;32m 2609\u001b[0m \u001b[39m# Return iterable dataset in case of streaming\u001b[39;00m\n\u001b[1;32m 2610\u001b[0m \u001b[39mif\u001b[39;00m streaming:\n",
274
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/load.py:2301\u001b[0m, in \u001b[0;36mload_dataset_builder\u001b[0;34m(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, token, use_auth_token, storage_options, trust_remote_code, _require_default_config_name, **config_kwargs)\u001b[0m\n\u001b[1;32m 2299\u001b[0m builder_cls \u001b[39m=\u001b[39m get_dataset_builder_class(dataset_module, dataset_name\u001b[39m=\u001b[39mdataset_name)\n\u001b[1;32m 2300\u001b[0m \u001b[39m# Instantiate the dataset builder\u001b[39;00m\n\u001b[0;32m-> 2301\u001b[0m builder_instance: DatasetBuilder \u001b[39m=\u001b[39m builder_cls(\n\u001b[1;32m 2302\u001b[0m cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m 2303\u001b[0m dataset_name\u001b[39m=\u001b[39;49mdataset_name,\n\u001b[1;32m 2304\u001b[0m config_name\u001b[39m=\u001b[39;49mconfig_name,\n\u001b[1;32m 2305\u001b[0m data_dir\u001b[39m=\u001b[39;49mdata_dir,\n\u001b[1;32m 2306\u001b[0m data_files\u001b[39m=\u001b[39;49mdata_files,\n\u001b[1;32m 2307\u001b[0m \u001b[39mhash\u001b[39;49m\u001b[39m=\u001b[39;49mdataset_module\u001b[39m.\u001b[39;49mhash,\n\u001b[1;32m 2308\u001b[0m info\u001b[39m=\u001b[39;49minfo,\n\u001b[1;32m 2309\u001b[0m features\u001b[39m=\u001b[39;49mfeatures,\n\u001b[1;32m 2310\u001b[0m token\u001b[39m=\u001b[39;49mtoken,\n\u001b[1;32m 2311\u001b[0m storage_options\u001b[39m=\u001b[39;49mstorage_options,\n\u001b[1;32m 2312\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mbuilder_kwargs,\n\u001b[1;32m 2313\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mconfig_kwargs,\n\u001b[1;32m 2314\u001b[0m )\n\u001b[1;32m 2315\u001b[0m builder_instance\u001b[39m.\u001b[39m_use_legacy_cache_dir_if_possible(dataset_module)\n\u001b[1;32m 2317\u001b[0m \u001b[39mreturn\u001b[39;00m builder_instance\n",
275
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/builder.py:374\u001b[0m, in \u001b[0;36mDatasetBuilder.__init__\u001b[0;34m(self, cache_dir, dataset_name, config_name, hash, base_path, info, features, token, use_auth_token, repo_id, data_files, data_dir, storage_options, writer_batch_size, name, **config_kwargs)\u001b[0m\n\u001b[1;32m 372\u001b[0m config_kwargs[\u001b[39m\"\u001b[39m\u001b[39mdata_dir\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m data_dir\n\u001b[1;32m 373\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig_kwargs \u001b[39m=\u001b[39m config_kwargs\n\u001b[0;32m--> 374\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig_id \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_create_builder_config(\n\u001b[1;32m 375\u001b[0m config_name\u001b[39m=\u001b[39;49mconfig_name,\n\u001b[1;32m 376\u001b[0m custom_features\u001b[39m=\u001b[39;49mfeatures,\n\u001b[1;32m 377\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mconfig_kwargs,\n\u001b[1;32m 378\u001b[0m )\n\u001b[1;32m 380\u001b[0m \u001b[39m# prepare info: DatasetInfo are a standardized dataclass across all datasets\u001b[39;00m\n\u001b[1;32m 381\u001b[0m \u001b[39m# Prefill datasetinfo\u001b[39;00m\n\u001b[1;32m 382\u001b[0m \u001b[39mif\u001b[39;00m info \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 383\u001b[0m \u001b[39m# TODO FOR PACKAGED MODULES IT IMPORTS DATA FROM src/packaged_modules which doesn't make sense\u001b[39;00m\n",
276
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/builder.py:599\u001b[0m, in \u001b[0;36mDatasetBuilder._create_builder_config\u001b[0;34m(self, config_name, custom_features, **config_kwargs)\u001b[0m\n\u001b[1;32m 597\u001b[0m builder_config \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuilder_configs\u001b[39m.\u001b[39mget(config_name)\n\u001b[1;32m 598\u001b[0m \u001b[39mif\u001b[39;00m builder_config \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mBUILDER_CONFIGS:\n\u001b[0;32m--> 599\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 600\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mBuilderConfig \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mconfig_name\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m not found. Available: \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mlist\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuilder_configs\u001b[39m.\u001b[39mkeys())\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[1;32m 601\u001b[0m )\n\u001b[1;32m 603\u001b[0m \u001b[39m# if not using an existing config, then create a new config on the fly\u001b[39;00m\n\u001b[1;32m 604\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m builder_config:\n",
277
+ "\u001b[0;31mValueError\u001b[0m: BuilderConfig 'psgs_w100' not found. Available: ['psgs_w100.nq.exact', 'psgs_w100.nq.compressed', 'psgs_w100.nq.no_index', 'psgs_w100.multiset.exact', 'psgs_w100.multiset.compressed', 'psgs_w100.multiset.no_index', 'psgs_w100.nq.exact.no_embeddings', 'psgs_w100.nq.compressed.no_embeddings', 'psgs_w100.nq.no_index.no_embeddings', 'psgs_w100.multiset.exact.no_embeddings', 'psgs_w100.multiset.compressed.no_embeddings', 'psgs_w100.multiset.no_index.no_embeddings']"
278
+ ]
279
+ }
280
+ ],
281
+ "source": [
282
+ "from datasets import load_dataset\n",
283
+ "\n",
284
+ "# データセットのロード\n",
285
+ "dataset = load_dataset('wiki_dpr', 'psgs_w100', trust_remote_code=True)\n"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": 10,
291
+ "metadata": {},
292
+ "outputs": [
293
+ {
294
+ "name": "stderr",
295
+ "output_type": "stream",
296
+ "text": [
297
+ "Downloading data: 0%| | 0/157 [02:34<?, ?files/s]\n"
298
+ ]
299
+ },
300
+ {
301
+ "ename": "KeyboardInterrupt",
302
+ "evalue": "",
303
+ "output_type": "error",
304
+ "traceback": [
305
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
306
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
307
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/tqdm/contrib/concurrent.py:51\u001b[0m, in \u001b[0;36m_executor_map\u001b[0;34m(PoolExecutor, fn, *iterables, **tqdm_kwargs)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[39mwith\u001b[39;00m PoolExecutor(max_workers\u001b[39m=\u001b[39mmax_workers, initializer\u001b[39m=\u001b[39mtqdm_class\u001b[39m.\u001b[39mset_lock,\n\u001b[1;32m 50\u001b[0m initargs\u001b[39m=\u001b[39m(lk,)) \u001b[39mas\u001b[39;00m ex:\n\u001b[0;32m---> 51\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mlist\u001b[39;49m(tqdm_class(ex\u001b[39m.\u001b[39;49mmap(fn, \u001b[39m*\u001b[39;49miterables, chunksize\u001b[39m=\u001b[39;49mchunksize), \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs))\n",
308
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/tqdm/std.py:1181\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1180\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 1181\u001b[0m \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m iterable:\n\u001b[1;32m 1182\u001b[0m \u001b[39myield\u001b[39;00m obj\n",
309
+ "File \u001b[0;32m/usr/local/lib/python3.10/concurrent/futures/_base.py:621\u001b[0m, in \u001b[0;36mExecutor.map.<locals>.result_iterator\u001b[0;34m()\u001b[0m\n\u001b[1;32m 620\u001b[0m \u001b[39mif\u001b[39;00m timeout \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 621\u001b[0m \u001b[39myield\u001b[39;00m _result_or_cancel(fs\u001b[39m.\u001b[39;49mpop())\n\u001b[1;32m 622\u001b[0m \u001b[39melse\u001b[39;00m:\n",
310
+ "File \u001b[0;32m/usr/local/lib/python3.10/concurrent/futures/_base.py:319\u001b[0m, in \u001b[0;36m_result_or_cancel\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 318\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 319\u001b[0m \u001b[39mreturn\u001b[39;00m fut\u001b[39m.\u001b[39;49mresult(timeout)\n\u001b[1;32m 320\u001b[0m \u001b[39mfinally\u001b[39;00m:\n",
311
+ "File \u001b[0;32m/usr/local/lib/python3.10/concurrent/futures/_base.py:453\u001b[0m, in \u001b[0;36mFuture.result\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 451\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__get_result()\n\u001b[0;32m--> 453\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_condition\u001b[39m.\u001b[39;49mwait(timeout)\n\u001b[1;32m 455\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_state \u001b[39min\u001b[39;00m [CANCELLED, CANCELLED_AND_NOTIFIED]:\n",
312
+ "File \u001b[0;32m/usr/local/lib/python3.10/threading.py:320\u001b[0m, in \u001b[0;36mCondition.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[39mif\u001b[39;00m timeout \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 320\u001b[0m waiter\u001b[39m.\u001b[39;49macquire()\n\u001b[1;32m 321\u001b[0m gotit \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n",
313
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: ",
314
+ "\nDuring handling of the above exception, another exception occurred:\n",
315
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
316
+ "\u001b[1;32m/home/user/app/polls/test.ipynb Cell 8\u001b[0m line \u001b[0;36m4\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mdatasets\u001b[39;00m \u001b[39mimport\u001b[39;00m load_dataset\n\u001b[1;32m <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a>\u001b[0m \u001b[39m# データセットのロード\u001b[39;00m\n\u001b[0;32m----> <a href='vscode-notebook-cell://kenken999-fastapi-django-main--1027.hf.space/home/user/app/polls/test.ipynb#X10sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>\u001b[0m dataset \u001b[39m=\u001b[39m load_dataset(\u001b[39m'\u001b[39;49m\u001b[39mwiki_dpr\u001b[39;49m\u001b[39m'\u001b[39;49m, \u001b[39m'\u001b[39;49m\u001b[39mpsgs_w100.nq.exact\u001b[39;49m\u001b[39m'\u001b[39;49m, trust_remote_code\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n",
317
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/load.py:2614\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[1;32m 2611\u001b[0m \u001b[39mreturn\u001b[39;00m builder_instance\u001b[39m.\u001b[39mas_streaming_dataset(split\u001b[39m=\u001b[39msplit)\n\u001b[1;32m 2613\u001b[0m \u001b[39m# Download and prepare data\u001b[39;00m\n\u001b[0;32m-> 2614\u001b[0m builder_instance\u001b[39m.\u001b[39;49mdownload_and_prepare(\n\u001b[1;32m 2615\u001b[0m download_config\u001b[39m=\u001b[39;49mdownload_config,\n\u001b[1;32m 2616\u001b[0m download_mode\u001b[39m=\u001b[39;49mdownload_mode,\n\u001b[1;32m 2617\u001b[0m verification_mode\u001b[39m=\u001b[39;49mverification_mode,\n\u001b[1;32m 2618\u001b[0m num_proc\u001b[39m=\u001b[39;49mnum_proc,\n\u001b[1;32m 2619\u001b[0m storage_options\u001b[39m=\u001b[39;49mstorage_options,\n\u001b[1;32m 2620\u001b[0m )\n\u001b[1;32m 2622\u001b[0m \u001b[39m# Build dataset for splits\u001b[39;00m\n\u001b[1;32m 2623\u001b[0m keep_in_memory \u001b[39m=\u001b[39m (\n\u001b[1;32m 2624\u001b[0m keep_in_memory \u001b[39mif\u001b[39;00m keep_in_memory \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m is_small_dataset(builder_instance\u001b[39m.\u001b[39minfo\u001b[39m.\u001b[39mdataset_size)\n\u001b[1;32m 2625\u001b[0m )\n",
318
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/builder.py:1027\u001b[0m, in \u001b[0;36mDatasetBuilder.download_and_prepare\u001b[0;34m(self, output_dir, download_config, download_mode, verification_mode, ignore_verifications, try_from_hf_gcs, dl_manager, base_path, use_auth_token, file_format, max_shard_size, num_proc, storage_options, **download_and_prepare_kwargs)\u001b[0m\n\u001b[1;32m 1025\u001b[0m \u001b[39mif\u001b[39;00m num_proc \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 1026\u001b[0m prepare_split_kwargs[\u001b[39m\"\u001b[39m\u001b[39mnum_proc\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m num_proc\n\u001b[0;32m-> 1027\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_download_and_prepare(\n\u001b[1;32m 1028\u001b[0m dl_manager\u001b[39m=\u001b[39;49mdl_manager,\n\u001b[1;32m 1029\u001b[0m verification_mode\u001b[39m=\u001b[39;49mverification_mode,\n\u001b[1;32m 1030\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mprepare_split_kwargs,\n\u001b[1;32m 1031\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mdownload_and_prepare_kwargs,\n\u001b[1;32m 1032\u001b[0m )\n\u001b[1;32m 1033\u001b[0m \u001b[39m# Sync info\u001b[39;00m\n\u001b[1;32m 1034\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39minfo\u001b[39m.\u001b[39mdataset_size \u001b[39m=\u001b[39m \u001b[39msum\u001b[39m(split\u001b[39m.\u001b[39mnum_bytes \u001b[39mfor\u001b[39;00m split \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39minfo\u001b[39m.\u001b[39msplits\u001b[39m.\u001b[39mvalues())\n",
319
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/builder.py:1100\u001b[0m, in \u001b[0;36mDatasetBuilder._download_and_prepare\u001b[0;34m(self, dl_manager, verification_mode, **prepare_split_kwargs)\u001b[0m\n\u001b[1;32m 1098\u001b[0m split_dict \u001b[39m=\u001b[39m SplitDict(dataset_name\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdataset_name)\n\u001b[1;32m 1099\u001b[0m split_generators_kwargs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_make_split_generators_kwargs(prepare_split_kwargs)\n\u001b[0;32m-> 1100\u001b[0m split_generators \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_split_generators(dl_manager, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49msplit_generators_kwargs)\n\u001b[1;32m 1102\u001b[0m \u001b[39m# Checksums verification\u001b[39;00m\n\u001b[1;32m 1103\u001b[0m \u001b[39mif\u001b[39;00m verification_mode \u001b[39m==\u001b[39m VerificationMode\u001b[39m.\u001b[39mALL_CHECKS \u001b[39mand\u001b[39;00m dl_manager\u001b[39m.\u001b[39mrecord_checksums:\n",
320
+ "File \u001b[0;32m~/.cache/huggingface/modules/datasets_modules/datasets/wiki_dpr/66fd9b80f51375c02cd9010050e781ed3e8f759e868f690c31b2686a7a0eeb5c/wiki_dpr.py:143\u001b[0m, in \u001b[0;36mWikiDpr._split_generators\u001b[0;34m(self, dl_manager)\u001b[0m\n\u001b[1;32m 141\u001b[0m data_dir \u001b[39m=\u001b[39m os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39m\"\u001b[39m\u001b[39mdata\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mwiki_split, data_dir)\n\u001b[1;32m 142\u001b[0m files \u001b[39m=\u001b[39m [os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(data_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mtrain-\u001b[39m\u001b[39m{\u001b[39;00mi\u001b[39m:\u001b[39;00m\u001b[39m05d\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m-of-\u001b[39m\u001b[39m{\u001b[39;00mnum_shards\u001b[39m:\u001b[39;00m\u001b[39m05d\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m.parquet\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(num_shards)]\n\u001b[0;32m--> 143\u001b[0m downloaded_files \u001b[39m=\u001b[39m dl_manager\u001b[39m.\u001b[39;49mdownload_and_extract(files)\n\u001b[1;32m 144\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 145\u001b[0m datasets\u001b[39m.\u001b[39mSplitGenerator(name\u001b[39m=\u001b[39mdatasets\u001b[39m.\u001b[39mSplit\u001b[39m.\u001b[39mTRAIN, gen_kwargs\u001b[39m=\u001b[39m{\u001b[39m\"\u001b[39m\u001b[39mfiles\u001b[39m\u001b[39m\"\u001b[39m: downloaded_files}),\n\u001b[1;32m 146\u001b[0m ]\n",
321
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/download/download_manager.py:434\u001b[0m, in \u001b[0;36mDownloadManager.download_and_extract\u001b[0;34m(self, url_or_urls)\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdownload_and_extract\u001b[39m(\u001b[39mself\u001b[39m, url_or_urls):\n\u001b[1;32m 419\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Download and extract given `url_or_urls`.\u001b[39;00m\n\u001b[1;32m 420\u001b[0m \n\u001b[1;32m 421\u001b[0m \u001b[39m Is roughly equivalent to:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 432\u001b[0m \u001b[39m extracted_path(s): `str`, extracted paths of given URL(s).\u001b[39;00m\n\u001b[1;32m 433\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 434\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mextract(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdownload(url_or_urls))\n",
322
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/download/download_manager.py:257\u001b[0m, in \u001b[0;36mDownloadManager.download\u001b[0;34m(self, url_or_urls)\u001b[0m\n\u001b[1;32m 255\u001b[0m start_time \u001b[39m=\u001b[39m datetime\u001b[39m.\u001b[39mnow()\n\u001b[1;32m 256\u001b[0m \u001b[39mwith\u001b[39;00m stack_multiprocessing_download_progress_bars():\n\u001b[0;32m--> 257\u001b[0m downloaded_path_or_paths \u001b[39m=\u001b[39m map_nested(\n\u001b[1;32m 258\u001b[0m download_func,\n\u001b[1;32m 259\u001b[0m url_or_urls,\n\u001b[1;32m 260\u001b[0m map_tuple\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 261\u001b[0m num_proc\u001b[39m=\u001b[39;49mdownload_config\u001b[39m.\u001b[39;49mnum_proc,\n\u001b[1;32m 262\u001b[0m desc\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mDownloading data files\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 263\u001b[0m batched\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 264\u001b[0m batch_size\u001b[39m=\u001b[39;49m\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m,\n\u001b[1;32m 265\u001b[0m )\n\u001b[1;32m 266\u001b[0m duration \u001b[39m=\u001b[39m datetime\u001b[39m.\u001b[39mnow() \u001b[39m-\u001b[39m start_time\n\u001b[1;32m 267\u001b[0m logger\u001b[39m.\u001b[39minfo(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mDownloading took \u001b[39m\u001b[39m{\u001b[39;00mduration\u001b[39m.\u001b[39mtotal_seconds()\u001b[39m \u001b[39m\u001b[39m/\u001b[39m\u001b[39m/\u001b[39m\u001b[39m \u001b[39m\u001b[39m60\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m min\u001b[39m\u001b[39m\"\u001b[39m)\n",
323
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/utils/py_utils.py:511\u001b[0m, in \u001b[0;36mmap_nested\u001b[0;34m(function, data_struct, dict_only, map_list, map_tuple, map_numpy, num_proc, parallel_min_length, batched, batch_size, types, disable_tqdm, desc)\u001b[0m\n\u001b[1;32m 509\u001b[0m batch_size \u001b[39m=\u001b[39m \u001b[39mmax\u001b[39m(\u001b[39mlen\u001b[39m(iterable) \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m num_proc \u001b[39m+\u001b[39m \u001b[39mint\u001b[39m(\u001b[39mlen\u001b[39m(iterable) \u001b[39m%\u001b[39m num_proc \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m), \u001b[39m1\u001b[39m)\n\u001b[1;32m 510\u001b[0m iterable \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(iter_batched(iterable, batch_size))\n\u001b[0;32m--> 511\u001b[0m mapped \u001b[39m=\u001b[39m [\n\u001b[1;32m 512\u001b[0m _single_map_nested((function, obj, batched, batch_size, types, \u001b[39mNone\u001b[39;00m, \u001b[39mTrue\u001b[39;00m, \u001b[39mNone\u001b[39;00m))\n\u001b[1;32m 513\u001b[0m \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m hf_tqdm(iterable, disable\u001b[39m=\u001b[39mdisable_tqdm, desc\u001b[39m=\u001b[39mdesc)\n\u001b[1;32m 514\u001b[0m ]\n\u001b[1;32m 515\u001b[0m \u001b[39mif\u001b[39;00m batched:\n\u001b[1;32m 516\u001b[0m mapped \u001b[39m=\u001b[39m [mapped_item \u001b[39mfor\u001b[39;00m mapped_batch \u001b[39min\u001b[39;00m mapped \u001b[39mfor\u001b[39;00m mapped_item \u001b[39min\u001b[39;00m mapped_batch]\n",
324
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/utils/py_utils.py:512\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 509\u001b[0m batch_size \u001b[39m=\u001b[39m \u001b[39mmax\u001b[39m(\u001b[39mlen\u001b[39m(iterable) \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m num_proc \u001b[39m+\u001b[39m \u001b[39mint\u001b[39m(\u001b[39mlen\u001b[39m(iterable) \u001b[39m%\u001b[39m num_proc \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m), \u001b[39m1\u001b[39m)\n\u001b[1;32m 510\u001b[0m iterable \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(iter_batched(iterable, batch_size))\n\u001b[1;32m 511\u001b[0m mapped \u001b[39m=\u001b[39m [\n\u001b[0;32m--> 512\u001b[0m _single_map_nested((function, obj, batched, batch_size, types, \u001b[39mNone\u001b[39;49;00m, \u001b[39mTrue\u001b[39;49;00m, \u001b[39mNone\u001b[39;49;00m))\n\u001b[1;32m 513\u001b[0m \u001b[39mfor\u001b[39;00m obj \u001b[39min\u001b[39;00m hf_tqdm(iterable, disable\u001b[39m=\u001b[39mdisable_tqdm, desc\u001b[39m=\u001b[39mdesc)\n\u001b[1;32m 514\u001b[0m ]\n\u001b[1;32m 515\u001b[0m \u001b[39mif\u001b[39;00m batched:\n\u001b[1;32m 516\u001b[0m mapped \u001b[39m=\u001b[39m [mapped_item \u001b[39mfor\u001b[39;00m mapped_batch \u001b[39min\u001b[39;00m mapped \u001b[39mfor\u001b[39;00m mapped_item \u001b[39min\u001b[39;00m mapped_batch]\n",
325
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/utils/py_utils.py:380\u001b[0m, in \u001b[0;36m_single_map_nested\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 373\u001b[0m \u001b[39mreturn\u001b[39;00m function(data_struct)\n\u001b[1;32m 374\u001b[0m \u001b[39mif\u001b[39;00m (\n\u001b[1;32m 375\u001b[0m batched\n\u001b[1;32m 376\u001b[0m \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(data_struct, \u001b[39mdict\u001b[39m)\n\u001b[1;32m 377\u001b[0m \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(data_struct, types)\n\u001b[1;32m 378\u001b[0m \u001b[39mand\u001b[39;00m \u001b[39mall\u001b[39m(\u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(v, (\u001b[39mdict\u001b[39m, types)) \u001b[39mfor\u001b[39;00m v \u001b[39min\u001b[39;00m data_struct)\n\u001b[1;32m 379\u001b[0m ):\n\u001b[0;32m--> 380\u001b[0m \u001b[39mreturn\u001b[39;00m [mapped_item \u001b[39mfor\u001b[39;00m batch \u001b[39min\u001b[39;00m iter_batched(data_struct, batch_size) \u001b[39mfor\u001b[39;00m mapped_item \u001b[39min\u001b[39;00m function(batch)]\n\u001b[1;32m 382\u001b[0m \u001b[39m# Reduce logging to keep things readable in multiprocessing with tqdm\u001b[39;00m\n\u001b[1;32m 383\u001b[0m \u001b[39mif\u001b[39;00m rank \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m logging\u001b[39m.\u001b[39mget_verbosity() \u001b[39m<\u001b[39m logging\u001b[39m.\u001b[39mWARNING:\n",
326
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/utils/py_utils.py:380\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 373\u001b[0m \u001b[39mreturn\u001b[39;00m function(data_struct)\n\u001b[1;32m 374\u001b[0m \u001b[39mif\u001b[39;00m (\n\u001b[1;32m 375\u001b[0m batched\n\u001b[1;32m 376\u001b[0m \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(data_struct, \u001b[39mdict\u001b[39m)\n\u001b[1;32m 377\u001b[0m \u001b[39mand\u001b[39;00m \u001b[39misinstance\u001b[39m(data_struct, types)\n\u001b[1;32m 378\u001b[0m \u001b[39mand\u001b[39;00m \u001b[39mall\u001b[39m(\u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(v, (\u001b[39mdict\u001b[39m, types)) \u001b[39mfor\u001b[39;00m v \u001b[39min\u001b[39;00m data_struct)\n\u001b[1;32m 379\u001b[0m ):\n\u001b[0;32m--> 380\u001b[0m \u001b[39mreturn\u001b[39;00m [mapped_item \u001b[39mfor\u001b[39;00m batch \u001b[39min\u001b[39;00m iter_batched(data_struct, batch_size) \u001b[39mfor\u001b[39;00m mapped_item \u001b[39min\u001b[39;00m function(batch)]\n\u001b[1;32m 382\u001b[0m \u001b[39m# Reduce logging to keep things readable in multiprocessing with tqdm\u001b[39;00m\n\u001b[1;32m 383\u001b[0m \u001b[39mif\u001b[39;00m rank \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m logging\u001b[39m.\u001b[39mget_verbosity() \u001b[39m<\u001b[39m logging\u001b[39m.\u001b[39mWARNING:\n",
327
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/datasets/download/download_manager.py:300\u001b[0m, in \u001b[0;36mDownloadManager._download_batched\u001b[0;34m(self, url_or_filenames, download_config)\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[39mpass\u001b[39;00m\n\u001b[1;32m 296\u001b[0m max_workers \u001b[39m=\u001b[39m (\n\u001b[1;32m 297\u001b[0m config\u001b[39m.\u001b[39mHF_DATASETS_MULTITHREADING_MAX_WORKERS \u001b[39mif\u001b[39;00m size \u001b[39m<\u001b[39m (\u001b[39m20\u001b[39m \u001b[39m<<\u001b[39m \u001b[39m20\u001b[39m) \u001b[39melse\u001b[39;00m \u001b[39m1\u001b[39m\n\u001b[1;32m 298\u001b[0m ) \u001b[39m# enable multithreading if files are small\u001b[39;00m\n\u001b[0;32m--> 300\u001b[0m \u001b[39mreturn\u001b[39;00m thread_map(\n\u001b[1;32m 301\u001b[0m download_func,\n\u001b[1;32m 302\u001b[0m url_or_filenames,\n\u001b[1;32m 303\u001b[0m desc\u001b[39m=\u001b[39;49mdownload_config\u001b[39m.\u001b[39;49mdownload_desc \u001b[39mor\u001b[39;49;00m \u001b[39m\"\u001b[39;49m\u001b[39mDownloading\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 304\u001b[0m unit\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mfiles\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 305\u001b[0m position\u001b[39m=\u001b[39;49mmultiprocessing\u001b[39m.\u001b[39;49mcurrent_process()\u001b[39m.\u001b[39;49m_identity[\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m] \u001b[39m# contains the ranks of subprocesses\u001b[39;49;00m\n\u001b[1;32m 306\u001b[0m \u001b[39mif\u001b[39;49;00m os\u001b[39m.\u001b[39;49menviron\u001b[39m.\u001b[39;49mget(\u001b[39m\"\u001b[39;49m\u001b[39mHF_DATASETS_STACK_MULTIPROCESSING_DOWNLOAD_PROGRESS_BARS\u001b[39;49m\u001b[39m\"\u001b[39;49m) \u001b[39m==\u001b[39;49m \u001b[39m\"\u001b[39;49m\u001b[39m1\u001b[39;49m\u001b[39m\"\u001b[39;49m\n\u001b[1;32m 307\u001b[0m \u001b[39mand\u001b[39;49;00m multiprocessing\u001b[39m.\u001b[39;49mcurrent_process()\u001b[39m.\u001b[39;49m_identity\n\u001b[1;32m 308\u001b[0m \u001b[39melse\u001b[39;49;00m \u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 309\u001b[0m max_workers\u001b[39m=\u001b[39;49mmax_workers,\n\u001b[1;32m 310\u001b[0m tqdm_class\u001b[39m=\u001b[39;49mtqdm,\n\u001b[1;32m 311\u001b[0m )\n\u001b[1;32m 312\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 313\u001b[0m \u001b[39mreturn\u001b[39;00m [\n\u001b[1;32m 314\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_download_single(url_or_filename, download_config\u001b[39m=\u001b[39mdownload_config)\n\u001b[1;32m 315\u001b[0m \u001b[39mfor\u001b[39;00m url_or_filename \u001b[39min\u001b[39;00m url_or_filenames\n\u001b[1;32m 316\u001b[0m ]\n",
328
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/tqdm/contrib/concurrent.py:69\u001b[0m, in \u001b[0;36mthread_map\u001b[0;34m(fn, *iterables, **tqdm_kwargs)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 56\u001b[0m \u001b[39mEquivalent of `list(map(fn, *iterables))`\u001b[39;00m\n\u001b[1;32m 57\u001b[0m \u001b[39mdriven by `concurrent.futures.ThreadPoolExecutor`.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[39m [default: max(32, cpu_count() + 4)].\u001b[39;00m\n\u001b[1;32m 67\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 68\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mconcurrent\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mfutures\u001b[39;00m \u001b[39mimport\u001b[39;00m ThreadPoolExecutor\n\u001b[0;32m---> 69\u001b[0m \u001b[39mreturn\u001b[39;00m _executor_map(ThreadPoolExecutor, fn, \u001b[39m*\u001b[39;49miterables, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mtqdm_kwargs)\n",
329
+ "File \u001b[0;32m/usr/local/lib/python3.10/site-packages/tqdm/contrib/concurrent.py:49\u001b[0m, in \u001b[0;36m_executor_map\u001b[0;34m(PoolExecutor, fn, *iterables, **tqdm_kwargs)\u001b[0m\n\u001b[1;32m 46\u001b[0m lock_name \u001b[39m=\u001b[39m kwargs\u001b[39m.\u001b[39mpop(\u001b[39m\"\u001b[39m\u001b[39mlock_name\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 47\u001b[0m \u001b[39mwith\u001b[39;00m ensure_lock(tqdm_class, lock_name\u001b[39m=\u001b[39mlock_name) \u001b[39mas\u001b[39;00m lk:\n\u001b[1;32m 48\u001b[0m \u001b[39m# share lock in case workers are already using `tqdm`\u001b[39;00m\n\u001b[0;32m---> 49\u001b[0m \u001b[39mwith\u001b[39;00m PoolExecutor(max_workers\u001b[39m=\u001b[39mmax_workers, initializer\u001b[39m=\u001b[39mtqdm_class\u001b[39m.\u001b[39mset_lock,\n\u001b[1;32m 50\u001b[0m initargs\u001b[39m=\u001b[39m(lk,)) \u001b[39mas\u001b[39;00m ex:\n\u001b[1;32m 51\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mlist\u001b[39m(tqdm_class(ex\u001b[39m.\u001b[39mmap(fn, \u001b[39m*\u001b[39miterables, chunksize\u001b[39m=\u001b[39mchunksize), \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs))\n",
330
+ "File \u001b[0;32m/usr/local/lib/python3.10/concurrent/futures/_base.py:649\u001b[0m, in \u001b[0;36mExecutor.__exit__\u001b[0;34m(self, exc_type, exc_val, exc_tb)\u001b[0m\n\u001b[1;32m 648\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__exit__\u001b[39m(\u001b[39mself\u001b[39m, exc_type, exc_val, exc_tb):\n\u001b[0;32m--> 649\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mshutdown(wait\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n\u001b[1;32m 650\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mFalse\u001b[39;00m\n",
331
+ "File \u001b[0;32m/usr/local/lib/python3.10/concurrent/futures/thread.py:235\u001b[0m, in \u001b[0;36mThreadPoolExecutor.shutdown\u001b[0;34m(self, wait, cancel_futures)\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[39mif\u001b[39;00m wait:\n\u001b[1;32m 234\u001b[0m \u001b[39mfor\u001b[39;00m t \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_threads:\n\u001b[0;32m--> 235\u001b[0m t\u001b[39m.\u001b[39;49mjoin()\n",
332
+ "File \u001b[0;32m/usr/local/lib/python3.10/threading.py:1096\u001b[0m, in \u001b[0;36mThread.join\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1093\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mcannot join current thread\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 1095\u001b[0m \u001b[39mif\u001b[39;00m timeout \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m-> 1096\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_wait_for_tstate_lock()\n\u001b[1;32m 1097\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 1098\u001b[0m \u001b[39m# the behavior of a negative timeout isn't documented, but\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[39m# historically .join(timeout=x) for x<0 has acted as if timeout=0\u001b[39;00m\n\u001b[1;32m 1100\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_wait_for_tstate_lock(timeout\u001b[39m=\u001b[39m\u001b[39mmax\u001b[39m(timeout, \u001b[39m0\u001b[39m))\n",
333
+ "File \u001b[0;32m/usr/local/lib/python3.10/threading.py:1116\u001b[0m, in \u001b[0;36mThread._wait_for_tstate_lock\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 1113\u001b[0m \u001b[39mreturn\u001b[39;00m\n\u001b[1;32m 1115\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 1116\u001b[0m \u001b[39mif\u001b[39;00m lock\u001b[39m.\u001b[39;49macquire(block, timeout):\n\u001b[1;32m 1117\u001b[0m lock\u001b[39m.\u001b[39mrelease()\n\u001b[1;32m 1118\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_stop()\n",
334
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
335
+ ]
336
+ }
337
+ ],
338
+ "source": [
339
+ "from datasets import load_dataset\n",
340
+ "\n",
341
+ "# データセットのロード\n",
342
+ "dataset = load_dataset('wiki_dpr', 'psgs_w100.nq.exact', trust_remote_code=True)\n"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "execution_count": 11,
348
+ "metadata": {},
349
+ "outputs": [
350
+ {
351
+ "name": "stderr",
352
+ "output_type": "stream",
353
+ "text": [
354
+ "Downloading data: 90%|█████████ | 142/157 [23:35<01:01, 4.07s/files]"
355
+ ]
356
+ }
357
+ ],
358
+ "source": [
359
+ "from datasets import load_dataset\n",
360
+ "from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration\n",
361
+ "import faiss\n",
362
+ "import numpy as np\n",
363
+ "\n",
364
+ "# データセットのロード\n",
365
+ "dataset = load_dataset('wiki_dpr', 'psgs_w100.nq.exact', trust_remote_code=True)\n",
366
+ "\n",
367
+ "# モデルとトークナイザーのロード\n",
368
+ "model_name = \"facebook/rag-sequence-nq\"\n",
369
+ "tokenizer = RagTokenizer.from_pretrained(model_name)\n",
370
+ "retriever = RagRetriever.from_pretrained(model_name, use_dummy_dataset=True)\n",
371
+ "model = RagSequenceForGeneration.from_pretrained(model_name, retriever=retriever)\n",
372
+ "\n",
373
+ "# ドキュメントのエンコーディング\n",
374
+ "def embed_passages(passages):\n",
375
+ " embeddings = []\n",
376
+ " for passage in passages:\n",
377
+ " inputs = tokenizer(passage, return_tensors=\"pt\", truncation=True, padding=\"max_length\", max_length=256)\n",
378
+ " outputs = model.retriever.question_encoder(**inputs)\n",
379
+ " embeddings.append(outputs.pooler_output.detach().numpy())\n",
380
+ " return np.vstack(embeddings)\n",
381
+ "\n",
382
+ "# ドキュメントのエンベッド\n",
383
+ "passages = dataset['train']['text'][:1000] # デモのため、最初の1000ドキュメントのみを使用\n",
384
+ "passage_embeddings = embed_passages(passages)\n",
385
+ "\n",
386
+ "# Faissインデックスの作成\n",
387
+ "index = faiss.IndexFlatL2(passage_embeddings.shape[1])\n",
388
+ "index.add(passage_embeddings)\n",
389
+ "faiss.write_index(index, \"faiss_index\")\n",
390
+ "\n",
391
+ "# Faissインデックスをモデルに読み込む\n",
392
+ "retriever.index = faiss.read_index(\"faiss_index\")\n",
393
+ "\n",
394
+ "# 質問のトークナイズ\n",
395
+ "question = \"What is the capital of France?\"\n",
396
+ "inputs = tokenizer(question, return_tensors=\"pt\")\n",
397
+ "\n",
398
+ "# 回答の生成\n",
399
+ "outputs = model.generate(inputs[\"input_ids\"])\n",
400
+ "answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]\n",
401
+ "\n",
402
+ "print(answer)\n"
403
+ ]
404
+ }
405
+ ],
406
+ "metadata": {
407
+ "kernelspec": {
408
+ "display_name": "Python 3",
409
+ "language": "python",
410
+ "name": "python3"
411
+ },
412
+ "language_info": {
413
+ "codemirror_mode": {
414
+ "name": "ipython",
415
+ "version": 3
416
+ },
417
+ "file_extension": ".py",
418
+ "mimetype": "text/x-python",
419
+ "name": "python",
420
+ "nbconvert_exporter": "python",
421
+ "pygments_lexer": "ipython3",
422
+ "version": "3.10.13"
423
+ }
424
+ },
425
+ "nbformat": 4,
426
+ "nbformat_minor": 2
427
+ }