Text Classification
Safetensors
gemma2
custom_code
Ray2333 commited on
Commit
b65f8a1
·
verified ·
1 Parent(s): 075964b

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +210 -0
model.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Huggingface trl package AutoModelForCausalLMWithValueHead class
2
+ # Enabling better customization for generalizable reward modeling
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoModelForCausalLM
6
+ from trl import PreTrainedModelWrapper
7
+
8
+
9
+ class ValueHead(nn.Module):
10
+ def __init__(self, config, **kwargs):
11
+ super().__init__()
12
+ if not hasattr(config, "summary_dropout_prob"):
13
+ summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1)
14
+ else:
15
+ summary_dropout_prob = config.summary_dropout_prob
16
+
17
+ self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity()
18
+
19
+ # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m
20
+ if hasattr(config, "hidden_size"):
21
+ hidden_size = config.hidden_size
22
+ if hasattr(config, "word_embed_proj_dim"):
23
+ hidden_size = config.word_embed_proj_dim
24
+ elif hasattr(config, "is_encoder_decoder"):
25
+ if config.is_encoder_decoder and hasattr(config, "decoder"):
26
+ if hasattr(config.decoder, "hidden_size"):
27
+ hidden_size = config.decoder.hidden_size
28
+
29
+ # get vhead config
30
+ if hasattr(config, "vhead_layer_type"): # config from json first
31
+ self.layer_type = config.vhead_layer_type
32
+ else:
33
+ self.layer_type = kwargs.pop("vhead_layer_type", 'mlp')
34
+ if hasattr(config, 'vhead_num_neurons'):
35
+ num_neurons = config.vhead_num_neurons
36
+ else:
37
+ num_neurons = kwargs.pop("vhead_num_neurons", 1024)
38
+ if hasattr(config, 'vhead_num_layers'):
39
+ num_layers = config.vhead_num_layers
40
+ else:
41
+ num_layers = kwargs.pop("vhead_num_layers", 1)
42
+
43
+ if self.layer_type == 'linear':
44
+ self.summary = nn.Linear(hidden_size, 1)
45
+ else:
46
+ module_lis = []
47
+ input_neurons = hidden_size
48
+ for i in range(num_layers):
49
+ module_lis.extend([nn.Linear(input_neurons, num_neurons), nn.ReLU()])
50
+ input_neurons = num_neurons
51
+
52
+ module_lis.append(nn.Linear(num_neurons, 1))
53
+ self.summary = nn.Sequential(*module_lis)
54
+ self.flatten = nn.Flatten()
55
+
56
+ def forward(self, hidden_states):
57
+ output = self.dropout(hidden_states)
58
+ if (self.layer_type == 'linear' and output.dtype != self.summary.weight.dtype):
59
+ output = output.to(self.summary.weight.dtype)
60
+ elif (self.layer_type != 'linear' and output.dtype != self.summary[0].weight.dtype):
61
+ output = output.to(self.summary[0].weight.dtype)
62
+
63
+ output = self.summary(output)
64
+ return output
65
+
66
+
67
+ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
68
+ transformers_parent_class = AutoModelForCausalLM
69
+ lm_head_namings = ["lm_head", "embed_out"]
70
+ supported_args = (
71
+ "summary_dropout_prob",
72
+ "v_head_initializer_range",
73
+ "v_head_init_strategy",
74
+ "layer_type",
75
+ 'num_neurons',
76
+ 'num_layers',
77
+ )
78
+
79
+ def __init__(self, pretrained_model, **kwargs):
80
+ r"""
81
+ Initializes the model.
82
+ """
83
+ super().__init__(pretrained_model, **kwargs)
84
+ v_head_kwargs, _, _ = self._split_kwargs(kwargs)
85
+
86
+ if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
87
+ raise ValueError("The model does not have a language model head, please use a model that has one.")
88
+
89
+ self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
90
+ self._init_weights(**v_head_kwargs)
91
+
92
+ def _init_weights(self, **kwargs):
93
+ r"""
94
+ Initializes the weights of the value head.
95
+ """
96
+ initializer_range = kwargs.pop("v_head_initializer_range", 0.2)
97
+ # random init by default
98
+ init_strategy = kwargs.pop("v_head_init_strategy", None)
99
+ if init_strategy is None:
100
+ # do nothing
101
+ pass
102
+ elif init_strategy == "normal":
103
+ self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range)
104
+ self.v_head.summary.bias.data.zero_()
105
+
106
+ def forward(
107
+ self,
108
+ input_ids=None,
109
+ past_key_values=None,
110
+ attention_mask=None,
111
+ **kwargs,
112
+ ):
113
+ kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples
114
+ kwargs["past_key_values"] = past_key_values
115
+
116
+ if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
117
+ kwargs.pop("past_key_values")
118
+
119
+ base_model_output = self.pretrained_model(
120
+ input_ids=input_ids,
121
+ attention_mask=attention_mask,
122
+ **kwargs,
123
+ )
124
+
125
+ last_hidden_state = base_model_output.hidden_states[-1]
126
+ lm_logits = base_model_output.logits
127
+ loss = base_model_output.loss
128
+
129
+ if (hasattr(self.v_head.summary, 'weight') and last_hidden_state.device != self.v_head.summary.weight.device):
130
+ last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)
131
+ elif not hasattr(self.v_head.summary, 'weight') and (last_hidden_state.device != self.v_head.summary[0].weight.device):
132
+ last_hidden_state = last_hidden_state.to(self.v_head.summary[0].weight.device)
133
+
134
+ # use the last token value as reward
135
+ if torch.any(attention_mask[:, 0] == 0):
136
+ # left padding
137
+ last_index = attention_mask.shape[-1] - 1
138
+ else:
139
+ # right padding
140
+ last_index = attention_mask.sum(dim=-1) - 1
141
+ value = self.v_head(last_hidden_state).squeeze(-1)[torch.arange(len(last_hidden_state)), last_index]
142
+
143
+ # force upcast in fp32 if logits are in half-precision
144
+ if lm_logits.dtype != torch.float32:
145
+ lm_logits = lm_logits.float()
146
+
147
+ return (lm_logits, loss, value)
148
+
149
+ def generate(self, *args, **kwargs):
150
+ return self.pretrained_model.generate(*args, **kwargs)
151
+
152
+ def state_dict(self, *args, **kwargs):
153
+ pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs)
154
+
155
+ v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
156
+ for k, v in v_head_state_dict.items():
157
+ pretrained_model_state_dict[f"v_head.{k}"] = v
158
+ return pretrained_model_state_dict
159
+
160
+ def push_to_hub(self, *args, **kwargs):
161
+ setattr(self.pretrained_model, "v_head", self.v_head)
162
+ return self.pretrained_model.push_to_hub(*args, **kwargs)
163
+
164
+
165
+
166
+ def post_init(self, state_dict):
167
+ for k in list(state_dict.keys()):
168
+ if "v_head." in k:
169
+ state_dict[k.replace("v_head.", "")] = state_dict.pop(k)
170
+ self.v_head.load_state_dict(state_dict, strict=False)
171
+ del state_dict
172
+
173
+ if hasattr(self.pretrained_model, "hf_device_map"):
174
+ if (
175
+ "cpu" in self.pretrained_model.hf_device_map.values()
176
+ or "disk" in self.pretrained_model.hf_device_map.values()
177
+ ):
178
+ raise ValueError(
179
+ "The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models."
180
+ )
181
+
182
+ first_device = list(set(self.pretrained_model.hf_device_map.values()))[0]
183
+
184
+ self.v_head = self.v_head.to(first_device)
185
+
186
+ def set_device_hook(module, input, outputs):
187
+ new_output = ()
188
+ for output in outputs:
189
+ if isinstance(output, torch.Tensor):
190
+ new_output += (output.to(first_device),)
191
+ else:
192
+ new_output += (output,)
193
+ return new_output
194
+
195
+ self.register_forward_hook(set_device_hook)
196
+
197
+ self.is_sequential_parallel = True
198
+
199
+ @classmethod
200
+ def register_for_auto_class(cls, auto_class="AutoModel"):
201
+ if not isinstance(auto_class, str):
202
+ auto_class = auto_class.__name__
203
+
204
+ import transformers.models.auto as auto_module
205
+
206
+ if not hasattr(auto_module, auto_class):
207
+ raise ValueError(f"{auto_class} is not a valid auto class.")
208
+
209
+ cls._auto_class = auto_class
210
+