shamashel commited on
Commit
2a3cf64
·
1 Parent(s): 55e1276

Retrieve dataset from dataset instead of text file

Browse files
Files changed (7) hide show
  1. bigram.py +7 -7
  2. dataset.py +19 -0
  3. encoder.py +3 -2
  4. input.txt +0 -0
  5. model.pth +1 -1
  6. poetry.lock +0 -0
  7. pyproject.toml +4 -0
bigram.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import torch.nn as nn
4
  from torch.nn import functional as F
5
  import numpy as np
 
6
 
7
  from encoder import encode, decode
8
  from self_attention import Head, MultiHead
@@ -13,13 +14,12 @@ class Batcher():
13
  self.device = device
14
  self.batch_size = batch_size
15
  self.block_size = block_size
16
- with open('input.txt', 'r', encoding='utf-8') as f:
17
- text = f.read()
18
- my_tensors = torch.tensor(encode(text), dtype=torch.long)
19
- n = int(0.9*len(my_tensors))
20
- self.train_data = my_tensors[:n]
21
- self.val_data = my_tensors[n:]
22
- self.vocab = set(text)
23
 
24
  def get_batch(self, split: str = 'val'):
25
  data = self.train_data if split == 'train' else self.val_data
 
3
  import torch.nn as nn
4
  from torch.nn import functional as F
5
  import numpy as np
6
+ from datasets import load_dataset
7
 
8
  from encoder import encode, decode
9
  from self_attention import Head, MultiHead
 
14
  self.device = device
15
  self.batch_size = batch_size
16
  self.block_size = block_size
17
+ from dataset import make_dataset
18
+ train_data = make_dataset('train')
19
+ val_data = make_dataset('validation')
20
+ self.train_data = torch.tensor(encode(train_data), dtype=torch.long)
21
+ self.val_data = torch.tensor(encode(val_data), dtype=torch.long)
22
+ self.vocab = set(train_data + val_data)
 
23
 
24
  def get_batch(self, split: str = 'val'):
25
  data = self.train_data if split == 'train' else self.val_data
dataset.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from datasets import load_dataset, Dataset
3
+
4
+ _datasets = {
5
+ 'train': None,
6
+ 'validation': None,
7
+ 'test': None,
8
+ }
9
+
10
+ # Lazy load the dataset
11
+
12
+
13
+ def make_dataset(split: Literal['train', 'validation', 'test'] = 'train'):
14
+ if _datasets[split] is None:
15
+ ds: Dataset = load_dataset(
16
+ "karpathy/tiny_shakespeare", split=split, trust_remote_code=True)
17
+ out = str(list(ds)[0]['text'])
18
+ _datasets[split] = out
19
+ return _datasets[split]
encoder.py CHANGED
@@ -1,5 +1,6 @@
1
- with open('input.txt', 'r', encoding='utf-8') as f:
2
- text = f.read()
 
3
 
4
  chars = sorted(list(set(text)))
5
  stoi = {ch: i for i, ch in enumerate(chars)}
 
1
+ from dataset import make_dataset
2
+
3
+ text = make_dataset('train') + make_dataset('validation')
4
 
5
  chars = sorted(list(set(text)))
6
  stoi = {ch: i for i, ch in enumerate(chars)}
input.txt DELETED
The diff for this file is too large to render. See raw diff
 
model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9036021f6ede2817a9b030d5dd605d38cf89ee627129bc791f5e9cc0b948aae1
3
  size 139095034
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f623313ac43af74c994754923b93641afef4c026c03a09b28d0e06640875675
3
  size 139095034
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -10,8 +10,12 @@ package-mode = false
10
  python = "^3.10"
11
  torch = "^2.3.0"
12
  numpy = "^1.26.4"
 
13
 
14
 
 
 
 
15
  [build-system]
16
  requires = ["poetry-core"]
17
  build-backend = "poetry.core.masonry.api"
 
10
  python = "^3.10"
11
  torch = "^2.3.0"
12
  numpy = "^1.26.4"
13
+ datasets = "^2.19.0"
14
 
15
 
16
+ [tool.poetry.group.dev.dependencies]
17
+ ipykernel = "^6.29.4"
18
+
19
  [build-system]
20
  requires = ["poetry-core"]
21
  build-backend = "poetry.core.masonry.api"