hassaanik commited on
Commit
e159241
·
verified ·
1 Parent(s): 597d443

Upload 2 files

Browse files
Notebooks/Couselling Chat.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Notebooks/Medication Chat.ipynb ADDED
@@ -0,0 +1,1348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "source": [
22
+ "### Data Preparation"
23
+ ],
24
+ "metadata": {
25
+ "id": "ga8c1nhja4Qy"
26
+ }
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "source": [
31
+ "!pip install opendatasets"
32
+ ],
33
+ "metadata": {
34
+ "colab": {
35
+ "base_uri": "https://localhost:8080/"
36
+ },
37
+ "id": "O7NczD5abI6o",
38
+ "outputId": "422faa21-1ee0-4582-9315-4c2b01f4518d"
39
+ },
40
+ "execution_count": 1,
41
+ "outputs": [
42
+ {
43
+ "output_type": "stream",
44
+ "name": "stdout",
45
+ "text": [
46
+ "Collecting opendatasets\n",
47
+ " Downloading opendatasets-0.1.22-py3-none-any.whl.metadata (9.2 kB)\n",
48
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from opendatasets) (4.66.5)\n",
49
+ "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (from opendatasets) (1.6.17)\n",
50
+ "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from opendatasets) (8.1.7)\n",
51
+ "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (1.16.0)\n",
52
+ "Requirement already satisfied: certifi>=2023.7.22 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2024.8.30)\n",
53
+ "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.8.2)\n",
54
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.32.3)\n",
55
+ "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (8.0.4)\n",
56
+ "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.0.7)\n",
57
+ "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (6.1.0)\n",
58
+ "Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle->opendatasets) (0.5.1)\n",
59
+ "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle->opendatasets) (1.3)\n",
60
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle->opendatasets) (3.3.2)\n",
61
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle->opendatasets) (3.10)\n",
62
+ "Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)\n",
63
+ "Installing collected packages: opendatasets\n",
64
+ "Successfully installed opendatasets-0.1.22\n"
65
+ ]
66
+ }
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "source": [
72
+ "import opendatasets as od\n",
73
+ "od.download('https://www.kaggle.com/datasets/hassaanidrees/medinfo?select=MedInfo2019-QA-Medications.xlsx')"
74
+ ],
75
+ "metadata": {
76
+ "colab": {
77
+ "base_uri": "https://localhost:8080/"
78
+ },
79
+ "id": "7QSxa8cRbIug",
80
+ "outputId": "088ef3d5-b3fc-4860-8928-bb872ff83ab5"
81
+ },
82
+ "execution_count": 2,
83
+ "outputs": [
84
+ {
85
+ "output_type": "stream",
86
+ "name": "stdout",
87
+ "text": [
88
+ "Dataset URL: https://www.kaggle.com/datasets/hassaanidrees/medinfo\n",
89
+ "Downloading medinfo.zip to ./medinfo\n"
90
+ ]
91
+ },
92
+ {
93
+ "output_type": "stream",
94
+ "name": "stderr",
95
+ "text": [
96
+ "100%|██████████| 159k/159k [00:00<00:00, 480kB/s]"
97
+ ]
98
+ },
99
+ {
100
+ "output_type": "stream",
101
+ "name": "stdout",
102
+ "text": [
103
+ "\n"
104
+ ]
105
+ },
106
+ {
107
+ "output_type": "stream",
108
+ "name": "stderr",
109
+ "text": [
110
+ "\n"
111
+ ]
112
+ }
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "source": [
118
+ "# Import pandas for data analysis\n",
119
+ "import pandas as pd\n",
120
+ "df = pd.read_excel(\"/content/medinfo/MedInfo2019-QA-Medications.xlsx\")\n",
121
+ "df = df[['Question','Answer']]"
122
+ ],
123
+ "metadata": {
124
+ "id": "sooD64r3bIDJ"
125
+ },
126
+ "execution_count": 3,
127
+ "outputs": []
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "source": [
132
+ "df.head() #show first five rows"
133
+ ],
134
+ "metadata": {
135
+ "colab": {
136
+ "base_uri": "https://localhost:8080/",
137
+ "height": 206
138
+ },
139
+ "id": "eRneQPLAqAJL",
140
+ "outputId": "d1772f7e-8edd-4687-9c1a-c3102e86138e"
141
+ },
142
+ "execution_count": null,
143
+ "outputs": [
144
+ {
145
+ "output_type": "execute_result",
146
+ "data": {
147
+ "text/plain": [
148
+ " Question \\\n",
149
+ "0 how does rivatigmine and otc sleep medicine in... \n",
150
+ "1 how does valium affect the brain \n",
151
+ "2 what is morphine \n",
152
+ "3 what are the milligrams for oxycodone e \n",
153
+ "4 81% aspirin contain resin and shellac in it. ? \n",
154
+ "\n",
155
+ " Answer \n",
156
+ "0 tell your doctor and pharmacist what prescript... \n",
157
+ "1 Diazepam is a benzodiazepine that exerts anxio... \n",
158
+ "2 Morphine is a pain medication of the opiate fa... \n",
159
+ "3 … 10 mg … 20 mg … 40 mg … 80 mg ... \n",
160
+ "4 Inactive Ingredients Ingredient Name "
161
+ ],
162
+ "text/html": [
163
+ "\n",
164
+ " <div id=\"df-d79eadfb-a1cc-4af0-87f3-9921298edcfe\" class=\"colab-df-container\">\n",
165
+ " <div>\n",
166
+ "<style scoped>\n",
167
+ " .dataframe tbody tr th:only-of-type {\n",
168
+ " vertical-align: middle;\n",
169
+ " }\n",
170
+ "\n",
171
+ " .dataframe tbody tr th {\n",
172
+ " vertical-align: top;\n",
173
+ " }\n",
174
+ "\n",
175
+ " .dataframe thead th {\n",
176
+ " text-align: right;\n",
177
+ " }\n",
178
+ "</style>\n",
179
+ "<table border=\"1\" class=\"dataframe\">\n",
180
+ " <thead>\n",
181
+ " <tr style=\"text-align: right;\">\n",
182
+ " <th></th>\n",
183
+ " <th>Question</th>\n",
184
+ " <th>Answer</th>\n",
185
+ " </tr>\n",
186
+ " </thead>\n",
187
+ " <tbody>\n",
188
+ " <tr>\n",
189
+ " <th>0</th>\n",
190
+ " <td>how does rivatigmine and otc sleep medicine in...</td>\n",
191
+ " <td>tell your doctor and pharmacist what prescript...</td>\n",
192
+ " </tr>\n",
193
+ " <tr>\n",
194
+ " <th>1</th>\n",
195
+ " <td>how does valium affect the brain</td>\n",
196
+ " <td>Diazepam is a benzodiazepine that exerts anxio...</td>\n",
197
+ " </tr>\n",
198
+ " <tr>\n",
199
+ " <th>2</th>\n",
200
+ " <td>what is morphine</td>\n",
201
+ " <td>Morphine is a pain medication of the opiate fa...</td>\n",
202
+ " </tr>\n",
203
+ " <tr>\n",
204
+ " <th>3</th>\n",
205
+ " <td>what are the milligrams for oxycodone e</td>\n",
206
+ " <td>… 10 mg … 20 mg … 40 mg … 80 mg ...</td>\n",
207
+ " </tr>\n",
208
+ " <tr>\n",
209
+ " <th>4</th>\n",
210
+ " <td>81% aspirin contain resin and shellac in it. ?</td>\n",
211
+ " <td>Inactive Ingredients Ingredient Name</td>\n",
212
+ " </tr>\n",
213
+ " </tbody>\n",
214
+ "</table>\n",
215
+ "</div>\n",
216
+ " <div class=\"colab-df-buttons\">\n",
217
+ "\n",
218
+ " <div class=\"colab-df-container\">\n",
219
+ " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-d79eadfb-a1cc-4af0-87f3-9921298edcfe')\"\n",
220
+ " title=\"Convert this dataframe to an interactive table.\"\n",
221
+ " style=\"display:none;\">\n",
222
+ "\n",
223
+ " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
224
+ " <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
225
+ " </svg>\n",
226
+ " </button>\n",
227
+ "\n",
228
+ " <style>\n",
229
+ " .colab-df-container {\n",
230
+ " display:flex;\n",
231
+ " gap: 12px;\n",
232
+ " }\n",
233
+ "\n",
234
+ " .colab-df-convert {\n",
235
+ " background-color: #E8F0FE;\n",
236
+ " border: none;\n",
237
+ " border-radius: 50%;\n",
238
+ " cursor: pointer;\n",
239
+ " display: none;\n",
240
+ " fill: #1967D2;\n",
241
+ " height: 32px;\n",
242
+ " padding: 0 0 0 0;\n",
243
+ " width: 32px;\n",
244
+ " }\n",
245
+ "\n",
246
+ " .colab-df-convert:hover {\n",
247
+ " background-color: #E2EBFA;\n",
248
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
249
+ " fill: #174EA6;\n",
250
+ " }\n",
251
+ "\n",
252
+ " .colab-df-buttons div {\n",
253
+ " margin-bottom: 4px;\n",
254
+ " }\n",
255
+ "\n",
256
+ " [theme=dark] .colab-df-convert {\n",
257
+ " background-color: #3B4455;\n",
258
+ " fill: #D2E3FC;\n",
259
+ " }\n",
260
+ "\n",
261
+ " [theme=dark] .colab-df-convert:hover {\n",
262
+ " background-color: #434B5C;\n",
263
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
264
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
265
+ " fill: #FFFFFF;\n",
266
+ " }\n",
267
+ " </style>\n",
268
+ "\n",
269
+ " <script>\n",
270
+ " const buttonEl =\n",
271
+ " document.querySelector('#df-d79eadfb-a1cc-4af0-87f3-9921298edcfe button.colab-df-convert');\n",
272
+ " buttonEl.style.display =\n",
273
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
274
+ "\n",
275
+ " async function convertToInteractive(key) {\n",
276
+ " const element = document.querySelector('#df-d79eadfb-a1cc-4af0-87f3-9921298edcfe');\n",
277
+ " const dataTable =\n",
278
+ " await google.colab.kernel.invokeFunction('convertToInteractive',\n",
279
+ " [key], {});\n",
280
+ " if (!dataTable) return;\n",
281
+ "\n",
282
+ " const docLinkHtml = 'Like what you see? Visit the ' +\n",
283
+ " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
284
+ " + ' to learn more about interactive tables.';\n",
285
+ " element.innerHTML = '';\n",
286
+ " dataTable['output_type'] = 'display_data';\n",
287
+ " await google.colab.output.renderOutput(dataTable, element);\n",
288
+ " const docLink = document.createElement('div');\n",
289
+ " docLink.innerHTML = docLinkHtml;\n",
290
+ " element.appendChild(docLink);\n",
291
+ " }\n",
292
+ " </script>\n",
293
+ " </div>\n",
294
+ "\n",
295
+ "\n",
296
+ "<div id=\"df-862caeb9-15bf-47b1-b083-8b9307722b80\">\n",
297
+ " <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-862caeb9-15bf-47b1-b083-8b9307722b80')\"\n",
298
+ " title=\"Suggest charts\"\n",
299
+ " style=\"display:none;\">\n",
300
+ "\n",
301
+ "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
302
+ " width=\"24px\">\n",
303
+ " <g>\n",
304
+ " <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
305
+ " </g>\n",
306
+ "</svg>\n",
307
+ " </button>\n",
308
+ "\n",
309
+ "<style>\n",
310
+ " .colab-df-quickchart {\n",
311
+ " --bg-color: #E8F0FE;\n",
312
+ " --fill-color: #1967D2;\n",
313
+ " --hover-bg-color: #E2EBFA;\n",
314
+ " --hover-fill-color: #174EA6;\n",
315
+ " --disabled-fill-color: #AAA;\n",
316
+ " --disabled-bg-color: #DDD;\n",
317
+ " }\n",
318
+ "\n",
319
+ " [theme=dark] .colab-df-quickchart {\n",
320
+ " --bg-color: #3B4455;\n",
321
+ " --fill-color: #D2E3FC;\n",
322
+ " --hover-bg-color: #434B5C;\n",
323
+ " --hover-fill-color: #FFFFFF;\n",
324
+ " --disabled-bg-color: #3B4455;\n",
325
+ " --disabled-fill-color: #666;\n",
326
+ " }\n",
327
+ "\n",
328
+ " .colab-df-quickchart {\n",
329
+ " background-color: var(--bg-color);\n",
330
+ " border: none;\n",
331
+ " border-radius: 50%;\n",
332
+ " cursor: pointer;\n",
333
+ " display: none;\n",
334
+ " fill: var(--fill-color);\n",
335
+ " height: 32px;\n",
336
+ " padding: 0;\n",
337
+ " width: 32px;\n",
338
+ " }\n",
339
+ "\n",
340
+ " .colab-df-quickchart:hover {\n",
341
+ " background-color: var(--hover-bg-color);\n",
342
+ " box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
343
+ " fill: var(--button-hover-fill-color);\n",
344
+ " }\n",
345
+ "\n",
346
+ " .colab-df-quickchart-complete:disabled,\n",
347
+ " .colab-df-quickchart-complete:disabled:hover {\n",
348
+ " background-color: var(--disabled-bg-color);\n",
349
+ " fill: var(--disabled-fill-color);\n",
350
+ " box-shadow: none;\n",
351
+ " }\n",
352
+ "\n",
353
+ " .colab-df-spinner {\n",
354
+ " border: 2px solid var(--fill-color);\n",
355
+ " border-color: transparent;\n",
356
+ " border-bottom-color: var(--fill-color);\n",
357
+ " animation:\n",
358
+ " spin 1s steps(1) infinite;\n",
359
+ " }\n",
360
+ "\n",
361
+ " @keyframes spin {\n",
362
+ " 0% {\n",
363
+ " border-color: transparent;\n",
364
+ " border-bottom-color: var(--fill-color);\n",
365
+ " border-left-color: var(--fill-color);\n",
366
+ " }\n",
367
+ " 20% {\n",
368
+ " border-color: transparent;\n",
369
+ " border-left-color: var(--fill-color);\n",
370
+ " border-top-color: var(--fill-color);\n",
371
+ " }\n",
372
+ " 30% {\n",
373
+ " border-color: transparent;\n",
374
+ " border-left-color: var(--fill-color);\n",
375
+ " border-top-color: var(--fill-color);\n",
376
+ " border-right-color: var(--fill-color);\n",
377
+ " }\n",
378
+ " 40% {\n",
379
+ " border-color: transparent;\n",
380
+ " border-right-color: var(--fill-color);\n",
381
+ " border-top-color: var(--fill-color);\n",
382
+ " }\n",
383
+ " 60% {\n",
384
+ " border-color: transparent;\n",
385
+ " border-right-color: var(--fill-color);\n",
386
+ " }\n",
387
+ " 80% {\n",
388
+ " border-color: transparent;\n",
389
+ " border-right-color: var(--fill-color);\n",
390
+ " border-bottom-color: var(--fill-color);\n",
391
+ " }\n",
392
+ " 90% {\n",
393
+ " border-color: transparent;\n",
394
+ " border-bottom-color: var(--fill-color);\n",
395
+ " }\n",
396
+ " }\n",
397
+ "</style>\n",
398
+ "\n",
399
+ " <script>\n",
400
+ " async function quickchart(key) {\n",
401
+ " const quickchartButtonEl =\n",
402
+ " document.querySelector('#' + key + ' button');\n",
403
+ " quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
404
+ " quickchartButtonEl.classList.add('colab-df-spinner');\n",
405
+ " try {\n",
406
+ " const charts = await google.colab.kernel.invokeFunction(\n",
407
+ " 'suggestCharts', [key], {});\n",
408
+ " } catch (error) {\n",
409
+ " console.error('Error during call to suggestCharts:', error);\n",
410
+ " }\n",
411
+ " quickchartButtonEl.classList.remove('colab-df-spinner');\n",
412
+ " quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
413
+ " }\n",
414
+ " (() => {\n",
415
+ " let quickchartButtonEl =\n",
416
+ " document.querySelector('#df-862caeb9-15bf-47b1-b083-8b9307722b80 button');\n",
417
+ " quickchartButtonEl.style.display =\n",
418
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
419
+ " })();\n",
420
+ " </script>\n",
421
+ "</div>\n",
422
+ "\n",
423
+ " </div>\n",
424
+ " </div>\n"
425
+ ],
426
+ "application/vnd.google.colaboratory.intrinsic+json": {
427
+ "type": "dataframe",
428
+ "variable_name": "df",
429
+ "summary": "{\n \"name\": \"df\",\n \"rows\": 690,\n \"fields\": [\n {\n \"column\": \"Question\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 651,\n \"samples\": [\n \"how is marijuana used\",\n \"tudorza pressair is what schedule drug\",\n \"how long does ecstasy or mda leave your body\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Answer\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 652,\n \"samples\": [\n \"Marijuana is best known as a drug that people smoke or eat to get high. It is derived from the plant Cannabis sativa. Possession of marijuana is illegal under federal law. Medical marijuana refers to using marijuana to treat certain medical conditions. In the United States, about half of the states have legalized marijuana for medical use.\",\n \"Color - GRAY, Shape - CAPSULE (biconvex), Score - no score, Size - 12mm, Imprint Code - m10\",\n \"Quantity: 60; Per Unit: $4.68 \\u2013 $15.91; Price: $280.99 \\u2013 $954.47\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
430
+ }
431
+ },
432
+ "metadata": {},
433
+ "execution_count": 4
434
+ }
435
+ ]
436
+ },
437
+ {
438
+ "cell_type": "code",
439
+ "source": [
440
+ "df.Question[0]"
441
+ ],
442
+ "metadata": {
443
+ "colab": {
444
+ "base_uri": "https://localhost:8080/",
445
+ "height": 36
446
+ },
447
+ "id": "4SEkJJHwqBwo",
448
+ "outputId": "7aeec0ad-b51a-44fa-f2e1-5a93b61246d5"
449
+ },
450
+ "execution_count": null,
451
+ "outputs": [
452
+ {
453
+ "output_type": "execute_result",
454
+ "data": {
455
+ "text/plain": [
456
+ "'how does rivatigmine and otc sleep medicine interact'"
457
+ ],
458
+ "application/vnd.google.colaboratory.intrinsic+json": {
459
+ "type": "string"
460
+ }
461
+ },
462
+ "metadata": {},
463
+ "execution_count": 5
464
+ }
465
+ ]
466
+ },
467
+ {
468
+ "cell_type": "code",
469
+ "source": [
470
+ "df.Answer[0]"
471
+ ],
472
+ "metadata": {
473
+ "colab": {
474
+ "base_uri": "https://localhost:8080/",
475
+ "height": 105
476
+ },
477
+ "id": "qTllg8a-qGXW",
478
+ "outputId": "a6b8bca7-135e-4e26-e0ff-a2a1424bc45c"
479
+ },
480
+ "execution_count": null,
481
+ "outputs": [
482
+ {
483
+ "output_type": "execute_result",
484
+ "data": {
485
+ "text/plain": [
486
+ "\"tell your doctor and pharmacist what prescription and nonprescription medications, vitamins, nutritional supplements, and herbal products you are taking or plan to take. Be sure to mention any of the following: antihistamines; aspirin and other nonsteroidal anti-inflammatory medications (NSAIDs) such as ibuprofen (Advil, Motrin) and naproxen (Aleve, Naprosyn); bethanechol (Duvoid, Urecholine); ipratropium (Atrovent, in Combivent, DuoNeb); and medications for Alzheimer's disease, glaucoma, irritable bowel disease, motion sickness, ulcers, or urinary problems. Your doctor may need to change the doses of your medications or monitor you carefully for side effects.\""
487
+ ],
488
+ "application/vnd.google.colaboratory.intrinsic+json": {
489
+ "type": "string"
490
+ }
491
+ },
492
+ "metadata": {},
493
+ "execution_count": 6
494
+ }
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "source": [
500
+ "df.shape # 690 rows | 2 cols"
501
+ ],
502
+ "metadata": {
503
+ "colab": {
504
+ "base_uri": "https://localhost:8080/"
505
+ },
506
+ "id": "xs_qECG1qIW5",
507
+ "outputId": "678a409c-9164-48f4-803e-501d3dff3c96"
508
+ },
509
+ "execution_count": null,
510
+ "outputs": [
511
+ {
512
+ "output_type": "execute_result",
513
+ "data": {
514
+ "text/plain": [
515
+ "(690, 2)"
516
+ ]
517
+ },
518
+ "metadata": {},
519
+ "execution_count": 7
520
+ }
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "source": [
526
+ "!pip install cleantext"
527
+ ],
528
+ "metadata": {
529
+ "colab": {
530
+ "base_uri": "https://localhost:8080/"
531
+ },
532
+ "id": "LPvkkbdbrNp-",
533
+ "outputId": "938e6a8d-fb4b-4112-9a0e-3139146e56eb"
534
+ },
535
+ "execution_count": null,
536
+ "outputs": [
537
+ {
538
+ "output_type": "stream",
539
+ "name": "stdout",
540
+ "text": [
541
+ "Collecting cleantext\n",
542
+ " Downloading cleantext-1.1.4-py3-none-any.whl.metadata (3.5 kB)\n",
543
+ "Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from cleantext) (3.8.1)\n",
544
+ "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk->cleantext) (8.1.7)\n",
545
+ "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nltk->cleantext) (1.4.2)\n",
546
+ "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk->cleantext) (2024.5.15)\n",
547
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nltk->cleantext) (4.66.5)\n",
548
+ "Downloading cleantext-1.1.4-py3-none-any.whl (4.9 kB)\n",
549
+ "Installing collected packages: cleantext\n",
550
+ "Successfully installed cleantext-1.1.4\n"
551
+ ]
552
+ }
553
+ ]
554
+ },
555
+ {
556
+ "cell_type": "code",
557
+ "source": [
558
+ "import cleantext\n",
559
+ "\n",
560
+ "# Function to clean text data by removing unwanted characters and formatting\n",
561
+ "def clean(textdata):\n",
562
+ " cleaned_text = []\n",
563
+ " for i in textdata:\n",
564
+ " cleaned_text.append(cleantext.clean(str(i), extra_spaces=True, lowercase=True, stopwords=False, stemming=False, numbers=True, punct=True, clean_all = True))\n",
565
+ "\n",
566
+ " return cleaned_text"
567
+ ],
568
+ "metadata": {
569
+ "id": "dws3d49Lqv1b"
570
+ },
571
+ "execution_count": null,
572
+ "outputs": []
573
+ },
574
+ {
575
+ "cell_type": "code",
576
+ "source": [
577
+ "# Apply the clean function to the questions and answers columns\n",
578
+ "\n",
579
+ "df.Question = list(clean(df.Question))\n",
580
+ "df.Answer = list(clean(df.Answer))"
581
+ ],
582
+ "metadata": {
583
+ "id": "H1ia-jFqrIsG"
584
+ },
585
+ "execution_count": null,
586
+ "outputs": []
587
+ },
588
+ {
589
+ "cell_type": "code",
590
+ "source": [
591
+ "# Save the cleaned data into a new CSV file & save\n",
592
+ "df.to_csv(\"cleaned_med_QA_data.csv\", index=False)"
593
+ ],
594
+ "metadata": {
595
+ "id": "HcB15JQirImk"
596
+ },
597
+ "execution_count": null,
598
+ "outputs": []
599
+ },
600
+ {
601
+ "cell_type": "markdown",
602
+ "source": [
603
+ "### GPT-2 Model"
604
+ ],
605
+ "metadata": {
606
+ "id": "zw5mkpmueML4"
607
+ }
608
+ },
609
+ {
610
+ "cell_type": "code",
611
+ "source": [
612
+ "!pip install datasets"
613
+ ],
614
+ "metadata": {
615
+ "colab": {
616
+ "base_uri": "https://localhost:8080/",
617
+ "height": 1000
618
+ },
619
+ "id": "QhgGKgZ-rYAY",
620
+ "outputId": "f2334a48-2745-42b5-f5fd-929ca58e1ed6",
621
+ "collapsed": true
622
+ },
623
+ "execution_count": null,
624
+ "outputs": [
625
+ {
626
+ "output_type": "stream",
627
+ "name": "stdout",
628
+ "text": [
629
+ "Collecting datasets\n",
630
+ " Downloading datasets-3.0.0-py3-none-any.whl.metadata (19 kB)\n",
631
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.0)\n",
632
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n",
633
+ "Collecting pyarrow>=15.0.0 (from datasets)\n",
634
+ " Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)\n",
635
+ "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n",
636
+ " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n",
637
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n",
638
+ "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n",
639
+ "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.5)\n",
640
+ "Collecting xxhash (from datasets)\n",
641
+ " Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n",
642
+ "Collecting multiprocess (from datasets)\n",
643
+ " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n",
644
+ "Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1)\n",
645
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.5)\n",
646
+ "Requirement already satisfied: huggingface-hub>=0.22.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.24.6)\n",
647
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1)\n",
648
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n",
649
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.0)\n",
650
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
651
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n",
652
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n",
653
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n",
654
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.11.1)\n",
655
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
656
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.22.0->datasets) (4.12.2)\n",
657
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.3.2)\n",
658
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.8)\n",
659
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.0.7)\n",
660
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n",
661
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
662
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n",
663
+ "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n",
664
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
665
+ "Downloading datasets-3.0.0-py3-none-any.whl (474 kB)\n",
666
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m474.3/474.3 kB\u001b[0m \u001b[31m32.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
667
+ "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n",
668
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
669
+ "\u001b[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)\n",
670
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m39.9/39.9 MB\u001b[0m \u001b[31m19.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
671
+ "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n",
672
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m14.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
673
+ "\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n",
674
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m20.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
675
+ "\u001b[?25hInstalling collected packages: xxhash, pyarrow, dill, multiprocess, datasets\n",
676
+ " Attempting uninstall: pyarrow\n",
677
+ " Found existing installation: pyarrow 14.0.2\n",
678
+ " Uninstalling pyarrow-14.0.2:\n",
679
+ " Successfully uninstalled pyarrow-14.0.2\n",
680
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
681
+ "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.\n",
682
+ "ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 17.0.0 which is incompatible.\u001b[0m\u001b[31m\n",
683
+ "\u001b[0mSuccessfully installed datasets-3.0.0 dill-0.3.8 multiprocess-0.70.16 pyarrow-17.0.0 xxhash-3.5.0\n"
684
+ ]
685
+ },
686
+ {
687
+ "output_type": "display_data",
688
+ "data": {
689
+ "application/vnd.colab-display-data+json": {
690
+ "pip_warning": {
691
+ "packages": [
692
+ "pyarrow"
693
+ ]
694
+ },
695
+ "id": "a6cd6efad93b4c4cb5a29a91b023de8a"
696
+ }
697
+ },
698
+ "metadata": {}
699
+ }
700
+ ]
701
+ },
702
+ {
703
+ "cell_type": "code",
704
+ "source": [
705
+ "from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments\n",
706
+ "import torch\n",
707
+ "from datasets import load_dataset\n",
708
+ "\n",
709
+ "# Load the GPT-2 model and tokenizer\n",
710
+ "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
711
+ "model = GPT2LMHeadModel.from_pretrained('gpt2')"
712
+ ],
713
+ "metadata": {
714
+ "colab": {
715
+ "base_uri": "https://localhost:8080/"
716
+ },
717
+ "id": "xgGgvCqerk-1",
718
+ "outputId": "e338ee7f-c898-41c4-b1f6-036f115d3735"
719
+ },
720
+ "execution_count": null,
721
+ "outputs": [
722
+ {
723
+ "output_type": "stream",
724
+ "name": "stderr",
725
+ "text": [
726
+ "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
727
+ " warnings.warn(\n"
728
+ ]
729
+ }
730
+ ]
731
+ },
732
+ {
733
+ "cell_type": "code",
734
+ "source": [
735
+ "# Set the padding token for the tokenizer to be the end-of-sequence token\n",
736
+ "tokenizer.pad_token = tokenizer.eos_token\n",
737
+ "\n",
738
+ "# Maximum sequence length that GPT-2 can handle\n",
739
+ "max_length = tokenizer.model_max_length\n",
740
+ "print(max_length)"
741
+ ],
742
+ "metadata": {
743
+ "colab": {
744
+ "base_uri": "https://localhost:8080/"
745
+ },
746
+ "id": "EeiMYkpCrp62",
747
+ "outputId": "e8b0118b-1694-4d9e-d666-e791b083f63f"
748
+ },
749
+ "execution_count": null,
750
+ "outputs": [
751
+ {
752
+ "output_type": "stream",
753
+ "name": "stdout",
754
+ "text": [
755
+ "1024\n"
756
+ ]
757
+ }
758
+ ]
759
+ },
760
+ {
761
+ "cell_type": "code",
762
+ "source": [
763
+ "# Load the cleaned QA dataset as a training set using the 'datasets' library\n",
764
+ "dataset = load_dataset('csv', data_files={'train': 'cleaned_med_QA_data.csv'}, split='train')"
765
+ ],
766
+ "metadata": {
767
+ "id": "MW5Ad0exrry3"
768
+ },
769
+ "execution_count": null,
770
+ "outputs": []
771
+ },
772
+ {
773
+ "cell_type": "code",
774
+ "source": [
775
+ "#Function to tokenize questions and answers and prepare them for the model\n",
776
+ "def tokenize_function(examples):\n",
777
+ " '''1. Combine each question and answer into a single input string\n",
778
+ " 2. Tokenize the combined text using the GPT-2 tokenizer\n",
779
+ " 3. Set the labels to be the same as the input_ids (shifted to predict the next word)\n",
780
+ " 4. Return the tokenized output. '''\n",
781
+ "\n",
782
+ " combined_text = [str(q) + \" \" + str(a) for q, a in zip(examples['Question'], examples['Answer'])]\n",
783
+ " tokenized_output = tokenizer(combined_text, padding='max_length', truncation=True, max_length=128)\n",
784
+ "\n",
785
+ " # Set the labels to be the same as the input_ids (shifted to predict the next word)\n",
786
+ " tokenized_output['labels'] = tokenized_output['input_ids'].copy()\n",
787
+ "\n",
788
+ " return tokenized_output\n",
789
+ "\n",
790
+ "# Tokenize the entire dataset\n",
791
+ "tokenized_dataset = dataset.map(tokenize_function, batched=True)"
792
+ ],
793
+ "metadata": {
794
+ "id": "99rfOROKr-M0"
795
+ },
796
+ "execution_count": null,
797
+ "outputs": []
798
+ },
799
+ {
800
+ "cell_type": "code",
801
+ "source": [
802
+ "# Define training arguments for the GPT-2 model\n",
803
+ "training_args = TrainingArguments(\n",
804
+ " output_dir='./results', # Directory to save model outputs\n",
805
+ " num_train_epochs=20, # Train for 50 epochs\n",
806
+ " per_device_train_batch_size=16, # Batch size during training\n",
807
+ " per_device_eval_batch_size=32, # Batch size during evaluation\n",
808
+ " warmup_steps=500, # Warmup steps for learning rate scheduler\n",
809
+ " weight_decay=0.01, # Weight decay for regularization\n",
810
+ " logging_dir='./logs', # Directory for saving logs\n",
811
+ " logging_steps=10, # Log every 10 steps\n",
812
+ " save_steps=1000, # Save model checkpoints every 1000 steps\n",
813
+ ")\n",
814
+ "\n",
815
+ "# Trainer class to handle training process\n",
816
+ "trainer = Trainer(\n",
817
+ " model=model,\n",
818
+ " args=training_args,\n",
819
+ " train_dataset=tokenized_dataset,\n",
820
+ " tokenizer=tokenizer,\n",
821
+ ")\n",
822
+ "\n",
823
+ "# Train the model\n",
824
+ "trainer.train()"
825
+ ],
826
+ "metadata": {
827
+ "colab": {
828
+ "base_uri": "https://localhost:8080/",
829
+ "height": 1000
830
+ },
831
+ "id": "TQGJ16yJsCBc",
832
+ "outputId": "ec5b1ae4-83c1-4117-95fe-3aae63fc0f75",
833
+ "collapsed": true
834
+ },
835
+ "execution_count": null,
836
+ "outputs": [
837
+ {
838
+ "output_type": "display_data",
839
+ "data": {
840
+ "text/plain": [
841
+ "<IPython.core.display.HTML object>"
842
+ ],
843
+ "text/html": [
844
+ "\n",
845
+ " <div>\n",
846
+ " \n",
847
+ " <progress value='880' max='880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
848
+ " [880/880 08:45, Epoch 20/20]\n",
849
+ " </div>\n",
850
+ " <table border=\"1\" class=\"dataframe\">\n",
851
+ " <thead>\n",
852
+ " <tr style=\"text-align: left;\">\n",
853
+ " <th>Step</th>\n",
854
+ " <th>Training Loss</th>\n",
855
+ " </tr>\n",
856
+ " </thead>\n",
857
+ " <tbody>\n",
858
+ " <tr>\n",
859
+ " <td>10</td>\n",
860
+ " <td>5.891800</td>\n",
861
+ " </tr>\n",
862
+ " <tr>\n",
863
+ " <td>20</td>\n",
864
+ " <td>5.497900</td>\n",
865
+ " </tr>\n",
866
+ " <tr>\n",
867
+ " <td>30</td>\n",
868
+ " <td>4.671300</td>\n",
869
+ " </tr>\n",
870
+ " <tr>\n",
871
+ " <td>40</td>\n",
872
+ " <td>3.751500</td>\n",
873
+ " </tr>\n",
874
+ " <tr>\n",
875
+ " <td>50</td>\n",
876
+ " <td>3.016000</td>\n",
877
+ " </tr>\n",
878
+ " <tr>\n",
879
+ " <td>60</td>\n",
880
+ " <td>2.633300</td>\n",
881
+ " </tr>\n",
882
+ " <tr>\n",
883
+ " <td>70</td>\n",
884
+ " <td>2.360800</td>\n",
885
+ " </tr>\n",
886
+ " <tr>\n",
887
+ " <td>80</td>\n",
888
+ " <td>2.079000</td>\n",
889
+ " </tr>\n",
890
+ " <tr>\n",
891
+ " <td>90</td>\n",
892
+ " <td>2.145600</td>\n",
893
+ " </tr>\n",
894
+ " <tr>\n",
895
+ " <td>100</td>\n",
896
+ " <td>2.150100</td>\n",
897
+ " </tr>\n",
898
+ " <tr>\n",
899
+ " <td>110</td>\n",
900
+ " <td>2.069300</td>\n",
901
+ " </tr>\n",
902
+ " <tr>\n",
903
+ " <td>120</td>\n",
904
+ " <td>2.000300</td>\n",
905
+ " </tr>\n",
906
+ " <tr>\n",
907
+ " <td>130</td>\n",
908
+ " <td>1.919900</td>\n",
909
+ " </tr>\n",
910
+ " <tr>\n",
911
+ " <td>140</td>\n",
912
+ " <td>1.954000</td>\n",
913
+ " </tr>\n",
914
+ " <tr>\n",
915
+ " <td>150</td>\n",
916
+ " <td>1.928500</td>\n",
917
+ " </tr>\n",
918
+ " <tr>\n",
919
+ " <td>160</td>\n",
920
+ " <td>1.832900</td>\n",
921
+ " </tr>\n",
922
+ " <tr>\n",
923
+ " <td>170</td>\n",
924
+ " <td>1.921300</td>\n",
925
+ " </tr>\n",
926
+ " <tr>\n",
927
+ " <td>180</td>\n",
928
+ " <td>2.043500</td>\n",
929
+ " </tr>\n",
930
+ " <tr>\n",
931
+ " <td>190</td>\n",
932
+ " <td>1.827400</td>\n",
933
+ " </tr>\n",
934
+ " <tr>\n",
935
+ " <td>200</td>\n",
936
+ " <td>1.687700</td>\n",
937
+ " </tr>\n",
938
+ " <tr>\n",
939
+ " <td>210</td>\n",
940
+ " <td>1.782400</td>\n",
941
+ " </tr>\n",
942
+ " <tr>\n",
943
+ " <td>220</td>\n",
944
+ " <td>1.959600</td>\n",
945
+ " </tr>\n",
946
+ " <tr>\n",
947
+ " <td>230</td>\n",
948
+ " <td>1.810500</td>\n",
949
+ " </tr>\n",
950
+ " <tr>\n",
951
+ " <td>240</td>\n",
952
+ " <td>1.706800</td>\n",
953
+ " </tr>\n",
954
+ " <tr>\n",
955
+ " <td>250</td>\n",
956
+ " <td>1.662200</td>\n",
957
+ " </tr>\n",
958
+ " <tr>\n",
959
+ " <td>260</td>\n",
960
+ " <td>1.783900</td>\n",
961
+ " </tr>\n",
962
+ " <tr>\n",
963
+ " <td>270</td>\n",
964
+ " <td>1.567300</td>\n",
965
+ " </tr>\n",
966
+ " <tr>\n",
967
+ " <td>280</td>\n",
968
+ " <td>1.695100</td>\n",
969
+ " </tr>\n",
970
+ " <tr>\n",
971
+ " <td>290</td>\n",
972
+ " <td>1.681800</td>\n",
973
+ " </tr>\n",
974
+ " <tr>\n",
975
+ " <td>300</td>\n",
976
+ " <td>1.657400</td>\n",
977
+ " </tr>\n",
978
+ " <tr>\n",
979
+ " <td>310</td>\n",
980
+ " <td>1.684000</td>\n",
981
+ " </tr>\n",
982
+ " <tr>\n",
983
+ " <td>320</td>\n",
984
+ " <td>1.494700</td>\n",
985
+ " </tr>\n",
986
+ " <tr>\n",
987
+ " <td>330</td>\n",
988
+ " <td>1.556800</td>\n",
989
+ " </tr>\n",
990
+ " <tr>\n",
991
+ " <td>340</td>\n",
992
+ " <td>1.648300</td>\n",
993
+ " </tr>\n",
994
+ " <tr>\n",
995
+ " <td>350</td>\n",
996
+ " <td>1.529300</td>\n",
997
+ " </tr>\n",
998
+ " <tr>\n",
999
+ " <td>360</td>\n",
1000
+ " <td>1.421200</td>\n",
1001
+ " </tr>\n",
1002
+ " <tr>\n",
1003
+ " <td>370</td>\n",
1004
+ " <td>1.483900</td>\n",
1005
+ " </tr>\n",
1006
+ " <tr>\n",
1007
+ " <td>380</td>\n",
1008
+ " <td>1.588400</td>\n",
1009
+ " </tr>\n",
1010
+ " <tr>\n",
1011
+ " <td>390</td>\n",
1012
+ " <td>1.442200</td>\n",
1013
+ " </tr>\n",
1014
+ " <tr>\n",
1015
+ " <td>400</td>\n",
1016
+ " <td>1.524600</td>\n",
1017
+ " </tr>\n",
1018
+ " <tr>\n",
1019
+ " <td>410</td>\n",
1020
+ " <td>1.469100</td>\n",
1021
+ " </tr>\n",
1022
+ " <tr>\n",
1023
+ " <td>420</td>\n",
1024
+ " <td>1.412900</td>\n",
1025
+ " </tr>\n",
1026
+ " <tr>\n",
1027
+ " <td>430</td>\n",
1028
+ " <td>1.388300</td>\n",
1029
+ " </tr>\n",
1030
+ " <tr>\n",
1031
+ " <td>440</td>\n",
1032
+ " <td>1.414400</td>\n",
1033
+ " </tr>\n",
1034
+ " <tr>\n",
1035
+ " <td>450</td>\n",
1036
+ " <td>1.368200</td>\n",
1037
+ " </tr>\n",
1038
+ " <tr>\n",
1039
+ " <td>460</td>\n",
1040
+ " <td>1.374900</td>\n",
1041
+ " </tr>\n",
1042
+ " <tr>\n",
1043
+ " <td>470</td>\n",
1044
+ " <td>1.336500</td>\n",
1045
+ " </tr>\n",
1046
+ " <tr>\n",
1047
+ " <td>480</td>\n",
1048
+ " <td>1.294900</td>\n",
1049
+ " </tr>\n",
1050
+ " <tr>\n",
1051
+ " <td>490</td>\n",
1052
+ " <td>1.231700</td>\n",
1053
+ " </tr>\n",
1054
+ " <tr>\n",
1055
+ " <td>500</td>\n",
1056
+ " <td>1.287600</td>\n",
1057
+ " </tr>\n",
1058
+ " <tr>\n",
1059
+ " <td>510</td>\n",
1060
+ " <td>1.248500</td>\n",
1061
+ " </tr>\n",
1062
+ " <tr>\n",
1063
+ " <td>520</td>\n",
1064
+ " <td>1.220700</td>\n",
1065
+ " </tr>\n",
1066
+ " <tr>\n",
1067
+ " <td>530</td>\n",
1068
+ " <td>1.335700</td>\n",
1069
+ " </tr>\n",
1070
+ " <tr>\n",
1071
+ " <td>540</td>\n",
1072
+ " <td>1.094200</td>\n",
1073
+ " </tr>\n",
1074
+ " <tr>\n",
1075
+ " <td>550</td>\n",
1076
+ " <td>1.151400</td>\n",
1077
+ " </tr>\n",
1078
+ " <tr>\n",
1079
+ " <td>560</td>\n",
1080
+ " <td>1.215000</td>\n",
1081
+ " </tr>\n",
1082
+ " <tr>\n",
1083
+ " <td>570</td>\n",
1084
+ " <td>1.235600</td>\n",
1085
+ " </tr>\n",
1086
+ " <tr>\n",
1087
+ " <td>580</td>\n",
1088
+ " <td>1.139800</td>\n",
1089
+ " </tr>\n",
1090
+ " <tr>\n",
1091
+ " <td>590</td>\n",
1092
+ " <td>1.119600</td>\n",
1093
+ " </tr>\n",
1094
+ " <tr>\n",
1095
+ " <td>600</td>\n",
1096
+ " <td>1.148000</td>\n",
1097
+ " </tr>\n",
1098
+ " <tr>\n",
1099
+ " <td>610</td>\n",
1100
+ " <td>1.057300</td>\n",
1101
+ " </tr>\n",
1102
+ " <tr>\n",
1103
+ " <td>620</td>\n",
1104
+ " <td>1.039700</td>\n",
1105
+ " </tr>\n",
1106
+ " <tr>\n",
1107
+ " <td>630</td>\n",
1108
+ " <td>1.081300</td>\n",
1109
+ " </tr>\n",
1110
+ " <tr>\n",
1111
+ " <td>640</td>\n",
1112
+ " <td>0.960300</td>\n",
1113
+ " </tr>\n",
1114
+ " <tr>\n",
1115
+ " <td>650</td>\n",
1116
+ " <td>1.026400</td>\n",
1117
+ " </tr>\n",
1118
+ " <tr>\n",
1119
+ " <td>660</td>\n",
1120
+ " <td>1.049900</td>\n",
1121
+ " </tr>\n",
1122
+ " <tr>\n",
1123
+ " <td>670</td>\n",
1124
+ " <td>0.967600</td>\n",
1125
+ " </tr>\n",
1126
+ " <tr>\n",
1127
+ " <td>680</td>\n",
1128
+ " <td>0.902100</td>\n",
1129
+ " </tr>\n",
1130
+ " <tr>\n",
1131
+ " <td>690</td>\n",
1132
+ " <td>0.950900</td>\n",
1133
+ " </tr>\n",
1134
+ " <tr>\n",
1135
+ " <td>700</td>\n",
1136
+ " <td>0.998500</td>\n",
1137
+ " </tr>\n",
1138
+ " <tr>\n",
1139
+ " <td>710</td>\n",
1140
+ " <td>1.043500</td>\n",
1141
+ " </tr>\n",
1142
+ " <tr>\n",
1143
+ " <td>720</td>\n",
1144
+ " <td>0.877700</td>\n",
1145
+ " </tr>\n",
1146
+ " <tr>\n",
1147
+ " <td>730</td>\n",
1148
+ " <td>0.818800</td>\n",
1149
+ " </tr>\n",
1150
+ " <tr>\n",
1151
+ " <td>740</td>\n",
1152
+ " <td>0.949500</td>\n",
1153
+ " </tr>\n",
1154
+ " <tr>\n",
1155
+ " <td>750</td>\n",
1156
+ " <td>1.032200</td>\n",
1157
+ " </tr>\n",
1158
+ " <tr>\n",
1159
+ " <td>760</td>\n",
1160
+ " <td>0.813600</td>\n",
1161
+ " </tr>\n",
1162
+ " <tr>\n",
1163
+ " <td>770</td>\n",
1164
+ " <td>0.871600</td>\n",
1165
+ " </tr>\n",
1166
+ " <tr>\n",
1167
+ " <td>780</td>\n",
1168
+ " <td>0.877400</td>\n",
1169
+ " </tr>\n",
1170
+ " <tr>\n",
1171
+ " <td>790</td>\n",
1172
+ " <td>0.952400</td>\n",
1173
+ " </tr>\n",
1174
+ " <tr>\n",
1175
+ " <td>800</td>\n",
1176
+ " <td>0.819600</td>\n",
1177
+ " </tr>\n",
1178
+ " <tr>\n",
1179
+ " <td>810</td>\n",
1180
+ " <td>0.852700</td>\n",
1181
+ " </tr>\n",
1182
+ " <tr>\n",
1183
+ " <td>820</td>\n",
1184
+ " <td>0.848300</td>\n",
1185
+ " </tr>\n",
1186
+ " <tr>\n",
1187
+ " <td>830</td>\n",
1188
+ " <td>0.834200</td>\n",
1189
+ " </tr>\n",
1190
+ " <tr>\n",
1191
+ " <td>840</td>\n",
1192
+ " <td>0.900900</td>\n",
1193
+ " </tr>\n",
1194
+ " <tr>\n",
1195
+ " <td>850</td>\n",
1196
+ " <td>0.830800</td>\n",
1197
+ " </tr>\n",
1198
+ " <tr>\n",
1199
+ " <td>860</td>\n",
1200
+ " <td>0.864700</td>\n",
1201
+ " </tr>\n",
1202
+ " <tr>\n",
1203
+ " <td>870</td>\n",
1204
+ " <td>0.842200</td>\n",
1205
+ " </tr>\n",
1206
+ " <tr>\n",
1207
+ " <td>880</td>\n",
1208
+ " <td>0.865000</td>\n",
1209
+ " </tr>\n",
1210
+ " </tbody>\n",
1211
+ "</table><p>"
1212
+ ]
1213
+ },
1214
+ "metadata": {}
1215
+ },
1216
+ {
1217
+ "output_type": "execute_result",
1218
+ "data": {
1219
+ "text/plain": [
1220
+ "TrainOutput(global_step=880, training_loss=1.5622584277933294, metrics={'train_runtime': 525.9662, 'train_samples_per_second': 26.237, 'train_steps_per_second': 1.673, 'total_flos': 901457510400000.0, 'train_loss': 1.5622584277933294, 'epoch': 20.0})"
1221
+ ]
1222
+ },
1223
+ "metadata": {},
1224
+ "execution_count": 13
1225
+ }
1226
+ ]
1227
+ },
1228
+ {
1229
+ "cell_type": "code",
1230
+ "source": [
1231
+ "# Save the model\n",
1232
+ "trainer.save_model('med_info_model')"
1233
+ ],
1234
+ "metadata": {
1235
+ "id": "4UrH8iP0u6Cp"
1236
+ },
1237
+ "execution_count": null,
1238
+ "outputs": []
1239
+ },
1240
+ {
1241
+ "cell_type": "markdown",
1242
+ "source": [
1243
+ "### Testing"
1244
+ ],
1245
+ "metadata": {
1246
+ "id": "VhXRJT6jeTuz"
1247
+ }
1248
+ },
1249
+ {
1250
+ "cell_type": "code",
1251
+ "source": [
1252
+ "# Function to generate a response based on a user prompt (testing the model)\n",
1253
+ "def generate_response(prompt):\n",
1254
+ " inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to('cuda')\n",
1255
+ " outputs = model.generate(inputs, max_length=150, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)\n",
1256
+ "\n",
1257
+ " # Decode the generated output\n",
1258
+ " response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
1259
+ "\n",
1260
+ " # Remove the prompt from the response\n",
1261
+ " if response.startswith(prompt):\n",
1262
+ " response = response[len(prompt):].strip() # Remove the prompt from the response\n",
1263
+ "\n",
1264
+ " return response"
1265
+ ],
1266
+ "metadata": {
1267
+ "id": "JbMs8UuSu5_R"
1268
+ },
1269
+ "execution_count": null,
1270
+ "outputs": []
1271
+ },
1272
+ {
1273
+ "cell_type": "code",
1274
+ "source": [
1275
+ "# Example conversation\n",
1276
+ "user_input = \"what is desonide ointment used for\"\n",
1277
+ "bot_response = generate_response(user_input)\n",
1278
+ "print(\"Bot Response:\", bot_response)"
1279
+ ],
1280
+ "metadata": {
1281
+ "colab": {
1282
+ "base_uri": "https://localhost:8080/"
1283
+ },
1284
+ "id": "qsHAT1-uxC4_",
1285
+ "outputId": "89b73c5f-0ae9-449d-8eb4-3df1a7c146bb"
1286
+ },
1287
+ "execution_count": null,
1288
+ "outputs": [
1289
+ {
1290
+ "output_type": "stream",
1291
+ "name": "stdout",
1292
+ "text": [
1293
+ "Bot Response: desonide ointment is used to treat a variety of conditions it is used to treat allergies and other skin conditions it is also used to treat certain types of infections it is also used to treat skin infections caused by bacteria that are on skin desonide is in a class of medications called antimicrobials it works by killing bacteria that cause skin infections desonide is in a class of medications called antibiotics it works by killing bacteria that cause skin infections\n"
1294
+ ]
1295
+ }
1296
+ ]
1297
+ },
1298
+ {
1299
+ "cell_type": "code",
1300
+ "source": [
1301
+ "# Copying the model to Google Drive (optional)\n",
1302
+ "import shutil\n",
1303
+ "\n",
1304
+ "# Path to the file in Colab\n",
1305
+ "colab_file_path = '/content/med_info_model/model.safetensors'\n",
1306
+ "\n",
1307
+ "# Path to your Google Drive\n",
1308
+ "drive_file_path = '/content/drive/MyDrive'\n",
1309
+ "\n",
1310
+ "# Copy the file\n",
1311
+ "shutil.copy(colab_file_path, drive_file_path)"
1312
+ ],
1313
+ "metadata": {
1314
+ "colab": {
1315
+ "base_uri": "https://localhost:8080/",
1316
+ "height": 36
1317
+ },
1318
+ "id": "aP4IEboMxDWG",
1319
+ "outputId": "c00d1d74-e389-4de4-a151-d20736b6bccd"
1320
+ },
1321
+ "execution_count": null,
1322
+ "outputs": [
1323
+ {
1324
+ "output_type": "execute_result",
1325
+ "data": {
1326
+ "text/plain": [
1327
+ "'/content/drive/MyDrive/model.safetensors'"
1328
+ ],
1329
+ "application/vnd.google.colaboratory.intrinsic+json": {
1330
+ "type": "string"
1331
+ }
1332
+ },
1333
+ "metadata": {},
1334
+ "execution_count": 22
1335
+ }
1336
+ ]
1337
+ },
1338
+ {
1339
+ "cell_type": "code",
1340
+ "source": [],
1341
+ "metadata": {
1342
+ "id": "uKYwYe5XyXgx"
1343
+ },
1344
+ "execution_count": null,
1345
+ "outputs": []
1346
+ }
1347
+ ]
1348
+ }