Retrieve dataset from dataset instead of text file
Browse files- bigram.py +7 -7
- dataset.py +19 -0
- encoder.py +3 -2
- input.txt +0 -0
- model.pth +1 -1
- poetry.lock +0 -0
- pyproject.toml +4 -0
bigram.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
import torch.nn as nn
|
4 |
from torch.nn import functional as F
|
5 |
import numpy as np
|
|
|
6 |
|
7 |
from encoder import encode, decode
|
8 |
from self_attention import Head, MultiHead
|
@@ -13,13 +14,12 @@ class Batcher():
|
|
13 |
self.device = device
|
14 |
self.batch_size = batch_size
|
15 |
self.block_size = block_size
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
self.vocab = set(text)
|
23 |
|
24 |
def get_batch(self, split: str = 'val'):
|
25 |
data = self.train_data if split == 'train' else self.val_data
|
|
|
3 |
import torch.nn as nn
|
4 |
from torch.nn import functional as F
|
5 |
import numpy as np
|
6 |
+
from datasets import load_dataset
|
7 |
|
8 |
from encoder import encode, decode
|
9 |
from self_attention import Head, MultiHead
|
|
|
14 |
self.device = device
|
15 |
self.batch_size = batch_size
|
16 |
self.block_size = block_size
|
17 |
+
from dataset import make_dataset
|
18 |
+
train_data = make_dataset('train')
|
19 |
+
val_data = make_dataset('validation')
|
20 |
+
self.train_data = torch.tensor(encode(train_data), dtype=torch.long)
|
21 |
+
self.val_data = torch.tensor(encode(val_data), dtype=torch.long)
|
22 |
+
self.vocab = set(train_data + val_data)
|
|
|
23 |
|
24 |
def get_batch(self, split: str = 'val'):
|
25 |
data = self.train_data if split == 'train' else self.val_data
|
dataset.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
from datasets import load_dataset, Dataset
|
3 |
+
|
4 |
+
_datasets = {
|
5 |
+
'train': None,
|
6 |
+
'validation': None,
|
7 |
+
'test': None,
|
8 |
+
}
|
9 |
+
|
10 |
+
# Lazy load the dataset
|
11 |
+
|
12 |
+
|
13 |
+
def make_dataset(split: Literal['train', 'validation', 'test'] = 'train'):
|
14 |
+
if _datasets[split] is None:
|
15 |
+
ds: Dataset = load_dataset(
|
16 |
+
"karpathy/tiny_shakespeare", split=split, trust_remote_code=True)
|
17 |
+
out = str(list(ds)[0]['text'])
|
18 |
+
_datasets[split] = out
|
19 |
+
return _datasets[split]
|
encoder.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
-
|
2 |
-
|
|
|
3 |
|
4 |
chars = sorted(list(set(text)))
|
5 |
stoi = {ch: i for i, ch in enumerate(chars)}
|
|
|
1 |
+
from dataset import make_dataset
|
2 |
+
|
3 |
+
text = make_dataset('train') + make_dataset('validation')
|
4 |
|
5 |
chars = sorted(list(set(text)))
|
6 |
stoi = {ch: i for i, ch in enumerate(chars)}
|
input.txt
DELETED
The diff for this file is too large to render.
See raw diff
|
|
model.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 139095034
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0f623313ac43af74c994754923b93641afef4c026c03a09b28d0e06640875675
|
3 |
size 139095034
|
poetry.lock
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
CHANGED
@@ -10,8 +10,12 @@ package-mode = false
|
|
10 |
python = "^3.10"
|
11 |
torch = "^2.3.0"
|
12 |
numpy = "^1.26.4"
|
|
|
13 |
|
14 |
|
|
|
|
|
|
|
15 |
[build-system]
|
16 |
requires = ["poetry-core"]
|
17 |
build-backend = "poetry.core.masonry.api"
|
|
|
10 |
python = "^3.10"
|
11 |
torch = "^2.3.0"
|
12 |
numpy = "^1.26.4"
|
13 |
+
datasets = "^2.19.0"
|
14 |
|
15 |
|
16 |
+
[tool.poetry.group.dev.dependencies]
|
17 |
+
ipykernel = "^6.29.4"
|
18 |
+
|
19 |
[build-system]
|
20 |
requires = ["poetry-core"]
|
21 |
build-backend = "poetry.core.masonry.api"
|