AbinayaM02 commited on
Commit
900235d
·
1 Parent(s): bc5106b

GPT2 - Oscar + Indic Corpus model

Browse files
Files changed (50) hide show
  1. .flake8 +8 -0
  2. .gitattributes +2 -1
  3. .gitignore +129 -0
  4. .pre-commit-config.yaml +28 -0
  5. LICENSE +21 -0
  6. README.md +13 -0
  7. dataset/README.md +1 -0
  8. demo/README.md +1 -0
  9. demo/tamil_generator.py +31 -0
  10. gpt-2-tamil/config.json +36 -0
  11. gpt-2-tamil/events.out.tfevents.1626336540.t1v-n-ebe36c53-w-0.751183.3.v2 +3 -0
  12. gpt-2-tamil/events.out.tfevents.1626339585.t1v-n-ebe36c53-w-0.759145.3.v2 +3 -0
  13. gpt-2-tamil/events.out.tfevents.1626340740.t1v-n-ebe36c53-w-0.765413.3.v2 +3 -0
  14. gpt-2-tamil/events.out.tfevents.1626341319.t1v-n-ebe36c53-w-0.768105.3.v2 +3 -0
  15. gpt-2-tamil/flax_model.msgpack +3 -0
  16. gpt-2-tamil/tokenizer.json +0 -0
  17. model/README.md +1 -0
  18. notebook/README.md +1 -0
  19. pyproject.toml +31 -0
  20. requirements.txt +8 -0
  21. scripts/train_gpt2-oscar-tamil.sh +25 -0
  22. scripts/wandb/latest-run +1 -0
  23. scripts/wandb/run-20210712_164633-1ddv4131/run-1ddv4131.wandb +3 -0
  24. scripts/wandb/run-20210715_080856-2mpx5n1j/files/config.yaml +305 -0
  25. scripts/wandb/run-20210715_080856-2mpx5n1j/files/events.out.tfevents.1626336540.t1v-n-ebe36c53-w-0.751183.3.v2 +1 -0
  26. scripts/wandb/run-20210715_080856-2mpx5n1j/files/requirements.txt +123 -0
  27. scripts/wandb/run-20210715_080856-2mpx5n1j/files/wandb-metadata.json +49 -0
  28. scripts/wandb/run-20210715_080856-2mpx5n1j/files/wandb-summary.json +1 -0
  29. scripts/wandb/run-20210715_080856-2mpx5n1j/run-2mpx5n1j.wandb +3 -0
  30. scripts/wandb/run-20210715_085943-1ize2alk/files/config.yaml +301 -0
  31. scripts/wandb/run-20210715_085943-1ize2alk/files/events.out.tfevents.1626339585.t1v-n-ebe36c53-w-0.759145.3.v2 +1 -0
  32. scripts/wandb/run-20210715_085943-1ize2alk/files/requirements.txt +123 -0
  33. scripts/wandb/run-20210715_085943-1ize2alk/files/wandb-metadata.json +49 -0
  34. scripts/wandb/run-20210715_085943-1ize2alk/files/wandb-summary.json +1 -0
  35. scripts/wandb/run-20210715_085943-1ize2alk/run-1ize2alk.wandb +3 -0
  36. scripts/wandb/run-20210715_091856-2v0tf7h4/files/config.yaml +305 -0
  37. scripts/wandb/run-20210715_091856-2v0tf7h4/files/events.out.tfevents.1626340740.t1v-n-ebe36c53-w-0.765413.3.v2 +1 -0
  38. scripts/wandb/run-20210715_091856-2v0tf7h4/files/requirements.txt +123 -0
  39. scripts/wandb/run-20210715_091856-2v0tf7h4/files/wandb-metadata.json +49 -0
  40. scripts/wandb/run-20210715_091856-2v0tf7h4/files/wandb-summary.json +1 -0
  41. scripts/wandb/run-20210715_091856-2v0tf7h4/run-2v0tf7h4.wandb +3 -0
  42. scripts/wandb/run-20210715_092837-watdq7ib/files/config.yaml +301 -0
  43. scripts/wandb/run-20210715_092837-watdq7ib/files/events.out.tfevents.1626341319.t1v-n-ebe36c53-w-0.768105.3.v2 +1 -0
  44. scripts/wandb/run-20210715_092837-watdq7ib/files/requirements.txt +123 -0
  45. scripts/wandb/run-20210715_092837-watdq7ib/files/wandb-metadata.json +49 -0
  46. scripts/wandb/run-20210715_092837-watdq7ib/files/wandb-summary.json +1 -0
  47. scripts/wandb/run-20210715_092837-watdq7ib/run-watdq7ib.wandb +3 -0
  48. src/create_config.py +8 -0
  49. src/run_clm_flax.py +661 -0
  50. src/train_tokenizer.py +40 -0
.flake8 ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ exclude = venv
3
+ ignore = E501, W503, E226, E203
4
+ max-line-length = 85
5
+
6
+ # E501: Line too long
7
+ # W503: Line break occurred before binary operator
8
+ # E226: Missing white space around arithmetic operator
.gitattributes CHANGED
@@ -12,6 +12,7 @@
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
- *.pt filter=lfs diff=lfs merge=lfs -text
 
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
+ *.log filter=lfs diff=lfs merge=lfs -text
16
+ *.wandb filter=lfs diff=lfs merge=lfs -text
17
  *.pth filter=lfs diff=lfs merge=lfs -text
18
  *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://pre-commit.com for more information
2
+ # See https://pre-commit.com/hooks.html for more hooks
3
+ repos:
4
+ - repo: https://github.com/pre-commit/pre-commit-hooks
5
+ rev: v3.4.0
6
+ hooks:
7
+ - id: trailing-whitespace
8
+ - id: check-yaml
9
+ - id: check-ast
10
+ - id: check-json
11
+ - id: check-merge-conflict
12
+ - id: detect-private-key
13
+ - repo: https://github.com/psf/black
14
+ rev: 21.6b0
15
+ hooks:
16
+ - id: black
17
+ args: []
18
+ files: .
19
+ - repo: https://gitlab.com/PyCQA/flake8
20
+ rev: 3.9.2
21
+ hooks:
22
+ - id: flake8
23
+ - repo: https://github.com/PyCQA/isort
24
+ rev: 5.9.1
25
+ hooks:
26
+ - id: isort
27
+ args: []
28
+ files: .
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Abinaya Mahendiran
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPT2-Tamil
2
+
3
+ This repository is created as part of the Flax/Jax community week by Huggingface. The aim of this project is to pre-train a language model using GPT-2 specifically for Tamil language.
4
+
5
+ ## Setup [Todo]:
6
+
7
+ ## Dataset Used [Todo]:
8
+
9
+ ## Preprocess Data [Todo]:
10
+
11
+ ## Train (Flax) [Todo]:
12
+
13
+ ## Demo [Todo]:
dataset/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Details of the dataset can go here. The folder can also contain dataset (downloaded locally).
demo/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Streamlit demo can go here.
demo/tamil_generator.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import locale
3
+ print(locale.getpreferredencoding())
4
+
5
+
6
+ from transformers import AutoConfig, AutoModelForCausalLM,pipeline,AutoTokenizer
7
+ from datasets import load_dataset
8
+
9
+ MODEL_DIR = "/home/deepak/sources/gpt2-tamil/gpt2-tamil/"
10
+
11
+
12
+
13
+
14
+
15
+ #get prompt from dataset, will be replaced by manual prompt once I figure out how to render tamil font
16
+ dataset = load_dataset("oscar", "unshuffled_deduplicated_ta", split="train")
17
+ id =232
18
+ print(dataset[id]['text'])
19
+ tamil_prompt =dataset[id]['text']
20
+
21
+ # Get configuration and the model
22
+ config = AutoConfig.from_pretrained(MODEL_DIR)
23
+ model = AutoModelForCausalLM.from_config(config)
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
26
+
27
+
28
+ generator= pipeline('text-generation', model=model, tokenizer=tokenizer)
29
+ model_output = generator(tamil_prompt, max_length=30, num_return_sequences=5)
30
+ print(model_output)
31
+
gpt-2-tamil/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "gelu_new",
3
+ "architectures": [
4
+ "GPT2LMHeadModel"
5
+ ],
6
+ "attn_pdrop": 0.0,
7
+ "bos_token_id": 50256,
8
+ "embd_pdrop": 0.0,
9
+ "eos_token_id": 50256,
10
+ "gradient_checkpointing": false,
11
+ "initializer_range": 0.02,
12
+ "layer_norm_epsilon": 1e-05,
13
+ "model_type": "gpt2",
14
+ "n_ctx": 1024,
15
+ "n_embd": 768,
16
+ "n_head": 12,
17
+ "n_inner": null,
18
+ "n_layer": 12,
19
+ "n_positions": 1024,
20
+ "resid_pdrop": 0.0,
21
+ "scale_attn_weights": true,
22
+ "summary_activation": null,
23
+ "summary_first_dropout": 0.1,
24
+ "summary_proj_to_labels": true,
25
+ "summary_type": "cls_index",
26
+ "summary_use_proj": true,
27
+ "task_specific_params": {
28
+ "text-generation": {
29
+ "do_sample": true,
30
+ "max_length": 50
31
+ }
32
+ },
33
+ "transformers_version": "4.9.0.dev0",
34
+ "use_cache": true,
35
+ "vocab_size": 50257
36
+ }
gpt-2-tamil/events.out.tfevents.1626336540.t1v-n-ebe36c53-w-0.751183.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1799847ce42c1a5f9fe25dfa8d8da9e1a6ff57595979b2bd0daea658d9ea785
3
+ size 40
gpt-2-tamil/events.out.tfevents.1626339585.t1v-n-ebe36c53-w-0.759145.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b47918f07e65192c48181c8f775cbf29f08585ac3a559e67df1e3f13fb1ca01
3
+ size 40
gpt-2-tamil/events.out.tfevents.1626340740.t1v-n-ebe36c53-w-0.765413.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5855b0a71977e29453739fe2c5055c32753a62fa6d3db8ea3f105fd8ca75357b
3
+ size 40
gpt-2-tamil/events.out.tfevents.1626341319.t1v-n-ebe36c53-w-0.768105.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:938ebc19608236e36e53fd65f7c12c9d7ad0de447d01d60627441645872ef573
3
+ size 22272043
gpt-2-tamil/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88b3f8ffc1e0cdd50358b8110910421ef1594f0559eea806e38bb95b186e0e03
3
+ size 497764120
gpt-2-tamil/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
model/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Model card details can go here.
notebook/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Notebook can go here.
pyproject.toml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Black formatting
2
+ [tool.black]
3
+ line-length = 85
4
+ include = '\.pyi?$'
5
+ exclude = '''
6
+ /(
7
+ \.eggs # exclude a few common directories in the
8
+ | \.git # root of the project
9
+ | \.hg
10
+ | \.mypy_cache
11
+ | \.tox
12
+ | \.venv
13
+ | _build
14
+ | buck-out
15
+ | build
16
+ | dist
17
+ | wandb
18
+ | model
19
+ | dataset
20
+ | notebook
21
+ )/
22
+ '''
23
+
24
+ # iSort
25
+ [tool.isort]
26
+ profile = "black"
27
+ line_length = 85
28
+ multi_line_output = 3
29
+ include_trailing_comma = true
30
+ skip_gitignore = true
31
+ virtual_env = "venv"
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ transformers
3
+ datasets
4
+ jax
5
+ jaxlib
6
+ flax
7
+ optax
8
+ wandb
scripts/train_gpt2-oscar-tamil.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ python ../src/run_clm_flax.py \
3
+ --output_dir="${MODEL_DIR}" \
4
+ --model_type="gpt2" \
5
+ --config_name="${MODEL_DIR}" \
6
+ --tokenizer_name="${MODEL_DIR}" \
7
+ --dataset_name="oscar" \
8
+ --dataset_config_name="unshuffled_deduplicated_ta" \
9
+ --do_train --do_eval \
10
+ --block_size="512" \
11
+ --per_device_train_batch_size="64" \
12
+ --per_device_eval_batch_size="64" \
13
+ --learning_rate="3e-5" \
14
+ --warmup_steps="1000" \
15
+ --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
16
+ --overwrite_output_dir \
17
+ --num_train_epochs="10" \
18
+ --report_to wandb \
19
+ --run_name trial \
20
+ --logging_steps="500" \
21
+ --save_steps="2500" \
22
+ --eval_steps="2500" \
23
+ --preprocessing_num_workers="90" \
24
+ #--push_to_hub
25
+ 2>&1 | tee run.log
scripts/wandb/latest-run ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20210715_092837-watdq7ib
scripts/wandb/run-20210712_164633-1ddv4131/run-1ddv4131.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8211487b4d0a0489ae4728120abad1be7ee4190520afc47fdae166087ae6068
3
+ size 60817322
scripts/wandb/run-20210715_080856-2mpx5n1j/files/config.yaml ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ __cached__setup_devices:
4
+ desc: null
5
+ value: cpu
6
+ _n_gpu:
7
+ desc: null
8
+ value: 0
9
+ _wandb:
10
+ desc: null
11
+ value:
12
+ cli_version: 0.10.33
13
+ framework: huggingface
14
+ huggingface_version: 4.9.0.dev0
15
+ is_jupyter_run: false
16
+ is_kaggle_kernel: false
17
+ python_version: 3.8.10
18
+ t:
19
+ 1:
20
+ - 1
21
+ - 3
22
+ - 11
23
+ 2:
24
+ - 1
25
+ - 3
26
+ - 11
27
+ 4: 3.8.10
28
+ 5: 0.10.33
29
+ 6: 4.9.0.dev0
30
+ 8:
31
+ - 5
32
+ adafactor:
33
+ desc: null
34
+ value: false
35
+ adam_beta1:
36
+ desc: null
37
+ value: 0.9
38
+ adam_beta2:
39
+ desc: null
40
+ value: 0.98
41
+ adam_epsilon:
42
+ desc: null
43
+ value: 1.0e-08
44
+ block_size:
45
+ desc: null
46
+ value: 512
47
+ cache_dir:
48
+ desc: null
49
+ value: null
50
+ config_name:
51
+ desc: null
52
+ value: ../gpt-2-tamil
53
+ dataloader_drop_last:
54
+ desc: null
55
+ value: false
56
+ dataloader_num_workers:
57
+ desc: null
58
+ value: 0
59
+ dataloader_pin_memory:
60
+ desc: null
61
+ value: true
62
+ dataset_config_name:
63
+ desc: null
64
+ value: unshuffled_deduplicated_ta
65
+ dataset_name:
66
+ desc: null
67
+ value: oscar
68
+ ddp_find_unused_parameters:
69
+ desc: null
70
+ value: null
71
+ debug:
72
+ desc: null
73
+ value: []
74
+ deepspeed:
75
+ desc: null
76
+ value: null
77
+ disable_tqdm:
78
+ desc: null
79
+ value: false
80
+ do_eval:
81
+ desc: null
82
+ value: true
83
+ do_predict:
84
+ desc: null
85
+ value: false
86
+ do_train:
87
+ desc: null
88
+ value: true
89
+ dtype:
90
+ desc: null
91
+ value: float32
92
+ eval_accumulation_steps:
93
+ desc: null
94
+ value: null
95
+ eval_steps:
96
+ desc: null
97
+ value: 2500
98
+ evaluation_strategy:
99
+ desc: null
100
+ value: IntervalStrategy.NO
101
+ fp16:
102
+ desc: null
103
+ value: false
104
+ fp16_backend:
105
+ desc: null
106
+ value: auto
107
+ fp16_full_eval:
108
+ desc: null
109
+ value: false
110
+ fp16_opt_level:
111
+ desc: null
112
+ value: O1
113
+ gradient_accumulation_steps:
114
+ desc: null
115
+ value: 1
116
+ greater_is_better:
117
+ desc: null
118
+ value: null
119
+ group_by_length:
120
+ desc: null
121
+ value: false
122
+ ignore_data_skip:
123
+ desc: null
124
+ value: false
125
+ label_names:
126
+ desc: null
127
+ value: null
128
+ label_smoothing_factor:
129
+ desc: null
130
+ value: 0.0
131
+ learning_rate:
132
+ desc: null
133
+ value: 3.0e-05
134
+ length_column_name:
135
+ desc: null
136
+ value: length
137
+ load_best_model_at_end:
138
+ desc: null
139
+ value: false
140
+ local_rank:
141
+ desc: null
142
+ value: -1
143
+ log_level:
144
+ desc: null
145
+ value: -1
146
+ log_level_replica:
147
+ desc: null
148
+ value: -1
149
+ log_on_each_node:
150
+ desc: null
151
+ value: true
152
+ logging_dir:
153
+ desc: null
154
+ value: ../gpt-2-tamil/runs/Jul15_06-31-48_t1v-n-ebe36c53-w-0
155
+ logging_first_step:
156
+ desc: null
157
+ value: false
158
+ logging_steps:
159
+ desc: null
160
+ value: 500
161
+ logging_strategy:
162
+ desc: null
163
+ value: IntervalStrategy.STEPS
164
+ lr_scheduler_type:
165
+ desc: null
166
+ value: SchedulerType.LINEAR
167
+ max_eval_samples:
168
+ desc: null
169
+ value: null
170
+ max_grad_norm:
171
+ desc: null
172
+ value: 1.0
173
+ max_steps:
174
+ desc: null
175
+ value: -1
176
+ max_train_samples:
177
+ desc: null
178
+ value: null
179
+ metric_for_best_model:
180
+ desc: null
181
+ value: null
182
+ model_name_or_path:
183
+ desc: null
184
+ value: null
185
+ model_type:
186
+ desc: null
187
+ value: gpt2
188
+ mp_parameters:
189
+ desc: null
190
+ value: ''
191
+ no_cuda:
192
+ desc: null
193
+ value: false
194
+ num_train_epochs:
195
+ desc: null
196
+ value: 10.0
197
+ output_dir:
198
+ desc: null
199
+ value: ../gpt-2-tamil
200
+ overwrite_cache:
201
+ desc: null
202
+ value: false
203
+ overwrite_output_dir:
204
+ desc: null
205
+ value: true
206
+ past_index:
207
+ desc: null
208
+ value: -1
209
+ per_device_eval_batch_size:
210
+ desc: null
211
+ value: 128
212
+ per_device_train_batch_size:
213
+ desc: null
214
+ value: 128
215
+ per_gpu_eval_batch_size:
216
+ desc: null
217
+ value: null
218
+ per_gpu_train_batch_size:
219
+ desc: null
220
+ value: null
221
+ prediction_loss_only:
222
+ desc: null
223
+ value: false
224
+ preprocessing_num_workers:
225
+ desc: null
226
+ value: 90
227
+ push_to_hub:
228
+ desc: null
229
+ value: false
230
+ push_to_hub_model_id:
231
+ desc: null
232
+ value: gpt-2-tamil
233
+ push_to_hub_organization:
234
+ desc: null
235
+ value: null
236
+ push_to_hub_token:
237
+ desc: null
238
+ value: null
239
+ remove_unused_columns:
240
+ desc: null
241
+ value: true
242
+ report_to:
243
+ desc: null
244
+ value:
245
+ - wandb
246
+ resume_from_checkpoint:
247
+ desc: null
248
+ value: null
249
+ run_name:
250
+ desc: null
251
+ value: trial
252
+ save_on_each_node:
253
+ desc: null
254
+ value: false
255
+ save_steps:
256
+ desc: null
257
+ value: 2500
258
+ save_strategy:
259
+ desc: null
260
+ value: IntervalStrategy.STEPS
261
+ save_total_limit:
262
+ desc: null
263
+ value: null
264
+ seed:
265
+ desc: null
266
+ value: 42
267
+ sharded_ddp:
268
+ desc: null
269
+ value: []
270
+ skip_memory_metrics:
271
+ desc: null
272
+ value: true
273
+ tokenizer_name:
274
+ desc: null
275
+ value: ../gpt-2-tamil
276
+ tpu_metrics_debug:
277
+ desc: null
278
+ value: false
279
+ tpu_num_cores:
280
+ desc: null
281
+ value: null
282
+ train_file:
283
+ desc: null
284
+ value: null
285
+ use_fast_tokenizer:
286
+ desc: null
287
+ value: true
288
+ use_legacy_prediction_loop:
289
+ desc: null
290
+ value: false
291
+ validation_file:
292
+ desc: null
293
+ value: null
294
+ validation_split_percentage:
295
+ desc: null
296
+ value: 5
297
+ warmup_ratio:
298
+ desc: null
299
+ value: 0.0
300
+ warmup_steps:
301
+ desc: null
302
+ value: 1000
303
+ weight_decay:
304
+ desc: null
305
+ value: 0.01
scripts/wandb/run-20210715_080856-2mpx5n1j/files/events.out.tfevents.1626336540.t1v-n-ebe36c53-w-0.751183.3.v2 ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/tweety_abi/GPT2-Tamil/gpt-2-tamil/events.out.tfevents.1626336540.t1v-n-ebe36c53-w-0.751183.3.v2
scripts/wandb/run-20210715_080856-2mpx5n1j/files/requirements.txt ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==0.13.0
2
+ aiohttp==3.7.4.post0
3
+ appdirs==1.4.4
4
+ astunparse==1.6.3
5
+ async-timeout==3.0.1
6
+ attrs==21.2.0
7
+ backcall==0.2.0
8
+ black==21.6b0
9
+ cachetools==4.2.2
10
+ certifi==2021.5.30
11
+ cfgv==3.3.0
12
+ chardet==4.0.0
13
+ chex==0.0.7
14
+ click==8.0.1
15
+ configparser==5.0.2
16
+ cycler==0.10.0
17
+ datasets==1.8.1.dev0
18
+ decorator==5.0.9
19
+ dill==0.3.4
20
+ distlib==0.3.2
21
+ dm-tree==0.1.6
22
+ docker-pycreds==0.4.0
23
+ filelock==3.0.12
24
+ flake8==3.9.2
25
+ flatbuffers==1.12
26
+ flax==0.3.4
27
+ fsspec==2021.6.1
28
+ gast==0.4.0
29
+ gitdb==4.0.7
30
+ gitpython==3.1.18
31
+ google-auth-oauthlib==0.4.4
32
+ google-auth==1.32.1
33
+ google-pasta==0.2.0
34
+ grpcio==1.34.1
35
+ h5py==3.1.0
36
+ huggingface-hub==0.0.12
37
+ identify==2.2.10
38
+ idna==2.10
39
+ ipython-genutils==0.2.0
40
+ ipython==7.25.0
41
+ isort==5.9.1
42
+ jax==0.2.16
43
+ jaxlib==0.1.68
44
+ jedi==0.18.0
45
+ joblib==1.0.1
46
+ keras-nightly==2.5.0.dev2021032900
47
+ keras-preprocessing==1.1.2
48
+ kiwisolver==1.3.1
49
+ libtpu-nightly==0.1.dev20210615
50
+ markdown==3.3.4
51
+ matplotlib-inline==0.1.2
52
+ matplotlib==3.4.2
53
+ mccabe==0.6.1
54
+ msgpack==1.0.2
55
+ multidict==5.1.0
56
+ multiprocess==0.70.12.2
57
+ mypy-extensions==0.4.3
58
+ nodeenv==1.6.0
59
+ numpy==1.19.5
60
+ oauthlib==3.1.1
61
+ opt-einsum==3.3.0
62
+ optax==0.0.8
63
+ packaging==20.9
64
+ pandas==1.2.5
65
+ parso==0.8.2
66
+ pathspec==0.8.1
67
+ pathtools==0.1.2
68
+ pexpect==4.8.0
69
+ pickleshare==0.7.5
70
+ pillow==8.3.0
71
+ pip==20.0.2
72
+ pkg-resources==0.0.0
73
+ pre-commit==2.13.0
74
+ promise==2.3
75
+ prompt-toolkit==3.0.19
76
+ protobuf==3.17.3
77
+ psutil==5.8.0
78
+ ptyprocess==0.7.0
79
+ pyarrow==4.0.1
80
+ pyasn1-modules==0.2.8
81
+ pyasn1==0.4.8
82
+ pycodestyle==2.7.0
83
+ pyflakes==2.3.1
84
+ pygments==2.9.0
85
+ pyparsing==2.4.7
86
+ python-dateutil==2.8.1
87
+ pytz==2021.1
88
+ pyyaml==5.4.1
89
+ regex==2021.7.1
90
+ requests-oauthlib==1.3.0
91
+ requests==2.25.1
92
+ rsa==4.7.2
93
+ sacremoses==0.0.45
94
+ scipy==1.7.0
95
+ sentry-sdk==1.3.0
96
+ setuptools==44.0.0
97
+ shortuuid==1.0.1
98
+ six==1.15.0
99
+ smmap==4.0.0
100
+ subprocess32==3.5.4
101
+ tensorboard-data-server==0.6.1
102
+ tensorboard-plugin-wit==1.8.0
103
+ tensorboard==2.5.0
104
+ tensorflow-estimator==2.5.0
105
+ tensorflow==2.5.0
106
+ termcolor==1.1.0
107
+ tokenizers==0.10.3
108
+ toml==0.10.2
109
+ toolz==0.11.1
110
+ torch==1.9.0
111
+ tqdm==4.61.1
112
+ traitlets==5.0.5
113
+ transformers==4.9.0.dev0
114
+ typing-extensions==3.7.4.3
115
+ urllib3==1.26.6
116
+ virtualenv==20.4.7
117
+ wandb==0.10.33
118
+ wcwidth==0.2.5
119
+ werkzeug==2.0.1
120
+ wheel==0.36.2
121
+ wrapt==1.12.1
122
+ xxhash==2.0.2
123
+ yarl==1.6.3
scripts/wandb/run-20210715_080856-2mpx5n1j/files/wandb-metadata.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2021-07-15T08:09:00.134255",
5
+ "startedAt": "2021-07-15T08:08:56.269238",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--output_dir=../gpt-2-tamil",
11
+ "--model_type=gpt2",
12
+ "--config_name=../gpt-2-tamil",
13
+ "--tokenizer_name=../gpt-2-tamil",
14
+ "--dataset_name=oscar",
15
+ "--dataset_config_name=unshuffled_deduplicated_ta",
16
+ "--do_train",
17
+ "--do_eval",
18
+ "--block_size=512",
19
+ "--per_device_train_batch_size=128",
20
+ "--per_device_eval_batch_size=128",
21
+ "--learning_rate=3e-5",
22
+ "--warmup_steps=1000",
23
+ "--adam_beta1=0.9",
24
+ "--adam_beta2=0.98",
25
+ "--weight_decay=0.01",
26
+ "--overwrite_output_dir",
27
+ "--num_train_epochs=10",
28
+ "--report_to",
29
+ "wandb",
30
+ "--run_name",
31
+ "trial",
32
+ "--logging_steps=500",
33
+ "--save_steps=2500",
34
+ "--eval_steps=2500",
35
+ "--preprocessing_num_workers=90"
36
+ ],
37
+ "state": "running",
38
+ "program": "../src/run_clm_flax.py",
39
+ "codePath": "src/run_clm_flax.py",
40
+ "git": {
41
+ "remote": "https://github.com/AbinayaM02/GPT2-Tamil.git",
42
+ "commit": "69c9b7bf75b708a8f62cf5833d1b89acf5d1760b"
43
+ },
44
+ "email": "[email protected]",
45
+ "root": "/home/tweety_abi/GPT2-Tamil",
46
+ "host": "t1v-n-ebe36c53-w-0",
47
+ "username": "tweety_abi",
48
+ "executable": "/home/tweety_abi/gpt2_env/bin/python"
49
+ }
scripts/wandb/run-20210715_080856-2mpx5n1j/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
scripts/wandb/run-20210715_080856-2mpx5n1j/run-2mpx5n1j.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad2816f7f07dec6835ab15fdfb6fa81ca124f1b3f1dfbaccb9b2f3658286d158
3
+ size 38211
scripts/wandb/run-20210715_085943-1ize2alk/files/config.yaml ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ __cached__setup_devices:
4
+ desc: null
5
+ value: cpu
6
+ _n_gpu:
7
+ desc: null
8
+ value: 0
9
+ _wandb:
10
+ desc: null
11
+ value:
12
+ cli_version: 0.10.33
13
+ framework: huggingface
14
+ huggingface_version: 4.9.0.dev0
15
+ is_jupyter_run: false
16
+ is_kaggle_kernel: false
17
+ python_version: 3.8.10
18
+ t:
19
+ 1:
20
+ - 1
21
+ - 3
22
+ - 11
23
+ 4: 3.8.10
24
+ 5: 0.10.33
25
+ 6: 4.9.0.dev0
26
+ 8:
27
+ - 5
28
+ adafactor:
29
+ desc: null
30
+ value: false
31
+ adam_beta1:
32
+ desc: null
33
+ value: 0.9
34
+ adam_beta2:
35
+ desc: null
36
+ value: 0.98
37
+ adam_epsilon:
38
+ desc: null
39
+ value: 1.0e-08
40
+ block_size:
41
+ desc: null
42
+ value: 512
43
+ cache_dir:
44
+ desc: null
45
+ value: null
46
+ config_name:
47
+ desc: null
48
+ value: ../gpt-2-tamil
49
+ dataloader_drop_last:
50
+ desc: null
51
+ value: false
52
+ dataloader_num_workers:
53
+ desc: null
54
+ value: 0
55
+ dataloader_pin_memory:
56
+ desc: null
57
+ value: true
58
+ dataset_config_name:
59
+ desc: null
60
+ value: unshuffled_deduplicated_ta
61
+ dataset_name:
62
+ desc: null
63
+ value: oscar
64
+ ddp_find_unused_parameters:
65
+ desc: null
66
+ value: null
67
+ debug:
68
+ desc: null
69
+ value: []
70
+ deepspeed:
71
+ desc: null
72
+ value: null
73
+ disable_tqdm:
74
+ desc: null
75
+ value: false
76
+ do_eval:
77
+ desc: null
78
+ value: true
79
+ do_predict:
80
+ desc: null
81
+ value: false
82
+ do_train:
83
+ desc: null
84
+ value: true
85
+ dtype:
86
+ desc: null
87
+ value: float32
88
+ eval_accumulation_steps:
89
+ desc: null
90
+ value: null
91
+ eval_steps:
92
+ desc: null
93
+ value: 2500
94
+ evaluation_strategy:
95
+ desc: null
96
+ value: IntervalStrategy.NO
97
+ fp16:
98
+ desc: null
99
+ value: false
100
+ fp16_backend:
101
+ desc: null
102
+ value: auto
103
+ fp16_full_eval:
104
+ desc: null
105
+ value: false
106
+ fp16_opt_level:
107
+ desc: null
108
+ value: O1
109
+ gradient_accumulation_steps:
110
+ desc: null
111
+ value: 1
112
+ greater_is_better:
113
+ desc: null
114
+ value: null
115
+ group_by_length:
116
+ desc: null
117
+ value: false
118
+ ignore_data_skip:
119
+ desc: null
120
+ value: false
121
+ label_names:
122
+ desc: null
123
+ value: null
124
+ label_smoothing_factor:
125
+ desc: null
126
+ value: 0.0
127
+ learning_rate:
128
+ desc: null
129
+ value: 3.0e-05
130
+ length_column_name:
131
+ desc: null
132
+ value: length
133
+ load_best_model_at_end:
134
+ desc: null
135
+ value: false
136
+ local_rank:
137
+ desc: null
138
+ value: -1
139
+ log_level:
140
+ desc: null
141
+ value: -1
142
+ log_level_replica:
143
+ desc: null
144
+ value: -1
145
+ log_on_each_node:
146
+ desc: null
147
+ value: true
148
+ logging_dir:
149
+ desc: null
150
+ value: ../gpt-2-tamil/runs/Jul15_07-55-49_t1v-n-ebe36c53-w-0
151
+ logging_first_step:
152
+ desc: null
153
+ value: false
154
+ logging_steps:
155
+ desc: null
156
+ value: 500
157
+ logging_strategy:
158
+ desc: null
159
+ value: IntervalStrategy.STEPS
160
+ lr_scheduler_type:
161
+ desc: null
162
+ value: SchedulerType.LINEAR
163
+ max_eval_samples:
164
+ desc: null
165
+ value: null
166
+ max_grad_norm:
167
+ desc: null
168
+ value: 1.0
169
+ max_steps:
170
+ desc: null
171
+ value: -1
172
+ max_train_samples:
173
+ desc: null
174
+ value: null
175
+ metric_for_best_model:
176
+ desc: null
177
+ value: null
178
+ model_name_or_path:
179
+ desc: null
180
+ value: null
181
+ model_type:
182
+ desc: null
183
+ value: gpt2
184
+ mp_parameters:
185
+ desc: null
186
+ value: ''
187
+ no_cuda:
188
+ desc: null
189
+ value: false
190
+ num_train_epochs:
191
+ desc: null
192
+ value: 10.0
193
+ output_dir:
194
+ desc: null
195
+ value: ../gpt-2-tamil
196
+ overwrite_cache:
197
+ desc: null
198
+ value: false
199
+ overwrite_output_dir:
200
+ desc: null
201
+ value: true
202
+ past_index:
203
+ desc: null
204
+ value: -1
205
+ per_device_eval_batch_size:
206
+ desc: null
207
+ value: 128
208
+ per_device_train_batch_size:
209
+ desc: null
210
+ value: 128
211
+ per_gpu_eval_batch_size:
212
+ desc: null
213
+ value: null
214
+ per_gpu_train_batch_size:
215
+ desc: null
216
+ value: null
217
+ prediction_loss_only:
218
+ desc: null
219
+ value: false
220
+ preprocessing_num_workers:
221
+ desc: null
222
+ value: 90
223
+ push_to_hub:
224
+ desc: null
225
+ value: false
226
+ push_to_hub_model_id:
227
+ desc: null
228
+ value: gpt-2-tamil
229
+ push_to_hub_organization:
230
+ desc: null
231
+ value: null
232
+ push_to_hub_token:
233
+ desc: null
234
+ value: null
235
+ remove_unused_columns:
236
+ desc: null
237
+ value: true
238
+ report_to:
239
+ desc: null
240
+ value:
241
+ - wandb
242
+ resume_from_checkpoint:
243
+ desc: null
244
+ value: null
245
+ run_name:
246
+ desc: null
247
+ value: trial
248
+ save_on_each_node:
249
+ desc: null
250
+ value: false
251
+ save_steps:
252
+ desc: null
253
+ value: 2500
254
+ save_strategy:
255
+ desc: null
256
+ value: IntervalStrategy.STEPS
257
+ save_total_limit:
258
+ desc: null
259
+ value: null
260
+ seed:
261
+ desc: null
262
+ value: 42
263
+ sharded_ddp:
264
+ desc: null
265
+ value: []
266
+ skip_memory_metrics:
267
+ desc: null
268
+ value: true
269
+ tokenizer_name:
270
+ desc: null
271
+ value: ../gpt-2-tamil
272
+ tpu_metrics_debug:
273
+ desc: null
274
+ value: false
275
+ tpu_num_cores:
276
+ desc: null
277
+ value: null
278
+ train_file:
279
+ desc: null
280
+ value: null
281
+ use_fast_tokenizer:
282
+ desc: null
283
+ value: true
284
+ use_legacy_prediction_loop:
285
+ desc: null
286
+ value: false
287
+ validation_file:
288
+ desc: null
289
+ value: null
290
+ validation_split_percentage:
291
+ desc: null
292
+ value: 5
293
+ warmup_ratio:
294
+ desc: null
295
+ value: 0.0
296
+ warmup_steps:
297
+ desc: null
298
+ value: 1000
299
+ weight_decay:
300
+ desc: null
301
+ value: 0.01
scripts/wandb/run-20210715_085943-1ize2alk/files/events.out.tfevents.1626339585.t1v-n-ebe36c53-w-0.759145.3.v2 ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/tweety_abi/GPT2-Tamil/gpt-2-tamil/events.out.tfevents.1626339585.t1v-n-ebe36c53-w-0.759145.3.v2
scripts/wandb/run-20210715_085943-1ize2alk/files/requirements.txt ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==0.13.0
2
+ aiohttp==3.7.4.post0
3
+ appdirs==1.4.4
4
+ astunparse==1.6.3
5
+ async-timeout==3.0.1
6
+ attrs==21.2.0
7
+ backcall==0.2.0
8
+ black==21.6b0
9
+ cachetools==4.2.2
10
+ certifi==2021.5.30
11
+ cfgv==3.3.0
12
+ chardet==4.0.0
13
+ chex==0.0.7
14
+ click==8.0.1
15
+ configparser==5.0.2
16
+ cycler==0.10.0
17
+ datasets==1.8.1.dev0
18
+ decorator==5.0.9
19
+ dill==0.3.4
20
+ distlib==0.3.2
21
+ dm-tree==0.1.6
22
+ docker-pycreds==0.4.0
23
+ filelock==3.0.12
24
+ flake8==3.9.2
25
+ flatbuffers==1.12
26
+ flax==0.3.4
27
+ fsspec==2021.6.1
28
+ gast==0.4.0
29
+ gitdb==4.0.7
30
+ gitpython==3.1.18
31
+ google-auth-oauthlib==0.4.4
32
+ google-auth==1.32.1
33
+ google-pasta==0.2.0
34
+ grpcio==1.34.1
35
+ h5py==3.1.0
36
+ huggingface-hub==0.0.12
37
+ identify==2.2.10
38
+ idna==2.10
39
+ ipython-genutils==0.2.0
40
+ ipython==7.25.0
41
+ isort==5.9.1
42
+ jax==0.2.16
43
+ jaxlib==0.1.68
44
+ jedi==0.18.0
45
+ joblib==1.0.1
46
+ keras-nightly==2.5.0.dev2021032900
47
+ keras-preprocessing==1.1.2
48
+ kiwisolver==1.3.1
49
+ libtpu-nightly==0.1.dev20210615
50
+ markdown==3.3.4
51
+ matplotlib-inline==0.1.2
52
+ matplotlib==3.4.2
53
+ mccabe==0.6.1
54
+ msgpack==1.0.2
55
+ multidict==5.1.0
56
+ multiprocess==0.70.12.2
57
+ mypy-extensions==0.4.3
58
+ nodeenv==1.6.0
59
+ numpy==1.19.5
60
+ oauthlib==3.1.1
61
+ opt-einsum==3.3.0
62
+ optax==0.0.8
63
+ packaging==20.9
64
+ pandas==1.2.5
65
+ parso==0.8.2
66
+ pathspec==0.8.1
67
+ pathtools==0.1.2
68
+ pexpect==4.8.0
69
+ pickleshare==0.7.5
70
+ pillow==8.3.0
71
+ pip==20.0.2
72
+ pkg-resources==0.0.0
73
+ pre-commit==2.13.0
74
+ promise==2.3
75
+ prompt-toolkit==3.0.19
76
+ protobuf==3.17.3
77
+ psutil==5.8.0
78
+ ptyprocess==0.7.0
79
+ pyarrow==4.0.1
80
+ pyasn1-modules==0.2.8
81
+ pyasn1==0.4.8
82
+ pycodestyle==2.7.0
83
+ pyflakes==2.3.1
84
+ pygments==2.9.0
85
+ pyparsing==2.4.7
86
+ python-dateutil==2.8.1
87
+ pytz==2021.1
88
+ pyyaml==5.4.1
89
+ regex==2021.7.1
90
+ requests-oauthlib==1.3.0
91
+ requests==2.25.1
92
+ rsa==4.7.2
93
+ sacremoses==0.0.45
94
+ scipy==1.7.0
95
+ sentry-sdk==1.3.0
96
+ setuptools==44.0.0
97
+ shortuuid==1.0.1
98
+ six==1.15.0
99
+ smmap==4.0.0
100
+ subprocess32==3.5.4
101
+ tensorboard-data-server==0.6.1
102
+ tensorboard-plugin-wit==1.8.0
103
+ tensorboard==2.5.0
104
+ tensorflow-estimator==2.5.0
105
+ tensorflow==2.5.0
106
+ termcolor==1.1.0
107
+ tokenizers==0.10.3
108
+ toml==0.10.2
109
+ toolz==0.11.1
110
+ torch==1.9.0
111
+ tqdm==4.61.1
112
+ traitlets==5.0.5
113
+ transformers==4.9.0.dev0
114
+ typing-extensions==3.7.4.3
115
+ urllib3==1.26.6
116
+ virtualenv==20.4.7
117
+ wandb==0.10.33
118
+ wcwidth==0.2.5
119
+ werkzeug==2.0.1
120
+ wheel==0.36.2
121
+ wrapt==1.12.1
122
+ xxhash==2.0.2
123
+ yarl==1.6.3
scripts/wandb/run-20210715_085943-1ize2alk/files/wandb-metadata.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2021-07-15T08:59:45.122600",
5
+ "startedAt": "2021-07-15T08:59:43.232731",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--output_dir=../gpt-2-tamil",
11
+ "--model_type=gpt2",
12
+ "--config_name=../gpt-2-tamil",
13
+ "--tokenizer_name=../gpt-2-tamil",
14
+ "--dataset_name=oscar",
15
+ "--dataset_config_name=unshuffled_deduplicated_ta",
16
+ "--do_train",
17
+ "--do_eval",
18
+ "--block_size=512",
19
+ "--per_device_train_batch_size=128",
20
+ "--per_device_eval_batch_size=128",
21
+ "--learning_rate=3e-5",
22
+ "--warmup_steps=1000",
23
+ "--adam_beta1=0.9",
24
+ "--adam_beta2=0.98",
25
+ "--weight_decay=0.01",
26
+ "--overwrite_output_dir",
27
+ "--num_train_epochs=10",
28
+ "--report_to",
29
+ "wandb",
30
+ "--run_name",
31
+ "trial",
32
+ "--logging_steps=500",
33
+ "--save_steps=2500",
34
+ "--eval_steps=2500",
35
+ "--preprocessing_num_workers=90"
36
+ ],
37
+ "state": "running",
38
+ "program": "../src/run_clm_flax.py",
39
+ "codePath": "src/run_clm_flax.py",
40
+ "git": {
41
+ "remote": "https://github.com/AbinayaM02/GPT2-Tamil.git",
42
+ "commit": "69c9b7bf75b708a8f62cf5833d1b89acf5d1760b"
43
+ },
44
+ "email": "[email protected]",
45
+ "root": "/home/tweety_abi/GPT2-Tamil",
46
+ "host": "t1v-n-ebe36c53-w-0",
47
+ "username": "tweety_abi",
48
+ "executable": "/home/tweety_abi/gpt2_env/bin/python"
49
+ }
scripts/wandb/run-20210715_085943-1ize2alk/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
scripts/wandb/run-20210715_085943-1ize2alk/run-1ize2alk.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ddd483c4184ad35f642b4c9ddd01c8f4915a2cd4d811fb5e6395adec23ec07e
3
+ size 11149
scripts/wandb/run-20210715_091856-2v0tf7h4/files/config.yaml ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ __cached__setup_devices:
4
+ desc: null
5
+ value: cpu
6
+ _n_gpu:
7
+ desc: null
8
+ value: 0
9
+ _wandb:
10
+ desc: null
11
+ value:
12
+ cli_version: 0.10.33
13
+ framework: huggingface
14
+ huggingface_version: 4.9.0.dev0
15
+ is_jupyter_run: false
16
+ is_kaggle_kernel: false
17
+ python_version: 3.8.10
18
+ t:
19
+ 1:
20
+ - 1
21
+ - 3
22
+ - 11
23
+ 2:
24
+ - 1
25
+ - 3
26
+ - 11
27
+ 4: 3.8.10
28
+ 5: 0.10.33
29
+ 6: 4.9.0.dev0
30
+ 8:
31
+ - 5
32
+ adafactor:
33
+ desc: null
34
+ value: false
35
+ adam_beta1:
36
+ desc: null
37
+ value: 0.9
38
+ adam_beta2:
39
+ desc: null
40
+ value: 0.98
41
+ adam_epsilon:
42
+ desc: null
43
+ value: 1.0e-08
44
+ block_size:
45
+ desc: null
46
+ value: 512
47
+ cache_dir:
48
+ desc: null
49
+ value: null
50
+ config_name:
51
+ desc: null
52
+ value: ../gpt-2-tamil
53
+ dataloader_drop_last:
54
+ desc: null
55
+ value: false
56
+ dataloader_num_workers:
57
+ desc: null
58
+ value: 0
59
+ dataloader_pin_memory:
60
+ desc: null
61
+ value: true
62
+ dataset_config_name:
63
+ desc: null
64
+ value: unshuffled_deduplicated_ta
65
+ dataset_name:
66
+ desc: null
67
+ value: oscar
68
+ ddp_find_unused_parameters:
69
+ desc: null
70
+ value: null
71
+ debug:
72
+ desc: null
73
+ value: []
74
+ deepspeed:
75
+ desc: null
76
+ value: null
77
+ disable_tqdm:
78
+ desc: null
79
+ value: false
80
+ do_eval:
81
+ desc: null
82
+ value: true
83
+ do_predict:
84
+ desc: null
85
+ value: false
86
+ do_train:
87
+ desc: null
88
+ value: true
89
+ dtype:
90
+ desc: null
91
+ value: float32
92
+ eval_accumulation_steps:
93
+ desc: null
94
+ value: null
95
+ eval_steps:
96
+ desc: null
97
+ value: 2500
98
+ evaluation_strategy:
99
+ desc: null
100
+ value: IntervalStrategy.NO
101
+ fp16:
102
+ desc: null
103
+ value: false
104
+ fp16_backend:
105
+ desc: null
106
+ value: auto
107
+ fp16_full_eval:
108
+ desc: null
109
+ value: false
110
+ fp16_opt_level:
111
+ desc: null
112
+ value: O1
113
+ gradient_accumulation_steps:
114
+ desc: null
115
+ value: 1
116
+ greater_is_better:
117
+ desc: null
118
+ value: null
119
+ group_by_length:
120
+ desc: null
121
+ value: false
122
+ ignore_data_skip:
123
+ desc: null
124
+ value: false
125
+ label_names:
126
+ desc: null
127
+ value: null
128
+ label_smoothing_factor:
129
+ desc: null
130
+ value: 0.0
131
+ learning_rate:
132
+ desc: null
133
+ value: 3.0e-05
134
+ length_column_name:
135
+ desc: null
136
+ value: length
137
+ load_best_model_at_end:
138
+ desc: null
139
+ value: false
140
+ local_rank:
141
+ desc: null
142
+ value: -1
143
+ log_level:
144
+ desc: null
145
+ value: -1
146
+ log_level_replica:
147
+ desc: null
148
+ value: -1
149
+ log_on_each_node:
150
+ desc: null
151
+ value: true
152
+ logging_dir:
153
+ desc: null
154
+ value: ../gpt-2-tamil/runs/Jul15_09-16-14_t1v-n-ebe36c53-w-0
155
+ logging_first_step:
156
+ desc: null
157
+ value: false
158
+ logging_steps:
159
+ desc: null
160
+ value: 500
161
+ logging_strategy:
162
+ desc: null
163
+ value: IntervalStrategy.STEPS
164
+ lr_scheduler_type:
165
+ desc: null
166
+ value: SchedulerType.LINEAR
167
+ max_eval_samples:
168
+ desc: null
169
+ value: null
170
+ max_grad_norm:
171
+ desc: null
172
+ value: 1.0
173
+ max_steps:
174
+ desc: null
175
+ value: -1
176
+ max_train_samples:
177
+ desc: null
178
+ value: null
179
+ metric_for_best_model:
180
+ desc: null
181
+ value: null
182
+ model_name_or_path:
183
+ desc: null
184
+ value: null
185
+ model_type:
186
+ desc: null
187
+ value: gpt2
188
+ mp_parameters:
189
+ desc: null
190
+ value: ''
191
+ no_cuda:
192
+ desc: null
193
+ value: false
194
+ num_train_epochs:
195
+ desc: null
196
+ value: 10.0
197
+ output_dir:
198
+ desc: null
199
+ value: ../gpt-2-tamil
200
+ overwrite_cache:
201
+ desc: null
202
+ value: false
203
+ overwrite_output_dir:
204
+ desc: null
205
+ value: true
206
+ past_index:
207
+ desc: null
208
+ value: -1
209
+ per_device_eval_batch_size:
210
+ desc: null
211
+ value: 128
212
+ per_device_train_batch_size:
213
+ desc: null
214
+ value: 128
215
+ per_gpu_eval_batch_size:
216
+ desc: null
217
+ value: null
218
+ per_gpu_train_batch_size:
219
+ desc: null
220
+ value: null
221
+ prediction_loss_only:
222
+ desc: null
223
+ value: false
224
+ preprocessing_num_workers:
225
+ desc: null
226
+ value: 90
227
+ push_to_hub:
228
+ desc: null
229
+ value: false
230
+ push_to_hub_model_id:
231
+ desc: null
232
+ value: gpt-2-tamil
233
+ push_to_hub_organization:
234
+ desc: null
235
+ value: null
236
+ push_to_hub_token:
237
+ desc: null
238
+ value: null
239
+ remove_unused_columns:
240
+ desc: null
241
+ value: true
242
+ report_to:
243
+ desc: null
244
+ value:
245
+ - wandb
246
+ resume_from_checkpoint:
247
+ desc: null
248
+ value: null
249
+ run_name:
250
+ desc: null
251
+ value: trial
252
+ save_on_each_node:
253
+ desc: null
254
+ value: false
255
+ save_steps:
256
+ desc: null
257
+ value: 2500
258
+ save_strategy:
259
+ desc: null
260
+ value: IntervalStrategy.STEPS
261
+ save_total_limit:
262
+ desc: null
263
+ value: null
264
+ seed:
265
+ desc: null
266
+ value: 42
267
+ sharded_ddp:
268
+ desc: null
269
+ value: []
270
+ skip_memory_metrics:
271
+ desc: null
272
+ value: true
273
+ tokenizer_name:
274
+ desc: null
275
+ value: ../gpt-2-tamil
276
+ tpu_metrics_debug:
277
+ desc: null
278
+ value: false
279
+ tpu_num_cores:
280
+ desc: null
281
+ value: null
282
+ train_file:
283
+ desc: null
284
+ value: null
285
+ use_fast_tokenizer:
286
+ desc: null
287
+ value: true
288
+ use_legacy_prediction_loop:
289
+ desc: null
290
+ value: false
291
+ validation_file:
292
+ desc: null
293
+ value: null
294
+ validation_split_percentage:
295
+ desc: null
296
+ value: 5
297
+ warmup_ratio:
298
+ desc: null
299
+ value: 0.0
300
+ warmup_steps:
301
+ desc: null
302
+ value: 1000
303
+ weight_decay:
304
+ desc: null
305
+ value: 0.01
scripts/wandb/run-20210715_091856-2v0tf7h4/files/events.out.tfevents.1626340740.t1v-n-ebe36c53-w-0.765413.3.v2 ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/tweety_abi/GPT2-Tamil/gpt-2-tamil/events.out.tfevents.1626340740.t1v-n-ebe36c53-w-0.765413.3.v2
scripts/wandb/run-20210715_091856-2v0tf7h4/files/requirements.txt ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==0.13.0
2
+ aiohttp==3.7.4.post0
3
+ appdirs==1.4.4
4
+ astunparse==1.6.3
5
+ async-timeout==3.0.1
6
+ attrs==21.2.0
7
+ backcall==0.2.0
8
+ black==21.6b0
9
+ cachetools==4.2.2
10
+ certifi==2021.5.30
11
+ cfgv==3.3.0
12
+ chardet==4.0.0
13
+ chex==0.0.7
14
+ click==8.0.1
15
+ configparser==5.0.2
16
+ cycler==0.10.0
17
+ datasets==1.8.1.dev0
18
+ decorator==5.0.9
19
+ dill==0.3.4
20
+ distlib==0.3.2
21
+ dm-tree==0.1.6
22
+ docker-pycreds==0.4.0
23
+ filelock==3.0.12
24
+ flake8==3.9.2
25
+ flatbuffers==1.12
26
+ flax==0.3.4
27
+ fsspec==2021.6.1
28
+ gast==0.4.0
29
+ gitdb==4.0.7
30
+ gitpython==3.1.18
31
+ google-auth-oauthlib==0.4.4
32
+ google-auth==1.32.1
33
+ google-pasta==0.2.0
34
+ grpcio==1.34.1
35
+ h5py==3.1.0
36
+ huggingface-hub==0.0.12
37
+ identify==2.2.10
38
+ idna==2.10
39
+ ipython-genutils==0.2.0
40
+ ipython==7.25.0
41
+ isort==5.9.1
42
+ jax==0.2.16
43
+ jaxlib==0.1.68
44
+ jedi==0.18.0
45
+ joblib==1.0.1
46
+ keras-nightly==2.5.0.dev2021032900
47
+ keras-preprocessing==1.1.2
48
+ kiwisolver==1.3.1
49
+ libtpu-nightly==0.1.dev20210615
50
+ markdown==3.3.4
51
+ matplotlib-inline==0.1.2
52
+ matplotlib==3.4.2
53
+ mccabe==0.6.1
54
+ msgpack==1.0.2
55
+ multidict==5.1.0
56
+ multiprocess==0.70.12.2
57
+ mypy-extensions==0.4.3
58
+ nodeenv==1.6.0
59
+ numpy==1.19.5
60
+ oauthlib==3.1.1
61
+ opt-einsum==3.3.0
62
+ optax==0.0.8
63
+ packaging==20.9
64
+ pandas==1.2.5
65
+ parso==0.8.2
66
+ pathspec==0.8.1
67
+ pathtools==0.1.2
68
+ pexpect==4.8.0
69
+ pickleshare==0.7.5
70
+ pillow==8.3.0
71
+ pip==20.0.2
72
+ pkg-resources==0.0.0
73
+ pre-commit==2.13.0
74
+ promise==2.3
75
+ prompt-toolkit==3.0.19
76
+ protobuf==3.17.3
77
+ psutil==5.8.0
78
+ ptyprocess==0.7.0
79
+ pyarrow==4.0.1
80
+ pyasn1-modules==0.2.8
81
+ pyasn1==0.4.8
82
+ pycodestyle==2.7.0
83
+ pyflakes==2.3.1
84
+ pygments==2.9.0
85
+ pyparsing==2.4.7
86
+ python-dateutil==2.8.1
87
+ pytz==2021.1
88
+ pyyaml==5.4.1
89
+ regex==2021.7.1
90
+ requests-oauthlib==1.3.0
91
+ requests==2.25.1
92
+ rsa==4.7.2
93
+ sacremoses==0.0.45
94
+ scipy==1.7.0
95
+ sentry-sdk==1.3.0
96
+ setuptools==44.0.0
97
+ shortuuid==1.0.1
98
+ six==1.15.0
99
+ smmap==4.0.0
100
+ subprocess32==3.5.4
101
+ tensorboard-data-server==0.6.1
102
+ tensorboard-plugin-wit==1.8.0
103
+ tensorboard==2.5.0
104
+ tensorflow-estimator==2.5.0
105
+ tensorflow==2.5.0
106
+ termcolor==1.1.0
107
+ tokenizers==0.10.3
108
+ toml==0.10.2
109
+ toolz==0.11.1
110
+ torch==1.9.0
111
+ tqdm==4.61.1
112
+ traitlets==5.0.5
113
+ transformers==4.9.0.dev0
114
+ typing-extensions==3.7.4.3
115
+ urllib3==1.26.6
116
+ virtualenv==20.4.7
117
+ wandb==0.10.33
118
+ wcwidth==0.2.5
119
+ werkzeug==2.0.1
120
+ wheel==0.36.2
121
+ wrapt==1.12.1
122
+ xxhash==2.0.2
123
+ yarl==1.6.3
scripts/wandb/run-20210715_091856-2v0tf7h4/files/wandb-metadata.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2021-07-15T09:19:00.102585",
5
+ "startedAt": "2021-07-15T09:18:56.277815",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--output_dir=../gpt-2-tamil",
11
+ "--model_type=gpt2",
12
+ "--config_name=../gpt-2-tamil",
13
+ "--tokenizer_name=../gpt-2-tamil",
14
+ "--dataset_name=oscar",
15
+ "--dataset_config_name=unshuffled_deduplicated_ta",
16
+ "--do_train",
17
+ "--do_eval",
18
+ "--block_size=512",
19
+ "--per_device_train_batch_size=128",
20
+ "--per_device_eval_batch_size=128",
21
+ "--learning_rate=3e-5",
22
+ "--warmup_steps=1000",
23
+ "--adam_beta1=0.9",
24
+ "--adam_beta2=0.98",
25
+ "--weight_decay=0.01",
26
+ "--overwrite_output_dir",
27
+ "--num_train_epochs=10",
28
+ "--report_to",
29
+ "wandb",
30
+ "--run_name",
31
+ "trial",
32
+ "--logging_steps=500",
33
+ "--save_steps=2500",
34
+ "--eval_steps=2500",
35
+ "--preprocessing_num_workers=90"
36
+ ],
37
+ "state": "running",
38
+ "program": "../src/run_clm_flax.py",
39
+ "codePath": "src/run_clm_flax.py",
40
+ "git": {
41
+ "remote": "https://github.com/AbinayaM02/GPT2-Tamil.git",
42
+ "commit": "69c9b7bf75b708a8f62cf5833d1b89acf5d1760b"
43
+ },
44
+ "email": "[email protected]",
45
+ "root": "/home/tweety_abi/GPT2-Tamil",
46
+ "host": "t1v-n-ebe36c53-w-0",
47
+ "username": "tweety_abi",
48
+ "executable": "/home/tweety_abi/gpt2_env/bin/python"
49
+ }
scripts/wandb/run-20210715_091856-2v0tf7h4/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
scripts/wandb/run-20210715_091856-2v0tf7h4/run-2v0tf7h4.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74c24264810cc8a5625c9a6fd0093d95ea89e0980f556fce2e873e00ba0254c5
3
+ size 38212
scripts/wandb/run-20210715_092837-watdq7ib/files/config.yaml ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ __cached__setup_devices:
4
+ desc: null
5
+ value: cpu
6
+ _n_gpu:
7
+ desc: null
8
+ value: 0
9
+ _wandb:
10
+ desc: null
11
+ value:
12
+ cli_version: 0.10.33
13
+ framework: huggingface
14
+ huggingface_version: 4.9.0.dev0
15
+ is_jupyter_run: false
16
+ is_kaggle_kernel: false
17
+ python_version: 3.8.10
18
+ t:
19
+ 1:
20
+ - 1
21
+ - 3
22
+ - 11
23
+ 4: 3.8.10
24
+ 5: 0.10.33
25
+ 6: 4.9.0.dev0
26
+ 8:
27
+ - 5
28
+ adafactor:
29
+ desc: null
30
+ value: false
31
+ adam_beta1:
32
+ desc: null
33
+ value: 0.9
34
+ adam_beta2:
35
+ desc: null
36
+ value: 0.98
37
+ adam_epsilon:
38
+ desc: null
39
+ value: 1.0e-08
40
+ block_size:
41
+ desc: null
42
+ value: 512
43
+ cache_dir:
44
+ desc: null
45
+ value: null
46
+ config_name:
47
+ desc: null
48
+ value: ../gpt-2-tamil
49
+ dataloader_drop_last:
50
+ desc: null
51
+ value: false
52
+ dataloader_num_workers:
53
+ desc: null
54
+ value: 0
55
+ dataloader_pin_memory:
56
+ desc: null
57
+ value: true
58
+ dataset_config_name:
59
+ desc: null
60
+ value: unshuffled_deduplicated_ta
61
+ dataset_name:
62
+ desc: null
63
+ value: oscar
64
+ ddp_find_unused_parameters:
65
+ desc: null
66
+ value: null
67
+ debug:
68
+ desc: null
69
+ value: []
70
+ deepspeed:
71
+ desc: null
72
+ value: null
73
+ disable_tqdm:
74
+ desc: null
75
+ value: false
76
+ do_eval:
77
+ desc: null
78
+ value: true
79
+ do_predict:
80
+ desc: null
81
+ value: false
82
+ do_train:
83
+ desc: null
84
+ value: true
85
+ dtype:
86
+ desc: null
87
+ value: float32
88
+ eval_accumulation_steps:
89
+ desc: null
90
+ value: null
91
+ eval_steps:
92
+ desc: null
93
+ value: 2500
94
+ evaluation_strategy:
95
+ desc: null
96
+ value: IntervalStrategy.NO
97
+ fp16:
98
+ desc: null
99
+ value: false
100
+ fp16_backend:
101
+ desc: null
102
+ value: auto
103
+ fp16_full_eval:
104
+ desc: null
105
+ value: false
106
+ fp16_opt_level:
107
+ desc: null
108
+ value: O1
109
+ gradient_accumulation_steps:
110
+ desc: null
111
+ value: 1
112
+ greater_is_better:
113
+ desc: null
114
+ value: null
115
+ group_by_length:
116
+ desc: null
117
+ value: false
118
+ ignore_data_skip:
119
+ desc: null
120
+ value: false
121
+ label_names:
122
+ desc: null
123
+ value: null
124
+ label_smoothing_factor:
125
+ desc: null
126
+ value: 0.0
127
+ learning_rate:
128
+ desc: null
129
+ value: 3.0e-05
130
+ length_column_name:
131
+ desc: null
132
+ value: length
133
+ load_best_model_at_end:
134
+ desc: null
135
+ value: false
136
+ local_rank:
137
+ desc: null
138
+ value: -1
139
+ log_level:
140
+ desc: null
141
+ value: -1
142
+ log_level_replica:
143
+ desc: null
144
+ value: -1
145
+ log_on_each_node:
146
+ desc: null
147
+ value: true
148
+ logging_dir:
149
+ desc: null
150
+ value: ../gpt-2-tamil/runs/Jul15_09-27-21_t1v-n-ebe36c53-w-0
151
+ logging_first_step:
152
+ desc: null
153
+ value: false
154
+ logging_steps:
155
+ desc: null
156
+ value: 500
157
+ logging_strategy:
158
+ desc: null
159
+ value: IntervalStrategy.STEPS
160
+ lr_scheduler_type:
161
+ desc: null
162
+ value: SchedulerType.LINEAR
163
+ max_eval_samples:
164
+ desc: null
165
+ value: null
166
+ max_grad_norm:
167
+ desc: null
168
+ value: 1.0
169
+ max_steps:
170
+ desc: null
171
+ value: -1
172
+ max_train_samples:
173
+ desc: null
174
+ value: null
175
+ metric_for_best_model:
176
+ desc: null
177
+ value: null
178
+ model_name_or_path:
179
+ desc: null
180
+ value: null
181
+ model_type:
182
+ desc: null
183
+ value: gpt2
184
+ mp_parameters:
185
+ desc: null
186
+ value: ''
187
+ no_cuda:
188
+ desc: null
189
+ value: false
190
+ num_train_epochs:
191
+ desc: null
192
+ value: 10.0
193
+ output_dir:
194
+ desc: null
195
+ value: ../gpt-2-tamil
196
+ overwrite_cache:
197
+ desc: null
198
+ value: false
199
+ overwrite_output_dir:
200
+ desc: null
201
+ value: true
202
+ past_index:
203
+ desc: null
204
+ value: -1
205
+ per_device_eval_batch_size:
206
+ desc: null
207
+ value: 64
208
+ per_device_train_batch_size:
209
+ desc: null
210
+ value: 64
211
+ per_gpu_eval_batch_size:
212
+ desc: null
213
+ value: null
214
+ per_gpu_train_batch_size:
215
+ desc: null
216
+ value: null
217
+ prediction_loss_only:
218
+ desc: null
219
+ value: false
220
+ preprocessing_num_workers:
221
+ desc: null
222
+ value: 90
223
+ push_to_hub:
224
+ desc: null
225
+ value: false
226
+ push_to_hub_model_id:
227
+ desc: null
228
+ value: gpt-2-tamil
229
+ push_to_hub_organization:
230
+ desc: null
231
+ value: null
232
+ push_to_hub_token:
233
+ desc: null
234
+ value: null
235
+ remove_unused_columns:
236
+ desc: null
237
+ value: true
238
+ report_to:
239
+ desc: null
240
+ value:
241
+ - wandb
242
+ resume_from_checkpoint:
243
+ desc: null
244
+ value: null
245
+ run_name:
246
+ desc: null
247
+ value: trial
248
+ save_on_each_node:
249
+ desc: null
250
+ value: false
251
+ save_steps:
252
+ desc: null
253
+ value: 2500
254
+ save_strategy:
255
+ desc: null
256
+ value: IntervalStrategy.STEPS
257
+ save_total_limit:
258
+ desc: null
259
+ value: null
260
+ seed:
261
+ desc: null
262
+ value: 42
263
+ sharded_ddp:
264
+ desc: null
265
+ value: []
266
+ skip_memory_metrics:
267
+ desc: null
268
+ value: true
269
+ tokenizer_name:
270
+ desc: null
271
+ value: ../gpt-2-tamil
272
+ tpu_metrics_debug:
273
+ desc: null
274
+ value: false
275
+ tpu_num_cores:
276
+ desc: null
277
+ value: null
278
+ train_file:
279
+ desc: null
280
+ value: null
281
+ use_fast_tokenizer:
282
+ desc: null
283
+ value: true
284
+ use_legacy_prediction_loop:
285
+ desc: null
286
+ value: false
287
+ validation_file:
288
+ desc: null
289
+ value: null
290
+ validation_split_percentage:
291
+ desc: null
292
+ value: 5
293
+ warmup_ratio:
294
+ desc: null
295
+ value: 0.0
296
+ warmup_steps:
297
+ desc: null
298
+ value: 1000
299
+ weight_decay:
300
+ desc: null
301
+ value: 0.01
scripts/wandb/run-20210715_092837-watdq7ib/files/events.out.tfevents.1626341319.t1v-n-ebe36c53-w-0.768105.3.v2 ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/tweety_abi/GPT2-Tamil/gpt-2-tamil/events.out.tfevents.1626341319.t1v-n-ebe36c53-w-0.768105.3.v2
scripts/wandb/run-20210715_092837-watdq7ib/files/requirements.txt ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==0.13.0
2
+ aiohttp==3.7.4.post0
3
+ appdirs==1.4.4
4
+ astunparse==1.6.3
5
+ async-timeout==3.0.1
6
+ attrs==21.2.0
7
+ backcall==0.2.0
8
+ black==21.6b0
9
+ cachetools==4.2.2
10
+ certifi==2021.5.30
11
+ cfgv==3.3.0
12
+ chardet==4.0.0
13
+ chex==0.0.7
14
+ click==8.0.1
15
+ configparser==5.0.2
16
+ cycler==0.10.0
17
+ datasets==1.8.1.dev0
18
+ decorator==5.0.9
19
+ dill==0.3.4
20
+ distlib==0.3.2
21
+ dm-tree==0.1.6
22
+ docker-pycreds==0.4.0
23
+ filelock==3.0.12
24
+ flake8==3.9.2
25
+ flatbuffers==1.12
26
+ flax==0.3.4
27
+ fsspec==2021.6.1
28
+ gast==0.4.0
29
+ gitdb==4.0.7
30
+ gitpython==3.1.18
31
+ google-auth-oauthlib==0.4.4
32
+ google-auth==1.32.1
33
+ google-pasta==0.2.0
34
+ grpcio==1.34.1
35
+ h5py==3.1.0
36
+ huggingface-hub==0.0.12
37
+ identify==2.2.10
38
+ idna==2.10
39
+ ipython-genutils==0.2.0
40
+ ipython==7.25.0
41
+ isort==5.9.1
42
+ jax==0.2.16
43
+ jaxlib==0.1.68
44
+ jedi==0.18.0
45
+ joblib==1.0.1
46
+ keras-nightly==2.5.0.dev2021032900
47
+ keras-preprocessing==1.1.2
48
+ kiwisolver==1.3.1
49
+ libtpu-nightly==0.1.dev20210615
50
+ markdown==3.3.4
51
+ matplotlib-inline==0.1.2
52
+ matplotlib==3.4.2
53
+ mccabe==0.6.1
54
+ msgpack==1.0.2
55
+ multidict==5.1.0
56
+ multiprocess==0.70.12.2
57
+ mypy-extensions==0.4.3
58
+ nodeenv==1.6.0
59
+ numpy==1.19.5
60
+ oauthlib==3.1.1
61
+ opt-einsum==3.3.0
62
+ optax==0.0.8
63
+ packaging==20.9
64
+ pandas==1.2.5
65
+ parso==0.8.2
66
+ pathspec==0.8.1
67
+ pathtools==0.1.2
68
+ pexpect==4.8.0
69
+ pickleshare==0.7.5
70
+ pillow==8.3.0
71
+ pip==20.0.2
72
+ pkg-resources==0.0.0
73
+ pre-commit==2.13.0
74
+ promise==2.3
75
+ prompt-toolkit==3.0.19
76
+ protobuf==3.17.3
77
+ psutil==5.8.0
78
+ ptyprocess==0.7.0
79
+ pyarrow==4.0.1
80
+ pyasn1-modules==0.2.8
81
+ pyasn1==0.4.8
82
+ pycodestyle==2.7.0
83
+ pyflakes==2.3.1
84
+ pygments==2.9.0
85
+ pyparsing==2.4.7
86
+ python-dateutil==2.8.1
87
+ pytz==2021.1
88
+ pyyaml==5.4.1
89
+ regex==2021.7.1
90
+ requests-oauthlib==1.3.0
91
+ requests==2.25.1
92
+ rsa==4.7.2
93
+ sacremoses==0.0.45
94
+ scipy==1.7.0
95
+ sentry-sdk==1.3.0
96
+ setuptools==44.0.0
97
+ shortuuid==1.0.1
98
+ six==1.15.0
99
+ smmap==4.0.0
100
+ subprocess32==3.5.4
101
+ tensorboard-data-server==0.6.1
102
+ tensorboard-plugin-wit==1.8.0
103
+ tensorboard==2.5.0
104
+ tensorflow-estimator==2.5.0
105
+ tensorflow==2.5.0
106
+ termcolor==1.1.0
107
+ tokenizers==0.10.3
108
+ toml==0.10.2
109
+ toolz==0.11.1
110
+ torch==1.9.0
111
+ tqdm==4.61.1
112
+ traitlets==5.0.5
113
+ transformers==4.9.0.dev0
114
+ typing-extensions==3.7.4.3
115
+ urllib3==1.26.6
116
+ virtualenv==20.4.7
117
+ wandb==0.10.33
118
+ wcwidth==0.2.5
119
+ werkzeug==2.0.1
120
+ wheel==0.36.2
121
+ wrapt==1.12.1
122
+ xxhash==2.0.2
123
+ yarl==1.6.3
scripts/wandb/run-20210715_092837-watdq7ib/files/wandb-metadata.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2021-07-15T09:28:39.248463",
5
+ "startedAt": "2021-07-15T09:28:37.215410",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--output_dir=../gpt-2-tamil",
11
+ "--model_type=gpt2",
12
+ "--config_name=../gpt-2-tamil",
13
+ "--tokenizer_name=../gpt-2-tamil",
14
+ "--dataset_name=oscar",
15
+ "--dataset_config_name=unshuffled_deduplicated_ta",
16
+ "--do_train",
17
+ "--do_eval",
18
+ "--block_size=512",
19
+ "--per_device_train_batch_size=64",
20
+ "--per_device_eval_batch_size=64",
21
+ "--learning_rate=3e-5",
22
+ "--warmup_steps=1000",
23
+ "--adam_beta1=0.9",
24
+ "--adam_beta2=0.98",
25
+ "--weight_decay=0.01",
26
+ "--overwrite_output_dir",
27
+ "--num_train_epochs=10",
28
+ "--report_to",
29
+ "wandb",
30
+ "--run_name",
31
+ "trial",
32
+ "--logging_steps=500",
33
+ "--save_steps=2500",
34
+ "--eval_steps=2500",
35
+ "--preprocessing_num_workers=90"
36
+ ],
37
+ "state": "running",
38
+ "program": "../src/run_clm_flax.py",
39
+ "codePath": "src/run_clm_flax.py",
40
+ "git": {
41
+ "remote": "https://github.com/AbinayaM02/GPT2-Tamil.git",
42
+ "commit": "69c9b7bf75b708a8f62cf5833d1b89acf5d1760b"
43
+ },
44
+ "email": "[email protected]",
45
+ "root": "/home/tweety_abi/GPT2-Tamil",
46
+ "host": "t1v-n-ebe36c53-w-0",
47
+ "username": "tweety_abi",
48
+ "executable": "/home/tweety_abi/gpt2_env/bin/python"
49
+ }
scripts/wandb/run-20210715_092837-watdq7ib/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"global_step": 158500, "_timestamp": 1626511475.977111, "train_time": 3007799.75, "train_learning_rate": 2.7665698780765524e-06, "_step": 316049, "train_loss": 1.1194136142730713, "eval_loss": 1.1329445838928223, "eval_perplexity": 3.104785442352295}
scripts/wandb/run-20210715_092837-watdq7ib/run-watdq7ib.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17ccbfb69a2e91865a50d34837db9291fa2687143f65c6f6c712e23f40a46343
3
+ size 71362583
src/create_config.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config
2
+
3
+ model_dir = "../gpt-2-tamil" # ${MODEL_DIR}
4
+
5
+ config = GPT2Config.from_pretrained(
6
+ "gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0
7
+ )
8
+ config.save_pretrained(model_dir)
src/run_clm_flax.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=causal-lm
21
+ """
22
+ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
+
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import dataclass, field
30
+ from pathlib import Path
31
+ from typing import Callable, Optional
32
+
33
+ import datasets
34
+ from datasets import Dataset, load_dataset, concatenate_datasets
35
+ from tqdm import tqdm
36
+
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb
42
+ from flax import jax_utils, traverse_util
43
+ from flax.jax_utils import unreplicate
44
+ from flax.training import train_state
45
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
46
+ from transformers import (
47
+ CONFIG_MAPPING,
48
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
49
+ AutoConfig,
50
+ AutoTokenizer,
51
+ FlaxAutoModelForCausalLM,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ )
56
+ from transformers.testing_utils import CaptureLogger
57
+
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
62
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
63
+
64
+
65
+ @dataclass
66
+ class ModelArguments:
67
+ """
68
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
69
+ """
70
+
71
+ model_name_or_path: Optional[str] = field(
72
+ default=None,
73
+ metadata={
74
+ "help": "The model checkpoint for weights initialization."
75
+ "Don't set if you want to train a model from scratch."
76
+ },
77
+ )
78
+ model_type: Optional[str] = field(
79
+ default=None,
80
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
81
+ )
82
+ config_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
84
+ )
85
+ tokenizer_name: Optional[str] = field(
86
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
90
+ )
91
+ use_fast_tokenizer: bool = field(
92
+ default=True,
93
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
94
+ )
95
+ dtype: Optional[str] = field(
96
+ default="float32",
97
+ metadata={
98
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
99
+ },
100
+ )
101
+
102
+
103
+ @dataclass
104
+ class DataTrainingArguments:
105
+ """
106
+ Arguments pertaining to what data we are going to input our model for training and eval.
107
+ """
108
+
109
+ dataset_name: Optional[str] = field(
110
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
111
+ )
112
+ dataset_config_name: Optional[str] = field(
113
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
114
+ )
115
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
116
+ validation_file: Optional[str] = field(
117
+ default=None,
118
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
119
+ )
120
+ max_train_samples: Optional[int] = field(
121
+ default=None,
122
+ metadata={
123
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
124
+ "value if set."
125
+ },
126
+ )
127
+ max_eval_samples: Optional[int] = field(
128
+ default=None,
129
+ metadata={
130
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
131
+ "value if set."
132
+ },
133
+ )
134
+ overwrite_cache: bool = field(
135
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
136
+ )
137
+ validation_split_percentage: Optional[int] = field(
138
+ default=5,
139
+ metadata={
140
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
141
+ },
142
+ )
143
+ block_size: Optional[int] = field(
144
+ default=None,
145
+ metadata={
146
+ "help": "Optional input sequence length after tokenization. "
147
+ "The training dataset will be truncated in block of this size for training. "
148
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
149
+ },
150
+ )
151
+ overwrite_cache: bool = field(
152
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
153
+ )
154
+ preprocessing_num_workers: Optional[int] = field(
155
+ default=None,
156
+ metadata={"help": "The number of processes to use for the preprocessing."},
157
+ )
158
+
159
+ def __post_init__(self):
160
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
161
+ raise ValueError("Need either a dataset name or a training/validation file.")
162
+ else:
163
+ if self.train_file is not None:
164
+ extension = self.train_file.split(".")[-1]
165
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
166
+ if self.validation_file is not None:
167
+ extension = self.validation_file.split(".")[-1]
168
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
169
+
170
+
171
+ class TrainState(train_state.TrainState):
172
+ dropout_rng: jnp.ndarray
173
+
174
+ def replicate(self):
175
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
176
+
177
+
178
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
179
+ """
180
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
181
+ Shuffle batches if `shuffle` is `True`.
182
+ """
183
+ steps_per_epoch = len(dataset) // batch_size
184
+
185
+ if shuffle:
186
+ batch_idx = jax.random.permutation(rng, len(dataset))
187
+ else:
188
+ batch_idx = jnp.arange(len(dataset))
189
+
190
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
191
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
192
+
193
+ for idx in batch_idx:
194
+ batch = dataset[idx]
195
+ batch = {k: jnp.array(v) for k, v in batch.items()}
196
+
197
+ batch = shard(batch)
198
+
199
+ yield batch
200
+
201
+
202
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
203
+ summary_writer.scalar("train_time", train_time, step)
204
+
205
+ train_metrics = get_metrics(train_metrics)
206
+ for key, vals in train_metrics.items():
207
+ tag = f"train_{key}"
208
+ for i, val in enumerate(vals):
209
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
210
+
211
+
212
+ def write_eval_metric(summary_writer, eval_metrics, step):
213
+ for metric_name, value in eval_metrics.items():
214
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
215
+
216
+
217
+ def create_learning_rate_fn(
218
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
219
+ ) -> Callable[[int], jnp.array]:
220
+ """Returns a linear warmup, linear_decay learning rate function."""
221
+ steps_per_epoch = train_ds_size // train_batch_size
222
+ num_train_steps = steps_per_epoch * num_train_epochs
223
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
224
+ decay_fn = optax.linear_schedule(
225
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
226
+ )
227
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
228
+ return schedule_fn
229
+
230
+
231
+ def main():
232
+ # See all possible arguments in src/transformers/training_args.py
233
+ # or by passing the --help flag to this script.
234
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
235
+
236
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
237
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
238
+ # If we pass only one argument to the script and it's the path to a json file,
239
+ # let's parse it to get our arguments.
240
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
241
+ else:
242
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
243
+
244
+ if (
245
+ os.path.exists(training_args.output_dir)
246
+ and os.listdir(training_args.output_dir)
247
+ and training_args.do_train
248
+ and not training_args.overwrite_output_dir
249
+ ):
250
+ raise ValueError(
251
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
252
+ "Use --overwrite_output_dir to overcome."
253
+ )
254
+
255
+ # Make one log on every process with the configuration for debugging.
256
+ logging.basicConfig(
257
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
258
+ datefmt="%m/%d/%Y %H:%M:%S",
259
+ level=logging.INFO,
260
+ )
261
+ # Setup logging, we only want one process per machine to log things on the screen.
262
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
263
+ if jax.process_index() == 0:
264
+ datasets.utils.logging.set_verbosity_warning()
265
+ transformers.utils.logging.set_verbosity_info()
266
+ else:
267
+ datasets.utils.logging.set_verbosity_error()
268
+ transformers.utils.logging.set_verbosity_error()
269
+
270
+ # Set the verbosity to info of the Transformers logger (on main process only):
271
+ logger.info(f"Training/evaluation parameters {training_args}")
272
+
273
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
274
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
275
+ # (the dataset will be downloaded automatically from the datasets Hub).
276
+ #
277
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
278
+ # 'text' is found. You can easily tweak this behavior (see below).
279
+ #
280
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
281
+ # download the dataset.
282
+
283
+ #GPT-2 tamil
284
+ logger.info(f"Loading dataset....")
285
+ print("Loading indic corp tamil dataset")
286
+ indic_tamil = load_dataset("csv",data_files="/tmp/indic_corp/ta.csv")
287
+
288
+ if data_args.dataset_name is not None:
289
+ # Downloading and loading a dataset from the hub.
290
+ dataset = load_dataset(
291
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
292
+ )
293
+
294
+ if "validation" not in dataset.keys():
295
+ dataset["validation"] = load_dataset(
296
+ data_args.dataset_name,
297
+ data_args.dataset_config_name,
298
+ split=f"train[:{data_args.validation_split_percentage}%]",
299
+ cache_dir=model_args.cache_dir,
300
+ )
301
+ dataset["train"] = load_dataset(
302
+ data_args.dataset_name,
303
+ data_args.dataset_config_name,
304
+ split=f"train[{data_args.validation_split_percentage}%:]",
305
+ cache_dir=model_args.cache_dir,
306
+ )
307
+ ## GPT2-tamil - adding indic_corp dataset manually
308
+ print("Concatenating datasets")
309
+ #pdb.set_trace()
310
+ dataset['train'] = concatenate_datasets([indic_tamil['train'],dataset['train']])
311
+ else:
312
+ data_files = {}
313
+ if data_args.train_file is not None:
314
+ data_files["train"] = data_args.train_file
315
+ if data_args.validation_file is not None:
316
+ data_files["validation"] = data_args.validation_file
317
+ extension = data_args.train_file.split(".")[-1]
318
+ if extension == "txt":
319
+ extension = "text"
320
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
321
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
322
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
323
+
324
+ # Load pretrained model and tokenizer
325
+
326
+ # Distributed training:
327
+ # The .from_pretrained methods guarantee that only one local process can concurrently
328
+ # download model & vocab.
329
+ if model_args.config_name:
330
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
331
+ elif model_args.model_name_or_path:
332
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
333
+ else:
334
+ config = CONFIG_MAPPING[model_args.model_type]()
335
+ logger.warning("You are instantiating a new config instance from scratch.")
336
+
337
+ if model_args.tokenizer_name:
338
+ tokenizer = AutoTokenizer.from_pretrained(
339
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
340
+ )
341
+ elif model_args.model_name_or_path:
342
+ tokenizer = AutoTokenizer.from_pretrained(
343
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
344
+ )
345
+ else:
346
+ raise ValueError(
347
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
348
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
349
+ )
350
+
351
+ if model_args.model_name_or_path:
352
+ model = FlaxAutoModelForCausalLM.from_pretrained(
353
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
354
+ )
355
+ else:
356
+ model = FlaxAutoModelForCausalLM.from_config(
357
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
358
+ )
359
+
360
+ # Preprocessing the datasets.
361
+ # First we tokenize all the texts.
362
+ if training_args.do_train:
363
+ column_names = dataset["train"].column_names
364
+ else:
365
+ column_names = dataset["validation"].column_names
366
+ text_column_name = "text" if "text" in column_names else column_names[0]
367
+
368
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
369
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
370
+
371
+ def tokenize_function(examples):
372
+ with CaptureLogger(tok_logger) as cl:
373
+ output = tokenizer(examples[text_column_name])
374
+ # clm input could be much much longer than block_size
375
+ if "Token indices sequence length is longer than the" in cl.out:
376
+ tok_logger.warning(
377
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
378
+ )
379
+ return output
380
+
381
+ tokenized_datasets = dataset.map(
382
+ tokenize_function,
383
+ batched=True,
384
+ num_proc=data_args.preprocessing_num_workers,
385
+ remove_columns=column_names,
386
+ load_from_cache_file=not data_args.overwrite_cache,
387
+ )
388
+
389
+ if data_args.block_size is None:
390
+ block_size = tokenizer.model_max_length
391
+ if block_size > config.max_position_embeddings:
392
+ logger.warning(
393
+ f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
394
+ "Picking 1024 instead. You can change that default value by passing --block_size xxx."
395
+ )
396
+ block_size = 1024
397
+ else:
398
+ if data_args.block_size > tokenizer.model_max_length:
399
+ logger.warning(
400
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
401
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
402
+ )
403
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
404
+
405
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
406
+ def group_texts(examples):
407
+ # Concatenate all texts.
408
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
409
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
410
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
411
+ # customize this part to your needs.
412
+ if total_length >= block_size:
413
+ total_length = (total_length // block_size) * block_size
414
+ # Split by chunks of max_len.
415
+ result = {
416
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
417
+ for k, t in concatenated_examples.items()
418
+ }
419
+ result["labels"] = result["input_ids"].copy()
420
+ return result
421
+
422
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
423
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
424
+ # to preprocess.
425
+ #
426
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
427
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
428
+
429
+ lm_datasets = tokenized_datasets.map(
430
+ group_texts,
431
+ batched=True,
432
+ num_proc=data_args.preprocessing_num_workers,
433
+ load_from_cache_file=not data_args.overwrite_cache,
434
+ )
435
+
436
+ if training_args.do_train:
437
+ if "train" not in tokenized_datasets:
438
+ raise ValueError("--do_train requires a train dataset")
439
+ train_dataset = lm_datasets["train"]
440
+ if data_args.max_train_samples is not None:
441
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
442
+
443
+ if training_args.do_eval:
444
+ if "validation" not in tokenized_datasets:
445
+ raise ValueError("--do_eval requires a validation dataset")
446
+ eval_dataset = lm_datasets["validation"]
447
+ if data_args.max_eval_samples is not None:
448
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
449
+
450
+ # Enable tensorboard only on the master node
451
+ has_tensorboard = is_tensorboard_available()
452
+ if has_tensorboard and jax.process_index() == 0:
453
+ wandb.init(
454
+ entity='wandb',
455
+ project='hf-flax-gpt2-tamil',
456
+ sync_tensorboard=True
457
+ )
458
+
459
+ wandb.config.update(training_args) # optional, log your configs
460
+ wandb.config.update(model_args) # optional, log your configs
461
+ wandb.config.update(data_args) # optional, log your configs
462
+
463
+ try:
464
+ from flax.metrics.tensorboard import SummaryWriter
465
+
466
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
467
+ except ImportError as ie:
468
+ has_tensorboard = False
469
+ logger.warning(
470
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
471
+ )
472
+ else:
473
+ logger.warning(
474
+ "Unable to display metrics through TensorBoard because the package is not installed: "
475
+ "Please run pip install tensorboard to enable."
476
+ )
477
+
478
+ # Initialize our training
479
+ rng = jax.random.PRNGKey(training_args.seed)
480
+ rng, dropout_rng = jax.random.split(rng)
481
+
482
+ # Store some constant
483
+ num_epochs = int(training_args.num_train_epochs)
484
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
485
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
486
+ steps_per_epoch = len(train_dataset) // train_batch_size
487
+ total_train_steps = steps_per_epoch * num_epochs
488
+
489
+ # Create learning rate schedule
490
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
491
+ len(train_dataset),
492
+ train_batch_size,
493
+ training_args.num_train_epochs,
494
+ training_args.warmup_steps,
495
+ training_args.learning_rate,
496
+ )
497
+
498
+ # We use Optax's "masking" functionality to not apply weight decay
499
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
500
+ # mask boolean with the same structure as the parameters.
501
+ # The mask is True for parameters that should be decayed.
502
+ # Note that this mask is specifically adapted for FlaxGPT2.
503
+ # For other models, one should correct the layer norm parameter naming
504
+ # accordingly.
505
+ def decay_mask_fn(params):
506
+ flat_params = traverse_util.flatten_dict(params)
507
+ flat_mask = {
508
+ path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
509
+ for path in flat_params
510
+ }
511
+ return traverse_util.unflatten_dict(flat_mask)
512
+
513
+ # create adam optimizer
514
+ if training_args.adafactor:
515
+ # We use the default parameters here to initialize adafactor,
516
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
517
+ optimizer = optax.adafactor(
518
+ learning_rate=linear_decay_lr_schedule_fn,
519
+ )
520
+ else:
521
+ optimizer = optax.adamw(
522
+ learning_rate=linear_decay_lr_schedule_fn,
523
+ b1=training_args.adam_beta1,
524
+ b2=training_args.adam_beta2,
525
+ eps=training_args.adam_epsilon,
526
+ weight_decay=training_args.weight_decay,
527
+ mask=decay_mask_fn,
528
+ )
529
+
530
+ # Setup train state
531
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
532
+
533
+ def loss_fn(logits, labels):
534
+ shift_logits = logits[..., :-1, :]
535
+ shift_labels = labels[..., 1:]
536
+ loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
537
+ return loss.mean()
538
+
539
+ # Define gradient update step fn
540
+ def train_step(state, batch):
541
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
542
+
543
+ def compute_loss(params):
544
+ labels = batch.pop("labels")
545
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
546
+ loss = loss_fn(logits, labels)
547
+ return loss
548
+
549
+ grad_fn = jax.value_and_grad(compute_loss)
550
+ loss, grad = grad_fn(state.params)
551
+ grad = jax.lax.pmean(grad, "batch")
552
+
553
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
554
+
555
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
556
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
557
+
558
+ return new_state, metrics
559
+
560
+ # Define eval fn
561
+ def eval_step(params, batch):
562
+ labels = batch.pop("labels")
563
+ logits = model(**batch, params=params, train=False)[0]
564
+ loss = loss_fn(logits, labels)
565
+
566
+ # summarize metrics
567
+ metrics = {"loss": loss}
568
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
569
+ return metrics
570
+
571
+ # Create parallel version of the train and eval step
572
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
573
+ p_eval_step = jax.pmap(eval_step, "batch")
574
+
575
+ # Replicate the train state on each device
576
+ state = state.replicate()
577
+
578
+ logger.info("***** Running training *****")
579
+ logger.info(f" Num examples = {len(train_dataset)}")
580
+ logger.info(f" Num Epochs = {num_epochs}")
581
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
582
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
583
+ logger.info(f" Total optimization steps = {total_train_steps}")
584
+
585
+ train_time = 0
586
+ train_metrics = []
587
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
588
+ for epoch in epochs:
589
+ # ======================== Training ================================
590
+ train_start = time.time()
591
+
592
+ # Create sampling rng
593
+ rng, input_rng = jax.random.split(rng)
594
+
595
+ # Generate an epoch by shuffling sampling indices from the train dataset
596
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
597
+ steps_per_epoch = len(train_dataset) // train_batch_size
598
+ # train
599
+ for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
600
+ batch = next(train_loader)
601
+ state, train_metric = p_train_step(state, batch)
602
+ train_metrics.append(train_metric)
603
+
604
+ cur_step = epoch * (len(train_dataset) // train_batch_size) + step
605
+
606
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
607
+ # Save metrics
608
+ train_metric = unreplicate(train_metric)
609
+ train_time += time.time() - train_start
610
+ if has_tensorboard and jax.process_index() == 0:
611
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
612
+
613
+ epochs.write(
614
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
615
+ )
616
+
617
+ train_metrics = []
618
+
619
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
620
+ # ======================== Evaluating ==============================
621
+ eval_metrics = []
622
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
623
+ eval_steps = len(eval_dataset) // eval_batch_size
624
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
625
+ # Model forward
626
+ batch = next(eval_loader)
627
+ metrics = p_eval_step(state.params, batch)
628
+ eval_metrics.append(metrics)
629
+
630
+ # normalize eval metrics
631
+ eval_metrics = get_metrics(eval_metrics)
632
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
633
+
634
+ try:
635
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
636
+ except OverflowError:
637
+ eval_metrics["perplexity"] = float("inf")
638
+
639
+ # Print metrics and update progress bar
640
+ desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
641
+ epochs.write(desc)
642
+ epochs.desc = desc
643
+
644
+ # Save metrics
645
+ if has_tensorboard and jax.process_index() == 0:
646
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
647
+
648
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
649
+ # save checkpoint after each epoch and push checkpoint to the hub
650
+ if jax.process_index() == 0:
651
+ params = jax.device_get(unreplicate(state.params))
652
+ model.save_pretrained(
653
+ training_args.output_dir,
654
+ params=params,
655
+ push_to_hub=training_args.push_to_hub,
656
+ commit_message=f"Saving weights and logs of step {cur_step}",
657
+ )
658
+
659
+
660
+ if __name__ == "__main__":
661
+ main()
src/train_tokenizer.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from datasets import load_dataset,concatenate_datasets
3
+ from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
4
+
5
+
6
+ from datasets import load_dataset
7
+ from tokenizers import ByteLevelBPETokenizer # Tokenizer, normalizers, trainers
8
+
9
+ model_dir = "../gpt-2-tamil" # ${MODEL_DIR}
10
+
11
+
12
+ # load dataset
13
+ dataset = load_dataset("oscar", "unshuffled_deduplicated_ta", split="train")
14
+ indic_tamil = load_dataset("csv",data_files="/tmp/indic_corp/ta.csv")
15
+ dataset = concatenate_datasets([dataset,indic_tamil['train']])
16
+ # Instantiate tokenizer
17
+ tokenizer = ByteLevelBPETokenizer()
18
+
19
+
20
+ def batch_iterator(batch_size=1000):
21
+ for i in range(0, len(dataset), batch_size):
22
+ yield dataset[i : i + batch_size]["text"]
23
+
24
+
25
+ # Customized training
26
+ tokenizer.train_from_iterator(
27
+ batch_iterator(),
28
+ vocab_size=50265,
29
+ min_frequency=2,
30
+ special_tokens=[
31
+ "<s>",
32
+ "<pad>",
33
+ "</s>",
34
+ "<unk>",
35
+ "<mask>",
36
+ ],
37
+ )
38
+
39
+ # Save files to disk
40
+ tokenizer.save(f"{model_dir}/tokenizer.json")