ksterx commited on
Commit
bbdb282
·
verified ·
1 Parent(s): 6b369b5

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +120 -0
  2. loss_comparison.png +0 -0
  3. retention.gif +0 -0
README.md CHANGED
@@ -1,3 +1,123 @@
1
  ---
 
 
 
2
  license: mit
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - ja
4
+ - en
5
  license: mit
6
+ library_name: transformers
7
  ---
8
+
9
+ ![SpiralAI RetNet-3b-ja-base](logo.png)
10
+
11
+ # SpiralAI RetNet-3b-ja-base
12
+
13
+ We have conducted pre-training from scratch on the RetNet (https://arxiv.org/abs/2307.08621) architecture model 3b using a mixed dataset of Japanese and English.
14
+ This model is released primarily for the basic research of "retention mechanism".
15
+
16
+ # Model Description
17
+
18
+ - **Developed by:** [SpiralAI](https://go-spiral.ai/)
19
+ - **Model type:** The `SpiralAI RetNet-3b-ja-base` is a language model equipped with a retention mechanism. It uses the `cyberagent/calm2-7b-chat` tokenizer.
20
+ - **Languages:** Japanese, English.
21
+ - **License:** MIT
22
+ - **Training:** Trained on 80b tokens.
23
+ - **Context Length:** 2,048 tokens.
24
+
25
+ # Installation
26
+
27
+ ```bash
28
+ pip install transformers==4.38 # The top_k_top_p_filtering feature has been removed in later versions.
29
+
30
+ ```
31
+
32
+ Clone the repository from **`https://github.com/syncdoth/RetNet`** and follow the *Getting Started* guide provided there.
33
+
34
+ Example:
35
+
36
+ ```bash
37
+ git clone https://github.com/syncdoth/RetNet.git
38
+ pip install torch transformers timm
39
+ cd RetNet
40
+
41
+ ```
42
+
43
+ # Usage
44
+
45
+ ```python
46
+ from transformers import AutoTokenizer
47
+
48
+ from retnet.modeling_retnet import RetNetForCausalLM
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm2-7b-chat")
51
+ tokenizer.pad_token = tokenizer.eos_token
52
+
53
+ model = RetNetForCausalLM.from_pretrained(
54
+ "Spiral-AI/RetNet-3b-base-ja", device_map="auto"
55
+ )
56
+ inputs = tokenizer("最近、秋葉原周辺で興味深い", return_tensors="pt")
57
+ input_ids = inputs["input_ids"].to(model.device)
58
+ generated = model.generate(
59
+ input_ids,
60
+ max_new_tokens=32,
61
+ repetition_penalty=1.2, # better to set this value for 3 billion model
62
+ )
63
+ print(tokenizer.decode(generated[0]))
64
+
65
+ ```
66
+
67
+ ## Examples
68
+ ```
69
+ input: 最近、秋葉原周辺で興味深い
70
+ output: お店がいくつかあります。
71
+ 1. 神田カレー街「カレーハウスCoCo壱番屋」
72
+ 2016年7月3日オープン
73
+ ```
74
+
75
+ ```
76
+ input: 近年、AI技術の進歩によって
77
+ output: 人間の仕事が奪われるのではないかという懸念がある。
78
+ しかしながら、AIは人間に取って代わるものではなく、「人間がコンピュータに仕事をさせる」という考え方
79
+ ```
80
+
81
+ ```
82
+ input: When I was a child, I used to play with
83
+ output: 3-D glasses. They were so much fun!
84
+ I have been playing around in the world of video games for years now and it is amazing how
85
+ ```
86
+
87
+ # Basic study
88
+
89
+ ## Visualization of the retention mechanism
90
+
91
+ ![retention](retention.gif)
92
+ This visualization shows the retention mechanism in action. The token being generated is represented by `*`.
93
+ The blue bars show how the tokens are weighted during generation.
94
+
95
+ Using the mathmatical equivalence between "recurrent mode" and "parallel mode", we apply the similar visualization technique as the attention mechanism, e.g.,
96
+ inner product between queries and keys are added up over all heads after absolute values are taken.
97
+ Here we show the result of the last layer.
98
+
99
+ ## Test loss comparison
100
+
101
+ We compared the test loss of `Spiral-AI/RetNet-3b-ja-base` and `cyberagent/open-calm-3b` on different length of tokens.
102
+ The first 100 examples are extracted from `wikipedia-ja` for the test dataset.
103
+
104
+ ![test_loss](loss_comparison.png)
105
+
106
+ Key findings are:
107
+
108
+ - The test loss of `Spiral-AI/RetNet-3b-ja-base` goes as low as `cyberagent/open-calm-3b`, showing the effectiveness of the retention mechanism.
109
+ - The explosion of test loss is suppressed in `Spiral-AI/RetNet-3b-ja-base` when the context length goes longer than 2,048 tokens (the maximum context length of training data; Note that `cyberagent/open-calm-3b` is trained on the same context length.).
110
+
111
+ # Training Datasets
112
+
113
+ - [izumi-lab/cc100-ja-filter-ja-normal](https://huggingface.co/datasets/izumi-lab/cc100-ja-filter-ja-normal) (Japanese)
114
+ - [izumi-lab/wikipedia-ja-20230720](https://huggingface.co/datasets/izumi-lab/wikipedia-ja-20230720) (Japanese)
115
+ - [wikipedia](https://huggingface.co/datasets/wikipedia/tree/main/data/20220301.en) (English)
116
+ - [uonlp/CulturaX](https://huggingface.co/datasets/uonlp/CulturaX) (English, Japanese)
117
+
118
+ # Limitations
119
+
120
+ This model is designed for broad applicability, but it may not fully meet the specific needs or contexts of all uses.
121
+ Pre-training data may contain inappropriate content, which could be reflected in the texts generated by the model. Therefore, when using this model, it is important to carefully review its output and avoid situations where it might cause discomfort or harm to individuals or groups.
122
+
123
+ There are no specific restrictions on commercial use, but users are responsible for addressing any ethical or legal issues that may arise in connection with the use of the model.
loss_comparison.png ADDED
retention.gif ADDED