jdchang commited on
Commit
b3ca691
·
verified ·
1 Parent(s): 5c0204c

Upload model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "AutoModelForCausalLMWithRM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "modeling_hf.RewardModelConfig",
9
+ "AutoModel": "modeling_hf.AutoModelForCausalLMWithRM"
10
+ },
11
+ "base_config": {
12
+ "_name_or_path": "meta-llama/Meta-Llama-3-8B-Instruct",
13
+ "architectures": [
14
+ "LlamaForCausalLM"
15
+ ],
16
+ "bos_token_id": 128000,
17
+ "eos_token_id": 128009,
18
+ "hidden_size": 4096,
19
+ "intermediate_size": 14336,
20
+ "max_position_embeddings": 8192,
21
+ "model_type": "llama",
22
+ "num_attention_heads": 32,
23
+ "num_hidden_layers": 32,
24
+ "num_key_value_heads": 8,
25
+ "rms_norm_eps": 1e-05,
26
+ "rope_theta": 500000.0,
27
+ "torch_dtype": "bfloat16",
28
+ "vocab_size": 128256
29
+ },
30
+ "base_model": "meta-llama/Meta-Llama-3-8B-Instruct",
31
+ "bias": 0.0,
32
+ "bos_token_id": 128000,
33
+ "eos_token_id": 128009,
34
+ "hidden_act": "silu",
35
+ "hidden_size": 4096,
36
+ "initializer_range": 0.02,
37
+ "intermediate_size": 14336,
38
+ "max_position_embeddings": 8192,
39
+ "mlp_bias": false,
40
+ "model_type": "pairwise_rm",
41
+ "n_labels": 1,
42
+ "num_attention_heads": 32,
43
+ "num_hidden_layers": 32,
44
+ "num_key_value_heads": 8,
45
+ "p_dropout": 0.0,
46
+ "pretrain_cfg": {},
47
+ "pretrained": false,
48
+ "pretraining_tp": 1,
49
+ "return_logits": false,
50
+ "rms_norm_eps": 1e-05,
51
+ "rope_scaling": null,
52
+ "rope_theta": 500000.0,
53
+ "tie_word_embeddings": false,
54
+ "torch_dtype": "float32",
55
+ "transformers_version": "4.43.3",
56
+ "use_cache": true,
57
+ "vocab_size": 128256
58
+ }
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 128000,
4
+ "eos_token_id": 128009,
5
+ "transformers_version": "4.43.3"
6
+ }
model-00001-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44c28b7123dbfdc86fe5731cecdb2720c8405aaa080c265abf1ee8ad9da3eace
3
+ size 4886466552
model-00002-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9173c728a034d0f60cac1f2fd972c931454b089d6fcfbfba28e0c9926f028a48
3
+ size 4832008016
model-00003-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f416de79a97f722588268f9b295581dfa88c9d48c793f357bd816e5f27dbf1c8
3
+ size 4999813744
model-00004-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79495bbe326d3bf5bce6cd67e37be87f970375a95f17299f87a421d6f6caa993
3
+ size 4999813760
model-00005-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e78edf98cffc3f5f591f192f2bc2143be27a7a9fbbb14d0d005e7f6d76fdba6b
3
+ size 4832008064
model-00006-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2546a266419f7c08a19af5863bd95ebde878bce61b7adf2bf4a21ec71791963d
3
+ size 4999813760
model-00007-of-00007.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70e2fab457fce951308aceda3dc5eeb1eb35dc11233cea432ff9cf9f50013615
3
+ size 2638300276
model.safetensors.index.json ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 32188186628
4
+ },
5
+ "weight_map": {
6
+ "lm_backbone.lm_head.weight": "model-00007-of-00007.safetensors",
7
+ "lm_backbone.model.embed_tokens.weight": "model-00001-of-00007.safetensors",
8
+ "lm_backbone.model.layers.0.input_layernorm.weight": "model-00001-of-00007.safetensors",
9
+ "lm_backbone.model.layers.0.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
10
+ "lm_backbone.model.layers.0.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
11
+ "lm_backbone.model.layers.0.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
12
+ "lm_backbone.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00007.safetensors",
13
+ "lm_backbone.model.layers.0.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
14
+ "lm_backbone.model.layers.0.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
15
+ "lm_backbone.model.layers.0.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
16
+ "lm_backbone.model.layers.0.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
17
+ "lm_backbone.model.layers.1.input_layernorm.weight": "model-00001-of-00007.safetensors",
18
+ "lm_backbone.model.layers.1.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
19
+ "lm_backbone.model.layers.1.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
20
+ "lm_backbone.model.layers.1.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
21
+ "lm_backbone.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00007.safetensors",
22
+ "lm_backbone.model.layers.1.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
23
+ "lm_backbone.model.layers.1.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
24
+ "lm_backbone.model.layers.1.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
25
+ "lm_backbone.model.layers.1.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
26
+ "lm_backbone.model.layers.10.input_layernorm.weight": "model-00003-of-00007.safetensors",
27
+ "lm_backbone.model.layers.10.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
28
+ "lm_backbone.model.layers.10.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
29
+ "lm_backbone.model.layers.10.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
30
+ "lm_backbone.model.layers.10.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
31
+ "lm_backbone.model.layers.10.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
32
+ "lm_backbone.model.layers.10.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
33
+ "lm_backbone.model.layers.10.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
34
+ "lm_backbone.model.layers.10.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
35
+ "lm_backbone.model.layers.11.input_layernorm.weight": "model-00003-of-00007.safetensors",
36
+ "lm_backbone.model.layers.11.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
37
+ "lm_backbone.model.layers.11.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
38
+ "lm_backbone.model.layers.11.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
39
+ "lm_backbone.model.layers.11.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
40
+ "lm_backbone.model.layers.11.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
41
+ "lm_backbone.model.layers.11.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
42
+ "lm_backbone.model.layers.11.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
43
+ "lm_backbone.model.layers.11.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
44
+ "lm_backbone.model.layers.12.input_layernorm.weight": "model-00003-of-00007.safetensors",
45
+ "lm_backbone.model.layers.12.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
46
+ "lm_backbone.model.layers.12.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
47
+ "lm_backbone.model.layers.12.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
48
+ "lm_backbone.model.layers.12.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
49
+ "lm_backbone.model.layers.12.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
50
+ "lm_backbone.model.layers.12.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
51
+ "lm_backbone.model.layers.12.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
52
+ "lm_backbone.model.layers.12.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
53
+ "lm_backbone.model.layers.13.input_layernorm.weight": "model-00003-of-00007.safetensors",
54
+ "lm_backbone.model.layers.13.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
55
+ "lm_backbone.model.layers.13.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
56
+ "lm_backbone.model.layers.13.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
57
+ "lm_backbone.model.layers.13.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
58
+ "lm_backbone.model.layers.13.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
59
+ "lm_backbone.model.layers.13.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
60
+ "lm_backbone.model.layers.13.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
61
+ "lm_backbone.model.layers.13.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
62
+ "lm_backbone.model.layers.14.input_layernorm.weight": "model-00004-of-00007.safetensors",
63
+ "lm_backbone.model.layers.14.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
64
+ "lm_backbone.model.layers.14.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
65
+ "lm_backbone.model.layers.14.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
66
+ "lm_backbone.model.layers.14.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
67
+ "lm_backbone.model.layers.14.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
68
+ "lm_backbone.model.layers.14.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
69
+ "lm_backbone.model.layers.14.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
70
+ "lm_backbone.model.layers.14.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
71
+ "lm_backbone.model.layers.15.input_layernorm.weight": "model-00004-of-00007.safetensors",
72
+ "lm_backbone.model.layers.15.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
73
+ "lm_backbone.model.layers.15.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
74
+ "lm_backbone.model.layers.15.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
75
+ "lm_backbone.model.layers.15.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
76
+ "lm_backbone.model.layers.15.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
77
+ "lm_backbone.model.layers.15.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
78
+ "lm_backbone.model.layers.15.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
79
+ "lm_backbone.model.layers.15.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
80
+ "lm_backbone.model.layers.16.input_layernorm.weight": "model-00004-of-00007.safetensors",
81
+ "lm_backbone.model.layers.16.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
82
+ "lm_backbone.model.layers.16.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
83
+ "lm_backbone.model.layers.16.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
84
+ "lm_backbone.model.layers.16.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
85
+ "lm_backbone.model.layers.16.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
86
+ "lm_backbone.model.layers.16.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
87
+ "lm_backbone.model.layers.16.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
88
+ "lm_backbone.model.layers.16.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
89
+ "lm_backbone.model.layers.17.input_layernorm.weight": "model-00004-of-00007.safetensors",
90
+ "lm_backbone.model.layers.17.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
91
+ "lm_backbone.model.layers.17.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
92
+ "lm_backbone.model.layers.17.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
93
+ "lm_backbone.model.layers.17.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
94
+ "lm_backbone.model.layers.17.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
95
+ "lm_backbone.model.layers.17.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
96
+ "lm_backbone.model.layers.17.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
97
+ "lm_backbone.model.layers.17.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
98
+ "lm_backbone.model.layers.18.input_layernorm.weight": "model-00004-of-00007.safetensors",
99
+ "lm_backbone.model.layers.18.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
100
+ "lm_backbone.model.layers.18.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
101
+ "lm_backbone.model.layers.18.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
102
+ "lm_backbone.model.layers.18.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
103
+ "lm_backbone.model.layers.18.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
104
+ "lm_backbone.model.layers.18.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
105
+ "lm_backbone.model.layers.18.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
106
+ "lm_backbone.model.layers.18.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
107
+ "lm_backbone.model.layers.19.input_layernorm.weight": "model-00004-of-00007.safetensors",
108
+ "lm_backbone.model.layers.19.mlp.down_proj.weight": "model-00004-of-00007.safetensors",
109
+ "lm_backbone.model.layers.19.mlp.gate_proj.weight": "model-00004-of-00007.safetensors",
110
+ "lm_backbone.model.layers.19.mlp.up_proj.weight": "model-00004-of-00007.safetensors",
111
+ "lm_backbone.model.layers.19.post_attention_layernorm.weight": "model-00004-of-00007.safetensors",
112
+ "lm_backbone.model.layers.19.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
113
+ "lm_backbone.model.layers.19.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
114
+ "lm_backbone.model.layers.19.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
115
+ "lm_backbone.model.layers.19.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
116
+ "lm_backbone.model.layers.2.input_layernorm.weight": "model-00001-of-00007.safetensors",
117
+ "lm_backbone.model.layers.2.mlp.down_proj.weight": "model-00001-of-00007.safetensors",
118
+ "lm_backbone.model.layers.2.mlp.gate_proj.weight": "model-00001-of-00007.safetensors",
119
+ "lm_backbone.model.layers.2.mlp.up_proj.weight": "model-00001-of-00007.safetensors",
120
+ "lm_backbone.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00007.safetensors",
121
+ "lm_backbone.model.layers.2.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
122
+ "lm_backbone.model.layers.2.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
123
+ "lm_backbone.model.layers.2.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
124
+ "lm_backbone.model.layers.2.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
125
+ "lm_backbone.model.layers.20.input_layernorm.weight": "model-00005-of-00007.safetensors",
126
+ "lm_backbone.model.layers.20.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
127
+ "lm_backbone.model.layers.20.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
128
+ "lm_backbone.model.layers.20.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
129
+ "lm_backbone.model.layers.20.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
130
+ "lm_backbone.model.layers.20.self_attn.k_proj.weight": "model-00004-of-00007.safetensors",
131
+ "lm_backbone.model.layers.20.self_attn.o_proj.weight": "model-00004-of-00007.safetensors",
132
+ "lm_backbone.model.layers.20.self_attn.q_proj.weight": "model-00004-of-00007.safetensors",
133
+ "lm_backbone.model.layers.20.self_attn.v_proj.weight": "model-00004-of-00007.safetensors",
134
+ "lm_backbone.model.layers.21.input_layernorm.weight": "model-00005-of-00007.safetensors",
135
+ "lm_backbone.model.layers.21.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
136
+ "lm_backbone.model.layers.21.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
137
+ "lm_backbone.model.layers.21.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
138
+ "lm_backbone.model.layers.21.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
139
+ "lm_backbone.model.layers.21.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
140
+ "lm_backbone.model.layers.21.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
141
+ "lm_backbone.model.layers.21.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
142
+ "lm_backbone.model.layers.21.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
143
+ "lm_backbone.model.layers.22.input_layernorm.weight": "model-00005-of-00007.safetensors",
144
+ "lm_backbone.model.layers.22.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
145
+ "lm_backbone.model.layers.22.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
146
+ "lm_backbone.model.layers.22.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
147
+ "lm_backbone.model.layers.22.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
148
+ "lm_backbone.model.layers.22.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
149
+ "lm_backbone.model.layers.22.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
150
+ "lm_backbone.model.layers.22.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
151
+ "lm_backbone.model.layers.22.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
152
+ "lm_backbone.model.layers.23.input_layernorm.weight": "model-00005-of-00007.safetensors",
153
+ "lm_backbone.model.layers.23.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
154
+ "lm_backbone.model.layers.23.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
155
+ "lm_backbone.model.layers.23.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
156
+ "lm_backbone.model.layers.23.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
157
+ "lm_backbone.model.layers.23.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
158
+ "lm_backbone.model.layers.23.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
159
+ "lm_backbone.model.layers.23.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
160
+ "lm_backbone.model.layers.23.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
161
+ "lm_backbone.model.layers.24.input_layernorm.weight": "model-00005-of-00007.safetensors",
162
+ "lm_backbone.model.layers.24.mlp.down_proj.weight": "model-00005-of-00007.safetensors",
163
+ "lm_backbone.model.layers.24.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
164
+ "lm_backbone.model.layers.24.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
165
+ "lm_backbone.model.layers.24.post_attention_layernorm.weight": "model-00005-of-00007.safetensors",
166
+ "lm_backbone.model.layers.24.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
167
+ "lm_backbone.model.layers.24.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
168
+ "lm_backbone.model.layers.24.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
169
+ "lm_backbone.model.layers.24.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
170
+ "lm_backbone.model.layers.25.input_layernorm.weight": "model-00006-of-00007.safetensors",
171
+ "lm_backbone.model.layers.25.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
172
+ "lm_backbone.model.layers.25.mlp.gate_proj.weight": "model-00005-of-00007.safetensors",
173
+ "lm_backbone.model.layers.25.mlp.up_proj.weight": "model-00005-of-00007.safetensors",
174
+ "lm_backbone.model.layers.25.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
175
+ "lm_backbone.model.layers.25.self_attn.k_proj.weight": "model-00005-of-00007.safetensors",
176
+ "lm_backbone.model.layers.25.self_attn.o_proj.weight": "model-00005-of-00007.safetensors",
177
+ "lm_backbone.model.layers.25.self_attn.q_proj.weight": "model-00005-of-00007.safetensors",
178
+ "lm_backbone.model.layers.25.self_attn.v_proj.weight": "model-00005-of-00007.safetensors",
179
+ "lm_backbone.model.layers.26.input_layernorm.weight": "model-00006-of-00007.safetensors",
180
+ "lm_backbone.model.layers.26.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
181
+ "lm_backbone.model.layers.26.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
182
+ "lm_backbone.model.layers.26.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
183
+ "lm_backbone.model.layers.26.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
184
+ "lm_backbone.model.layers.26.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
185
+ "lm_backbone.model.layers.26.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
186
+ "lm_backbone.model.layers.26.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
187
+ "lm_backbone.model.layers.26.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
188
+ "lm_backbone.model.layers.27.input_layernorm.weight": "model-00006-of-00007.safetensors",
189
+ "lm_backbone.model.layers.27.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
190
+ "lm_backbone.model.layers.27.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
191
+ "lm_backbone.model.layers.27.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
192
+ "lm_backbone.model.layers.27.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
193
+ "lm_backbone.model.layers.27.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
194
+ "lm_backbone.model.layers.27.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
195
+ "lm_backbone.model.layers.27.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
196
+ "lm_backbone.model.layers.27.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
197
+ "lm_backbone.model.layers.28.input_layernorm.weight": "model-00006-of-00007.safetensors",
198
+ "lm_backbone.model.layers.28.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
199
+ "lm_backbone.model.layers.28.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
200
+ "lm_backbone.model.layers.28.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
201
+ "lm_backbone.model.layers.28.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
202
+ "lm_backbone.model.layers.28.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
203
+ "lm_backbone.model.layers.28.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
204
+ "lm_backbone.model.layers.28.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
205
+ "lm_backbone.model.layers.28.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
206
+ "lm_backbone.model.layers.29.input_layernorm.weight": "model-00006-of-00007.safetensors",
207
+ "lm_backbone.model.layers.29.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
208
+ "lm_backbone.model.layers.29.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
209
+ "lm_backbone.model.layers.29.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
210
+ "lm_backbone.model.layers.29.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
211
+ "lm_backbone.model.layers.29.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
212
+ "lm_backbone.model.layers.29.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
213
+ "lm_backbone.model.layers.29.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
214
+ "lm_backbone.model.layers.29.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
215
+ "lm_backbone.model.layers.3.input_layernorm.weight": "model-00002-of-00007.safetensors",
216
+ "lm_backbone.model.layers.3.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
217
+ "lm_backbone.model.layers.3.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
218
+ "lm_backbone.model.layers.3.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
219
+ "lm_backbone.model.layers.3.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
220
+ "lm_backbone.model.layers.3.self_attn.k_proj.weight": "model-00001-of-00007.safetensors",
221
+ "lm_backbone.model.layers.3.self_attn.o_proj.weight": "model-00001-of-00007.safetensors",
222
+ "lm_backbone.model.layers.3.self_attn.q_proj.weight": "model-00001-of-00007.safetensors",
223
+ "lm_backbone.model.layers.3.self_attn.v_proj.weight": "model-00001-of-00007.safetensors",
224
+ "lm_backbone.model.layers.30.input_layernorm.weight": "model-00006-of-00007.safetensors",
225
+ "lm_backbone.model.layers.30.mlp.down_proj.weight": "model-00006-of-00007.safetensors",
226
+ "lm_backbone.model.layers.30.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
227
+ "lm_backbone.model.layers.30.mlp.up_proj.weight": "model-00006-of-00007.safetensors",
228
+ "lm_backbone.model.layers.30.post_attention_layernorm.weight": "model-00006-of-00007.safetensors",
229
+ "lm_backbone.model.layers.30.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
230
+ "lm_backbone.model.layers.30.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
231
+ "lm_backbone.model.layers.30.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
232
+ "lm_backbone.model.layers.30.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
233
+ "lm_backbone.model.layers.31.input_layernorm.weight": "model-00007-of-00007.safetensors",
234
+ "lm_backbone.model.layers.31.mlp.down_proj.weight": "model-00007-of-00007.safetensors",
235
+ "lm_backbone.model.layers.31.mlp.gate_proj.weight": "model-00006-of-00007.safetensors",
236
+ "lm_backbone.model.layers.31.mlp.up_proj.weight": "model-00007-of-00007.safetensors",
237
+ "lm_backbone.model.layers.31.post_attention_layernorm.weight": "model-00007-of-00007.safetensors",
238
+ "lm_backbone.model.layers.31.self_attn.k_proj.weight": "model-00006-of-00007.safetensors",
239
+ "lm_backbone.model.layers.31.self_attn.o_proj.weight": "model-00006-of-00007.safetensors",
240
+ "lm_backbone.model.layers.31.self_attn.q_proj.weight": "model-00006-of-00007.safetensors",
241
+ "lm_backbone.model.layers.31.self_attn.v_proj.weight": "model-00006-of-00007.safetensors",
242
+ "lm_backbone.model.layers.4.input_layernorm.weight": "model-00002-of-00007.safetensors",
243
+ "lm_backbone.model.layers.4.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
244
+ "lm_backbone.model.layers.4.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
245
+ "lm_backbone.model.layers.4.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
246
+ "lm_backbone.model.layers.4.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
247
+ "lm_backbone.model.layers.4.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
248
+ "lm_backbone.model.layers.4.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
249
+ "lm_backbone.model.layers.4.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
250
+ "lm_backbone.model.layers.4.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
251
+ "lm_backbone.model.layers.5.input_layernorm.weight": "model-00002-of-00007.safetensors",
252
+ "lm_backbone.model.layers.5.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
253
+ "lm_backbone.model.layers.5.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
254
+ "lm_backbone.model.layers.5.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
255
+ "lm_backbone.model.layers.5.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
256
+ "lm_backbone.model.layers.5.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
257
+ "lm_backbone.model.layers.5.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
258
+ "lm_backbone.model.layers.5.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
259
+ "lm_backbone.model.layers.5.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
260
+ "lm_backbone.model.layers.6.input_layernorm.weight": "model-00002-of-00007.safetensors",
261
+ "lm_backbone.model.layers.6.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
262
+ "lm_backbone.model.layers.6.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
263
+ "lm_backbone.model.layers.6.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
264
+ "lm_backbone.model.layers.6.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
265
+ "lm_backbone.model.layers.6.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
266
+ "lm_backbone.model.layers.6.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
267
+ "lm_backbone.model.layers.6.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
268
+ "lm_backbone.model.layers.6.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
269
+ "lm_backbone.model.layers.7.input_layernorm.weight": "model-00002-of-00007.safetensors",
270
+ "lm_backbone.model.layers.7.mlp.down_proj.weight": "model-00002-of-00007.safetensors",
271
+ "lm_backbone.model.layers.7.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
272
+ "lm_backbone.model.layers.7.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
273
+ "lm_backbone.model.layers.7.post_attention_layernorm.weight": "model-00002-of-00007.safetensors",
274
+ "lm_backbone.model.layers.7.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
275
+ "lm_backbone.model.layers.7.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
276
+ "lm_backbone.model.layers.7.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
277
+ "lm_backbone.model.layers.7.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
278
+ "lm_backbone.model.layers.8.input_layernorm.weight": "model-00003-of-00007.safetensors",
279
+ "lm_backbone.model.layers.8.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
280
+ "lm_backbone.model.layers.8.mlp.gate_proj.weight": "model-00002-of-00007.safetensors",
281
+ "lm_backbone.model.layers.8.mlp.up_proj.weight": "model-00002-of-00007.safetensors",
282
+ "lm_backbone.model.layers.8.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
283
+ "lm_backbone.model.layers.8.self_attn.k_proj.weight": "model-00002-of-00007.safetensors",
284
+ "lm_backbone.model.layers.8.self_attn.o_proj.weight": "model-00002-of-00007.safetensors",
285
+ "lm_backbone.model.layers.8.self_attn.q_proj.weight": "model-00002-of-00007.safetensors",
286
+ "lm_backbone.model.layers.8.self_attn.v_proj.weight": "model-00002-of-00007.safetensors",
287
+ "lm_backbone.model.layers.9.input_layernorm.weight": "model-00003-of-00007.safetensors",
288
+ "lm_backbone.model.layers.9.mlp.down_proj.weight": "model-00003-of-00007.safetensors",
289
+ "lm_backbone.model.layers.9.mlp.gate_proj.weight": "model-00003-of-00007.safetensors",
290
+ "lm_backbone.model.layers.9.mlp.up_proj.weight": "model-00003-of-00007.safetensors",
291
+ "lm_backbone.model.layers.9.post_attention_layernorm.weight": "model-00003-of-00007.safetensors",
292
+ "lm_backbone.model.layers.9.self_attn.k_proj.weight": "model-00003-of-00007.safetensors",
293
+ "lm_backbone.model.layers.9.self_attn.o_proj.weight": "model-00003-of-00007.safetensors",
294
+ "lm_backbone.model.layers.9.self_attn.q_proj.weight": "model-00003-of-00007.safetensors",
295
+ "lm_backbone.model.layers.9.self_attn.v_proj.weight": "model-00003-of-00007.safetensors",
296
+ "lm_backbone.model.norm.weight": "model-00007-of-00007.safetensors",
297
+ "value_head.dense.bias": "model-00007-of-00007.safetensors",
298
+ "value_head.dense.weight": "model-00007-of-00007.safetensors",
299
+ "value_head.score.bias": "model-00007-of-00007.safetensors",
300
+ "value_head.score.weight": "model-00007-of-00007.safetensors"
301
+ }
302
+ }
modeling_hf.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 MosaicML LLM Foundry authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`."""
5
+
6
+ import os
7
+ from copy import deepcopy
8
+ import warnings
9
+ import numpy as np
10
+ import logging
11
+ from typing import (
12
+ TYPE_CHECKING,
13
+ Any,
14
+ List,
15
+ Mapping,
16
+ Optional,
17
+ Tuple,
18
+ Union,
19
+ Dict,
20
+ )
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from types import SimpleNamespace
25
+ from composer.models.huggingface import peft_installed
26
+ from composer.utils import dist
27
+
28
+ from torchmetrics import Metric
29
+ from transformers import (
30
+ AutoConfig,
31
+ AutoModelForCausalLM,
32
+ PretrainedConfig,
33
+ PreTrainedModel,
34
+ PreTrainedTokenizerBase,
35
+ PreTrainedTokenizerFast,
36
+ PreTrainedTokenizer,
37
+ )
38
+
39
+ from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
40
+ from llmfoundry.models.layers.attention import is_flash_v2_installed
41
+ from llmfoundry.models.utils import init_empty_weights
42
+ from llmfoundry.utils.config_utils import get_hf_config_value
43
+
44
+ from composer.models.huggingface import HuggingFaceModel
45
+ from compose_rl.reward_learning.utils import prepare_hf_sequence_classification_model_for_fsdp, SequenceClassifierOutput
46
+
47
+ if TYPE_CHECKING:
48
+ from peft import PeftModel
49
+
50
+ __all__ = ['ComposerHFSequenceClassification']
51
+
52
+ log = logging.getLogger(__name__)
53
+
54
+
55
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
56
+
57
+
58
+ def layer_init(layer: nn.Module, std: float=np.sqrt(2), bias_const: float=0.0):
59
+ torch.nn.init.normal_(layer.weight, std=std)
60
+ torch.nn.init.constant_(layer.bias, val=bias_const)
61
+ return layer
62
+
63
+
64
+ class RewardModelConfig(PretrainedConfig):
65
+ model_type = "pairwise_rm"
66
+
67
+ def __init__(
68
+ self,
69
+ base_model: str = "meta-llama/Meta-Llama-3-70B-Instruct",
70
+ base_config: PretrainedConfig = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct"),
71
+ p_dropout: float = 0.0,
72
+ n_labels: int = 1,
73
+ bias: float = 0.0,
74
+ return_logits: bool = False,
75
+ pretrain_cfg: Dict[str, Any] = {},
76
+ pretrained: bool = False,
77
+ **kwargs: Any,
78
+ ):
79
+ super().__init__(**kwargs)
80
+ self.base_model = base_model
81
+ self.base_config = base_config
82
+ temp_config = deepcopy(base_config)
83
+ if not isinstance(base_config, dict):
84
+ temp_config = base_config.__dict__
85
+ for key, value in temp_config.items():
86
+ if key not in ["_name_or_path", "architectures"]:
87
+ setattr(self, key, value)
88
+ self.p_dropout = p_dropout
89
+ self.n_labels = n_labels
90
+ self.bias = bias
91
+ self.return_logits = return_logits
92
+ self.pretrain_cfg = pretrain_cfg
93
+ self.pretrained = pretrained
94
+
95
+
96
+ class ValueHead(nn.Module):
97
+
98
+ def __init__(self, config: RewardModelConfig):
99
+ super().__init__()
100
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
101
+ self.dropout = nn.Dropout(config.p_dropout)
102
+ self.score = layer_init(
103
+ nn.Linear(config.hidden_size, config.n_labels),
104
+ std=1 / np.sqrt(config.hidden_size + 1),
105
+ )
106
+ self.score = nn.Linear(config.hidden_size, config.n_labels)
107
+
108
+ def forward(self, hidden_states: torch.Tensor, **kwargs: Any):
109
+ hidden_states = self.dropout(hidden_states)
110
+ hidden_states = self.dense(hidden_states)
111
+ hidden_states = torch.tanh(hidden_states)
112
+ hidden_states = self.dropout(hidden_states)
113
+ output = self.score(hidden_states)
114
+ return output
115
+
116
+
117
+ class AutoModelForCausalLMWithRM(PreTrainedModel):
118
+ config_class = RewardModelConfig
119
+
120
+ def __init__(self, config: RewardModelConfig):
121
+ super().__init__(config)
122
+ self.config = config
123
+ pretrain_cfg = config.pretrain_cfg
124
+ pretrained = config.pretrained
125
+ if pretrained:
126
+ self.lm_backbone = AutoModelForCausalLM.from_pretrained(
127
+ config.base_model,
128
+ config=config.base_config,
129
+ **pretrain_cfg,
130
+ )
131
+ else:
132
+ #hack for now
133
+ if isinstance(config.base_config, dict):
134
+ config.base_config = AutoConfig.from_pretrained(config.base_model, **config.base_config)
135
+ self.lm_backbone = AutoModelForCausalLM.from_config(
136
+ config.base_config,
137
+ trust_remote_code=True,
138
+ )
139
+ self.value_head = ValueHead(config)
140
+
141
+ def generate(self, *args: Any, **kwargs: Any):
142
+ return self.lm_backbone.generate(**kwargs)
143
+
144
+ def resize_token_embeddings(
145
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
146
+ ) -> nn.Embedding:
147
+ # Note need to update vocab size in base config as well so lm_head modification happens
148
+ self.config.base_config.vocab_size = new_num_tokens
149
+ model_embeds = super().resize_token_embeddings(new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of)
150
+ return model_embeds
151
+
152
+ def set_input_embeddings(self, new_embeddings):
153
+ return self.lm_backbone.set_input_embeddings(new_embeddings)
154
+
155
+ def get_input_embeddings(self):
156
+ return self.lm_backbone.get_input_embeddings()
157
+
158
+ def set_output_embeddings(self, new_embeddings):
159
+ return self.lm_backbone.set_output_embeddings(new_embeddings)
160
+
161
+ def get_output_embeddings(self):
162
+ return self.lm_backbone.get_output_embeddings()
163
+
164
+ def forward(
165
+ self,
166
+ input_ids: torch.LongTensor = None,
167
+ attention_mask: Optional[torch.Tensor] = None,
168
+ position_ids: Optional[torch.LongTensor] = None,
169
+ past_key_values: Optional[Any] = None,
170
+ inputs_embeds: Optional[torch.FloatTensor] = None,
171
+ labels: Optional[torch.LongTensor] = None,
172
+ use_cache: Optional[bool] = None,
173
+ output_attentions: Optional[bool] = None,
174
+ output_hidden_states: Optional[bool] = None,
175
+ return_dict: Optional[bool] = None,
176
+ cache_position: Optional[torch.LongTensor] = None,
177
+ **kwargs: Any,
178
+ ):
179
+ output = self.lm_backbone(
180
+ input_ids=input_ids,
181
+ attention_mask=attention_mask,
182
+ position_ids=position_ids,
183
+ past_key_values=past_key_values,
184
+ inputs_embeds=inputs_embeds,
185
+ labels=labels,
186
+ use_cache=use_cache,
187
+ output_attentions=output_attentions,
188
+ output_hidden_states=True,
189
+ return_dict=True,
190
+ cache_position=cache_position,
191
+ )
192
+ scores = self.value_head(output.hidden_states[-1]).squeeze(-1) - self.config.bias
193
+
194
+ logits = None
195
+ if self.config.return_logits:
196
+ logits = output.logits
197
+
198
+ return SequenceClassifierOutput(
199
+ loss=output.loss,
200
+ scores=scores,
201
+ logits=logits,
202
+ past_key_values=output.past_key_values,
203
+ hidden_states=output.hidden_states,
204
+ attentions=output.attentions,
205
+ )
206
+
207
+
208
+ class ComposerHFSequenceClassification(HuggingFaceModel):
209
+
210
+ """Configures a :class:`.HuggingFaceModel` around a Causal LM.
211
+
212
+ Args:
213
+ pretrained_model_name_or_path (str): The name of or local path to
214
+ the HF Causal LM (e.g., `gpt2` to instantiate a GPT2LMHeadModel).
215
+ config_overrides (dict, optional): An optional dictionary of keyword
216
+ arguments that override the default configuration associated with
217
+ cfg.pretrained_model_name_or_path.
218
+ pretrained (bool): Whether to instantiate the model with pre-trained
219
+ weights coming from cfg.pretrained_model_name_or_path. If ``True``,
220
+ cfg.config_overrides must be compatible with the pre-trained weights.
221
+ init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to
222
+ initialize the model on. Currently, `meta` is only supported when
223
+ cfg.pretrained is ``False``. Default: ``'cpu'``.
224
+ peft_config (dict, optional): An optional dictionary of keyword arguments to be
225
+ passed to the PeftConfig constructor. If provided, the model will be wrapped in a PeftModel.
226
+ trust_remote_code (bool, optional): Whether to trust remote code when loading from Hugging Face
227
+ Hub. Default: ``True``.
228
+ use_auth_token (bool, optional): Whether to use the Hugging Face authentication token when
229
+ loading from Hugging Face Hub. Default: ``False``.
230
+ use_train_metrics (bool, optional): Whether to use training metrics. Default: ``True``.
231
+ load_in_8bit (bool, optional): Whether to load the model in 8-bit mode. Default: ``False``.
232
+ init_device (str, optional): Which device to initialize the model on. Default: ``'cpu'``.
233
+ use_flash_attention_2 (bool, optional): Whether to use flash-attention 2. Default: ``False``.
234
+ tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
235
+ """
236
+
237
+ def __init__(
238
+ self,
239
+ tokenizer: PreTrainedTokenizerBase,
240
+ pretrained_model_name_or_path: str,
241
+ pretrained: bool = True,
242
+ pretrained_lora_id_or_path: Optional[str] = None,
243
+ trust_remote_code: bool = True,
244
+ use_auth_token: bool = False,
245
+ use_flash_attention_2: bool = False,
246
+ load_in_8bit: bool = False,
247
+ init_device: str = 'cpu',
248
+ config_overrides: Optional[Dict[str, Any]] = None,
249
+ peft_config: Optional[Dict[str, Any]] = None,
250
+ use_train_metrics: bool = True,
251
+ additional_train_metrics: Optional[List] = None,
252
+ additional_eval_metrics: Optional[List] = None,
253
+ return_lm_logits: Optional[bool] = False,
254
+ ):
255
+
256
+ config_overrides = config_overrides or {}
257
+
258
+ model = ComposerHFSequenceClassification.build_inner_model(
259
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
260
+ pretrained_lora_id_or_path=pretrained_lora_id_or_path,
261
+ trust_remote_code=trust_remote_code,
262
+ init_device=init_device,
263
+ use_flash_attention_2=use_flash_attention_2,
264
+ use_auth_token=use_auth_token,
265
+ config_overrides=config_overrides,
266
+ load_in_8bit=load_in_8bit,
267
+ pretrained=pretrained,
268
+ prepare_for_fsdp=True,
269
+ return_lm_logits=return_lm_logits,
270
+ )
271
+
272
+ train_metrics, eval_metrics = ComposerHFSequenceClassification.build_metrics(
273
+ use_train_metrics=use_train_metrics,
274
+ additional_train_metrics=additional_train_metrics,
275
+ additional_eval_metrics=additional_eval_metrics,
276
+ )
277
+
278
+ if peft_config is not None and not peft_installed:
279
+ raise NotImplementedError("PEFT is not supported")
280
+
281
+ peft_config_object = None
282
+ if peft_config is not None:
283
+ peft_config_object = self._get_peft_config(peft_config)
284
+
285
+ # Set up config args for the model construction and base classes
286
+ super().__init__(
287
+ model=model,
288
+ shift_labels=True,
289
+ tokenizer=tokenizer,
290
+ metrics=train_metrics,
291
+ eval_metrics=eval_metrics,
292
+ peft_config=peft_config_object,
293
+ allow_embedding_resizing=True,
294
+ )
295
+ #self.model.config.vocab_size = len(self.tokenizer)
296
+ #self.model.config.base_config.vocab_size = len(self.tokenizer)
297
+ self.model.config.pretrained = False
298
+
299
+ @staticmethod
300
+ def build_metrics(
301
+ use_train_metrics: bool,
302
+ additional_train_metrics: Optional[List[str]] = None,
303
+ additional_eval_metrics: Optional[List[str]] = None,
304
+ ) -> Tuple[List[Metric], List[Metric]]:
305
+ """Builds the training and evaluation metrics for the model.
306
+
307
+ Args:
308
+ use_train_metrics (bool): Whether to use training metrics.
309
+ additional_train_metrics (Optional[List[str]]): Additional training metrics to include.
310
+ additional_eval_metrics (Optional[List[str]]): Additional evaluation metrics to include.
311
+
312
+ Returns:
313
+ Tuple[List[Metric], List[Metric]]: A tuple containing the list of training metrics and evaluation metrics.
314
+ """
315
+ from llmfoundry.utils.builders import build_metric
316
+ train_metric_names = additional_train_metrics if additional_train_metrics is not None else []
317
+ eval_metric_names = additional_eval_metrics if additional_eval_metrics is not None else []
318
+ train_metrics = [
319
+ build_metric(metric, {}) for metric in train_metric_names
320
+ ] if use_train_metrics else []
321
+ eval_metrics = [
322
+ build_metric(metric, {}) for metric in eval_metric_names
323
+ ]
324
+ return train_metrics, eval_metrics
325
+
326
+ @staticmethod
327
+ def build_inner_model(
328
+ pretrained_model_name_or_path: str,
329
+ pretrained_lora_id_or_path: Optional[str],
330
+ trust_remote_code: bool,
331
+ init_device: str,
332
+ use_flash_attention_2: bool,
333
+ use_auth_token: bool,
334
+ config_overrides: Dict[str, Any],
335
+ load_in_8bit: bool,
336
+ pretrained: bool,
337
+ prepare_for_fsdp: bool = False,
338
+ return_lm_logits: bool = False,
339
+ ) -> Union[PreTrainedModel, 'PeftModel']:
340
+ """Builds the inner model for the ComposerHFCausalLM.
341
+
342
+ Args:
343
+ pretrained_model_name_or_path (str): The pretrained model name or path.
344
+ pretrained_lora_id_or_path (Optional[str]): The pretrained LORA ID or path.
345
+ trust_remote_code (bool): Whether to trust remote code.
346
+ init_device (str): The initialization device.
347
+ use_flash_attention_2 (bool): Whether to use flash attention 2.
348
+ use_auth_token (bool): Whether to use an authentication token.
349
+ config_overrides (Dict[str, Any]): The configuration overrides.
350
+ load_in_8bit (bool): Whether to load in 8-bit.
351
+ prepare_for_fsdp (bool, optional): Whether to prepare the model for FSDP wrapping. Default: False.
352
+
353
+ Returns:
354
+ Union[PreTrainedModel, 'PeftModel']: The built inner model.
355
+ prepare_for_fsdp (bool): Whether to prepare the model for FSDP wrapping. Default: ``False``.
356
+ """
357
+ if not trust_remote_code and pretrained_model_name_or_path.startswith(
358
+ 'mosaicml/mpt',
359
+ ):
360
+ raise ValueError(
361
+ 'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, '
362
+ +
363
+ 'which is significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.',
364
+ )
365
+ # Resolve "mixed" init device to either "cpu" or "meta"
366
+ resolved_init_device = hf_get_init_device(init_device)
367
+ requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager'
368
+
369
+ if use_flash_attention_2 and not is_flash_v2_installed():
370
+ raise ValueError(
371
+ 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
372
+ + 'Please `pip install llm-foundry[gpu]`.',
373
+ )
374
+
375
+ # Construct the Hugging Face config to use
376
+ base_config = AutoConfig.from_pretrained(
377
+ pretrained_model_name_or_path,
378
+ trust_remote_code=trust_remote_code,
379
+ token=True,
380
+ attn_implementation=requested_attention_implementation,
381
+ use_cache=False, # Necessary due to https://github.com/huggingface/transformers/issues/28056
382
+ #num_hidden_layers=2, hidden_dim=128, # For Testing
383
+ )
384
+
385
+ config = RewardModelConfig(
386
+ base_model=pretrained_model_name_or_path,
387
+ base_config=base_config,
388
+ hidden_size=base_config.hidden_size,
389
+ torch_dtype=base_config.torch_dtype,
390
+ return_logits=return_lm_logits,
391
+ vocab_size=base_config.vocab_size,
392
+ )
393
+
394
+
395
+ # This is not ideal, however Hugging Face's _autoset_attn_implementation function
396
+ # forces you to load the model in fp16/bf16 if you want to use flash attention. Rather than loading
397
+ # the model and then casting it back to fp32, we are monkeypatching their check.
398
+ # https://github.com/huggingface/transformers/issues/28052
399
+ def _autoset_attn_implementation_monkeypatch(
400
+ cls, # type: ignore
401
+ config, # type: ignore
402
+ *args, # type: ignore
403
+ **kwargs, # type: ignore
404
+ ): # type: ignore
405
+ config._attn_implementation = requested_attention_implementation
406
+ return config
407
+
408
+ PreTrainedModel._autoset_attn_implementation = classmethod(
409
+ _autoset_attn_implementation_monkeypatch,
410
+ )
411
+
412
+ # set config overrides
413
+ for k, v in config_overrides.items():
414
+ if not hasattr(config, k):
415
+ raise ValueError(
416
+ f'config does not have attribute "{k}" to override ({k}: {v}).',
417
+ )
418
+
419
+ attr = getattr(config, k)
420
+ # attempt to disallow typos in nested configs
421
+ if isinstance(attr, Mapping):
422
+ extra_keys = [_k for _k in v.keys() if _k not in attr.keys()]
423
+ if extra_keys:
424
+ raise ValueError(
425
+ f'Config dict override got unknown keys. ' +
426
+ f'Extra keys: {extra_keys}. ' +
427
+ f'Expected (a subset of) keys: {list(attr.keys())}.',
428
+ )
429
+ getattr(config, k).update(v)
430
+ # necessary case to allow for rope_scaling to be overriden in llama config
431
+ elif attr is None and isinstance(v, Mapping):
432
+ setattr(config, k, {})
433
+ getattr(config, k).update(v)
434
+ elif isinstance(attr, PretrainedConfig):
435
+ if not isinstance(v, Mapping):
436
+ raise ValueError(
437
+ f'Expected a dictionary for config override {k}, but got {v}.',
438
+ )
439
+
440
+ for _k, _v in v.items():
441
+ if not hasattr(attr, _k):
442
+ raise ValueError(
443
+ f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).',
444
+ )
445
+ setattr(attr, _k, _v)
446
+ else:
447
+ setattr(config, k, v)
448
+
449
+ if hasattr(config, 'attn_config') and get_hf_config_value(
450
+ config.attn_config,
451
+ 'seq_parallel_world_size',
452
+ ) is not None:
453
+ raise NotImplementedError(
454
+ 'Sequence Parallelism is not supported for HuggingFace models.',
455
+ )
456
+
457
+ # We need to have all non-zero local ranks be not-pretrained
458
+ # Rank 0 will still be pretrained, and distribute the weights appropriately
459
+ if dist.get_local_rank() != 0 and init_device == 'mixed':
460
+ pretrained = False
461
+
462
+ # Hugging Face copies the modules into the
463
+ # transformers modules cache. On particular systems, this operation seems to cause contention between
464
+ # the different processes. To avoid this contention, we first create the model (on meta device) on local rank
465
+ # zero. This will set up the transformers model cache and avoid the future contention.
466
+ if dist.get_local_rank() == 0:
467
+ if os.path.isdir(pretrained_model_name_or_path):
468
+ with init_empty_weights(include_buffers=False):
469
+ with warnings.catch_warnings():
470
+ warnings.simplefilter('ignore', UserWarning)
471
+ AutoModelForCausalLM.from_pretrained(
472
+ pretrained_model_name_or_path,
473
+ trust_remote_code=trust_remote_code,
474
+ token=True,
475
+ config=base_config,
476
+ )
477
+ else:
478
+ with init_empty_weights(include_buffers=False):
479
+ AutoModelForCausalLM.from_config(
480
+ base_config,
481
+ trust_remote_code=trust_remote_code,
482
+ )
483
+
484
+ dist.barrier()
485
+
486
+ # initialize the model on the correct device
487
+ config.pretrained = pretrained
488
+ if resolved_init_device == 'cpu':
489
+ if pretrained:
490
+ config.pretrain_cfg = {
491
+ "trust_remote_code": trust_remote_code,
492
+ "token": True,
493
+ "load_in_8bit": load_in_8bit,
494
+ }
495
+ model = AutoModelForCausalLMWithRM(config)
496
+ else:
497
+ config.pretrain_cfg = {
498
+ "trust_remote_code": trust_remote_code,
499
+ }
500
+ model = AutoModelForCausalLMWithRM(config)
501
+ elif resolved_init_device == 'meta':
502
+ if pretrained:
503
+ raise ValueError(
504
+ 'Setting cfg.pretrained=True is not supported when init_device="meta".',
505
+ )
506
+ with init_empty_weights(include_buffers=False):
507
+ config.pretrain_cfg = {
508
+ "trust_remote_code": trust_remote_code,
509
+ }
510
+ model = AutoModelForCausalLMWithRM(config)
511
+ else:
512
+ raise ValueError(
513
+ f'init_device="{init_device}" must be either "cpu" or "meta".',
514
+ )
515
+
516
+ signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed'
517
+ if dist.get_local_rank() == 0:
518
+ with open(signal_file_path, 'wb') as f:
519
+ f.write(b'local_rank0_completed_download')
520
+
521
+ # Avoid the collective call until the local rank zero has finished trying to download the checkpoint
522
+ # so that we don't timeout for large downloads. This syncs all processes on the node
523
+ with dist.local_rank_zero_download_and_wait(signal_file_path):
524
+ # Then, wait to ensure every node has finished downloading the checkpoint
525
+ dist.barrier()
526
+
527
+ if dist.get_local_rank() == 0:
528
+ os.remove(signal_file_path)
529
+
530
+ # Hugging Face's weight tying does not succeed if the model is inited on meta device
531
+ # so we manually apply the weight tying here
532
+ if model.config.tie_word_embeddings and resolved_init_device == 'meta':
533
+ model.tie_weights()
534
+
535
+ if pretrained_lora_id_or_path is not None:
536
+ """TODO not supported"""
537
+ raise NotImplementedError("PEFT IS NOT SUPPORTED")
538
+
539
+ if prepare_for_fsdp:
540
+ # Note: We need to add the FSDP related attributes to the model AFTER the super init,
541
+ # so that the (possible) embedding resizing doesn't destroy them
542
+ prepare_hf_sequence_classification_model_for_fsdp(model, init_device)
543
+
544
+ # This provides support for meta initialization when using FSDP
545
+ model.param_init_fn = lambda module: model._init_weights(module)
546
+ return model