ammarnasr commited on
Commit
cc9ca31
·
verified ·
1 Parent(s): 671bc20

Upload model

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +35 -0
  3. configuration_t5mimo.py +152 -0
  4. model.safetensors +3 -0
  5. modeling_t5mimo.py +1745 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5MIMOModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_t5mimo.T5MIMOConfig",
7
+ "AutoModel": "modeling_t5mimo.T5MIMOModel"
8
+ },
9
+ "classifier_dropout": 0.0,
10
+ "d_ff": 1024,
11
+ "d_kv": 64,
12
+ "d_model": 256,
13
+ "decoder_start_token_id": 0,
14
+ "dense_act_fn": "relu",
15
+ "dropout_rate": 0.1,
16
+ "eos_token_id": 1,
17
+ "feed_forward_proj": "relu",
18
+ "initializer_factor": 0.05,
19
+ "is_encoder_decoder": true,
20
+ "is_gated_act": false,
21
+ "layer_norm_epsilon": 1e-06,
22
+ "model_type": "t5mimo",
23
+ "num_decoder_layers": 4,
24
+ "num_filters": 64,
25
+ "num_heads": 4,
26
+ "num_layers": 4,
27
+ "num_seqs": 3,
28
+ "pad_token_id": 0,
29
+ "relative_attention_max_distance": 128,
30
+ "relative_attention_num_buckets": 32,
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.41.1",
33
+ "use_cache": true,
34
+ "vocab_size": 4096
35
+ }
configuration_t5mimo.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Mapping
2
+ from transformers.configuration_utils import PretrainedConfig
3
+ from transformers.onnx import OnnxSeq2SeqConfigWithPast
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class T5MIMOConfig(PretrainedConfig):
11
+ r"""
12
+ This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to
13
+ instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a
14
+ configuration with the defaults will yield a similar configuration to that of the T5
15
+ [google-t5/t5-small](https://huggingface.co/google-t5/t5-small) architecture.
16
+
17
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
18
+ documentation from [`PretrainedConfig`] for more information.
19
+
20
+ Arguments:
21
+ vocab_size (`int`, *optional*, defaults to 32128):
22
+ Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
23
+ `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
24
+ d_model (`int`, *optional*, defaults to 512):
25
+ Size of the encoder layers and the pooler layer.
26
+ d_kv (`int`, *optional*, defaults to 64):
27
+ Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will
28
+ be defined as `num_heads * d_kv`.
29
+ d_ff (`int`, *optional*, defaults to 2048):
30
+ Size of the intermediate feed forward layer in each `T5Block`.
31
+ num_layers (`int`, *optional*, defaults to 6):
32
+ Number of hidden layers in the Transformer encoder.
33
+ num_decoder_layers (`int`, *optional*):
34
+ Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
35
+ num_heads (`int`, *optional*, defaults to 8):
36
+ Number of attention heads for each attention layer in the Transformer encoder.
37
+ relative_attention_num_buckets (`int`, *optional*, defaults to 32):
38
+ The number of buckets to use for each attention layer.
39
+ relative_attention_max_distance (`int`, *optional*, defaults to 128):
40
+ The maximum distance of the longer sequences for the bucket separation.
41
+ dropout_rate (`float`, *optional*, defaults to 0.1):
42
+ The ratio for all dropout layers.
43
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
44
+ The dropout ratio for classifier.
45
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
46
+ The epsilon used by the layer normalization layers.
47
+ initializer_factor (`float`, *optional*, defaults to 1):
48
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
49
+ testing).
50
+ feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
51
+ Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the
52
+ `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`.
53
+ use_cache (`bool`, *optional*, defaults to `True`):
54
+ Whether or not the model should return the last key/values attentions (not used by all models).
55
+ """
56
+
57
+ model_type = "t5mimo"
58
+ keys_to_ignore_at_inference = ["past_key_values"]
59
+ attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
60
+
61
+ def __init__(
62
+ self,
63
+ vocab_size=32128,
64
+ d_model=512,
65
+ d_kv=64,
66
+ d_ff=2048,
67
+ num_layers=6,
68
+ num_decoder_layers=None,
69
+ num_heads=8,
70
+ relative_attention_num_buckets=32,
71
+ relative_attention_max_distance=128,
72
+ dropout_rate=0.1,
73
+ layer_norm_epsilon=1e-6,
74
+ initializer_factor=1.0,
75
+ feed_forward_proj="relu",
76
+ is_encoder_decoder=True,
77
+ use_cache=True,
78
+ pad_token_id=0,
79
+ eos_token_id=1,
80
+ decoder_start_token_id = 0,
81
+ classifier_dropout=0.0,
82
+ num_seqs=3,
83
+ num_filters=64,
84
+ **kwargs,
85
+ ):
86
+ self.vocab_size = vocab_size
87
+ self.d_model = d_model
88
+ self.d_kv = d_kv
89
+ self.d_ff = d_ff
90
+ self.num_layers = num_layers
91
+ self.num_decoder_layers = (
92
+ num_decoder_layers if num_decoder_layers is not None else self.num_layers
93
+ ) # default = symmetry
94
+ self.num_heads = num_heads
95
+ self.relative_attention_num_buckets = relative_attention_num_buckets
96
+ self.relative_attention_max_distance = relative_attention_max_distance
97
+ self.dropout_rate = dropout_rate
98
+ self.classifier_dropout = classifier_dropout
99
+ self.layer_norm_epsilon = layer_norm_epsilon
100
+ self.initializer_factor = initializer_factor
101
+ self.feed_forward_proj = feed_forward_proj
102
+ self.use_cache = use_cache
103
+ self.num_seqs = num_seqs
104
+ self.num_filters = num_filters
105
+
106
+ act_info = self.feed_forward_proj.split("-")
107
+ self.dense_act_fn = act_info[-1]
108
+ self.is_gated_act = act_info[0] == "gated"
109
+
110
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
111
+ raise ValueError(
112
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
113
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
114
+ "'gated-gelu' or 'relu'"
115
+ )
116
+
117
+ # for backwards compatibility
118
+ if feed_forward_proj == "gated-gelu":
119
+ self.dense_act_fn = "gelu_new"
120
+
121
+ super().__init__(
122
+ pad_token_id=pad_token_id,
123
+ eos_token_id=eos_token_id,
124
+ decoder_start_token_id=decoder_start_token_id,
125
+ is_encoder_decoder=is_encoder_decoder,
126
+ **kwargs,
127
+ )
128
+
129
+
130
+ class T5MIMOOnnxConfig(OnnxSeq2SeqConfigWithPast):
131
+ @property
132
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
133
+ common_inputs = {
134
+ "input_ids": {0: "batch", 1: "encoder_sequence"},
135
+ "attention_mask": {0: "batch", 1: "encoder_sequence"},
136
+ }
137
+ if self.use_past:
138
+ common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
139
+ common_inputs["decoder_input_ids"] = {0: "batch"}
140
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
141
+ else:
142
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
143
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
144
+
145
+ if self.use_past:
146
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
147
+
148
+ return common_inputs
149
+
150
+ @property
151
+ def default_onnx_opset(self) -> int:
152
+ return 13
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9d97da94a794f0b0aad1566b9e13205267ea4b3b70ae8c6cd147e6fe6e651cb
3
+ size 33588312
modeling_t5mimo.py ADDED
@@ -0,0 +1,1745 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import warnings
4
+ from typing import Optional, Tuple, Union
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers.activations import ACT2FN
9
+ from transformers.modeling_outputs import (
10
+ BaseModelOutput,
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ Seq2SeqLMOutput,
13
+ Seq2SeqModelOutput,
14
+ )
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
17
+ from transformers.utils import (
18
+ DUMMY_INPUTS,
19
+ DUMMY_MASK,
20
+ is_torch_fx_proxy,
21
+ logging,
22
+ )
23
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
24
+ from .configuration_t5mimo import T5MIMOConfig
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+
31
+ class T5LayerNorm(nn.Module):
32
+ def __init__(self, hidden_size, eps=1e-6):
33
+ """
34
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
35
+ """
36
+ super().__init__()
37
+ self.weight = nn.Parameter(torch.ones(hidden_size))
38
+ self.variance_epsilon = eps
39
+
40
+ def forward(self, hidden_states):
41
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
42
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
43
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
44
+ # half-precision inputs is done in fp32
45
+
46
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
47
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
48
+
49
+ # convert into half-precision if necessary
50
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
51
+ hidden_states = hidden_states.to(self.weight.dtype)
52
+
53
+ return self.weight * hidden_states
54
+
55
+
56
+ ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
57
+
58
+
59
+ class T5DenseActDense(nn.Module):
60
+ def __init__(self, config: T5MIMOConfig):
61
+ super().__init__()
62
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
63
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
64
+ self.dropout = nn.Dropout(config.dropout_rate)
65
+ self.act = ACT2FN[config.dense_act_fn]
66
+
67
+ def forward(self, hidden_states):
68
+ hidden_states = self.wi(hidden_states)
69
+ hidden_states = self.act(hidden_states)
70
+ hidden_states = self.dropout(hidden_states)
71
+ if (
72
+ isinstance(self.wo.weight, torch.Tensor)
73
+ and hidden_states.dtype != self.wo.weight.dtype
74
+ and self.wo.weight.dtype != torch.int8
75
+ ):
76
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
77
+ hidden_states = self.wo(hidden_states)
78
+ return hidden_states
79
+
80
+
81
+ class T5DenseGatedActDense(nn.Module):
82
+ def __init__(self, config: T5MIMOConfig):
83
+ super().__init__()
84
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
85
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
86
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
87
+ self.dropout = nn.Dropout(config.dropout_rate)
88
+ self.act = ACT2FN[config.dense_act_fn]
89
+
90
+ def forward(self, hidden_states):
91
+ hidden_gelu = self.act(self.wi_0(hidden_states))
92
+ hidden_linear = self.wi_1(hidden_states)
93
+ hidden_states = hidden_gelu * hidden_linear
94
+ hidden_states = self.dropout(hidden_states)
95
+
96
+ # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
97
+ # See https://github.com/huggingface/transformers/issues/20287
98
+ # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
99
+ if (
100
+ isinstance(self.wo.weight, torch.Tensor)
101
+ and hidden_states.dtype != self.wo.weight.dtype
102
+ and self.wo.weight.dtype != torch.int8
103
+ ):
104
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
105
+
106
+ hidden_states = self.wo(hidden_states)
107
+ return hidden_states
108
+
109
+
110
+ class T5LayerFF(nn.Module):
111
+ def __init__(self, config: T5MIMOConfig):
112
+ super().__init__()
113
+ if config.is_gated_act:
114
+ self.DenseReluDense = T5DenseGatedActDense(config)
115
+ else:
116
+ self.DenseReluDense = T5DenseActDense(config)
117
+
118
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
119
+ self.dropout = nn.Dropout(config.dropout_rate)
120
+
121
+ def forward(self, hidden_states):
122
+ forwarded_states = self.layer_norm(hidden_states)
123
+ forwarded_states = self.DenseReluDense(forwarded_states)
124
+ hidden_states = hidden_states + self.dropout(forwarded_states)
125
+ return hidden_states
126
+
127
+
128
+
129
+ class MultivariateConvBlock(nn.Module):
130
+ def __init__(self, config: T5MIMOConfig, kernel_size=3, stride=1, padding=1):
131
+ super().__init__()
132
+ # 2D Convolution across sequences and time
133
+ self.conv1 = nn.Conv2d(
134
+ in_channels=config.num_seqs,
135
+ out_channels=config.num_filters,
136
+ kernel_size=kernel_size, # Kernel spans across time and all features
137
+ stride=1, # Stride across time, no stride across features
138
+ padding=1 # Padding to preserve sequence length, no padding across features
139
+ )
140
+
141
+ # Batch normalization for stabilization and faster convergence
142
+ self.bn1 = nn.BatchNorm2d(config.num_filters)
143
+
144
+ # Second convolution layer to further model interactions and temporal patterns
145
+ self.conv2 = nn.Conv2d(
146
+ in_channels=config.num_filters,
147
+ out_channels=config.num_filters,
148
+ kernel_size=(kernel_size, 1), # Focus only on temporal patterns
149
+ stride=(stride, 1),
150
+ padding=(padding, 0)
151
+ )
152
+
153
+ # Batch normalization after second convolution
154
+ self.bn2 = nn.BatchNorm2d(config.num_filters)
155
+
156
+ # 1x1 Convolution to reduce the channel dimension back to num_seqs
157
+ self.conv3 = nn.Conv2d(
158
+ in_channels=config.num_filters,
159
+ out_channels=config.num_seqs, # Back to the original number of sequences (channels)
160
+ kernel_size=(1, 1)
161
+ )
162
+
163
+ def forward(self, x):
164
+ """
165
+ Forward pass of the multivariate convolutional block.
166
+
167
+ Args:
168
+ x (torch.Tensor): Input tensor of shape [batch_size, num_seqs, seq_len, model_dim].
169
+
170
+ Returns:
171
+ torch.Tensor: Output tensor of shape [batch_size, num_seqs, seq_len, model_dim].
172
+ """
173
+ # Permute to [batch_size, num_seqs, seq_len, model_dim] -> [batch_size, num_seqs, model_dim, seq_len]
174
+ x = x.permute(0, 1, 3, 2)
175
+
176
+ # Apply first convolution and activation
177
+ x = nn.functional.relu(self.bn1(self.conv1(x)))
178
+ # Apply second convolution and activation
179
+ x = nn.functional.relu(self.bn2(self.conv2(x)))
180
+
181
+ # Reduce channel dimension back to num_seqs
182
+ x = self.conv3(x)
183
+
184
+ # Permute back to original shape [batch_size, num_seqs, seq_len, model_dim]
185
+ x = x.permute(0, 1, 3, 2)
186
+
187
+ return x
188
+
189
+
190
+
191
+ class T5Attention(nn.Module):
192
+ def __init__(self, config: T5MIMOConfig, has_relative_attention_bias=False):
193
+ super().__init__()
194
+ self.is_decoder = config.is_decoder
195
+ self.has_relative_attention_bias = has_relative_attention_bias
196
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
197
+ self.relative_attention_max_distance = config.relative_attention_max_distance
198
+ self.d_model = config.d_model
199
+ self.key_value_proj_dim = config.d_kv
200
+ self.n_heads = config.num_heads
201
+ self.dropout = config.dropout_rate
202
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
203
+
204
+ # Mesh TensorFlow initialization to avoid scaling before softmax
205
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
206
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
207
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
208
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
209
+
210
+ if self.has_relative_attention_bias:
211
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
212
+ self.pruned_heads = set()
213
+ self.gradient_checkpointing = False
214
+
215
+ def prune_heads(self, heads):
216
+ if len(heads) == 0:
217
+ return
218
+ heads, index = find_pruneable_heads_and_indices(
219
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
220
+ )
221
+ # Prune linear layers
222
+ self.q = prune_linear_layer(self.q, index)
223
+ self.k = prune_linear_layer(self.k, index)
224
+ self.v = prune_linear_layer(self.v, index)
225
+ self.o = prune_linear_layer(self.o, index, dim=1)
226
+ # Update hyper params
227
+ self.n_heads = self.n_heads - len(heads)
228
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
229
+ self.pruned_heads = self.pruned_heads.union(heads)
230
+
231
+ @staticmethod
232
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
233
+ """
234
+ Adapted from Mesh Tensorflow:
235
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
236
+
237
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
238
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
239
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
240
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
241
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
242
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
243
+
244
+ Args:
245
+ relative_position: an int32 Tensor
246
+ bidirectional: a boolean - whether the attention is bidirectional
247
+ num_buckets: an integer
248
+ max_distance: an integer
249
+
250
+ Returns:
251
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
252
+ """
253
+ relative_buckets = 0
254
+ if bidirectional:
255
+ num_buckets //= 2
256
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
257
+ relative_position = torch.abs(relative_position)
258
+ else:
259
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
260
+ # now relative_position is in the range [0, inf)
261
+
262
+ # half of the buckets are for exact increments in positions
263
+ max_exact = num_buckets // 2
264
+ is_small = relative_position < max_exact
265
+
266
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
267
+ relative_position_if_large = max_exact + (
268
+ torch.log(relative_position.float() / max_exact)
269
+ / math.log(max_distance / max_exact)
270
+ * (num_buckets - max_exact)
271
+ ).to(torch.long)
272
+ relative_position_if_large = torch.min(
273
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
274
+ )
275
+
276
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
277
+ return relative_buckets
278
+
279
+ def compute_bias(self, query_length, key_length,multivar_dim=-1, device=None):
280
+ """Compute binned relative position bias"""
281
+ if device is None:
282
+ device = self.relative_attention_bias.weight.device
283
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
284
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
285
+ relative_position = memory_position - context_position # shape (query_length, key_length)
286
+ relative_position_bucket = self._relative_position_bucket(
287
+ relative_position, # shape (query_length, key_length)
288
+ bidirectional=(not self.is_decoder),
289
+ num_buckets=self.relative_attention_num_buckets,
290
+ max_distance=self.relative_attention_max_distance,
291
+ )
292
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
293
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
294
+ if multivar_dim !=-1: # shape (1, multivar_dim, num_heads, query_length, key_length) (copy across)
295
+ values = values.expand(1, multivar_dim, -1, -1, -1)
296
+
297
+ return values
298
+
299
+ def forward(
300
+ self,
301
+ hidden_states,
302
+ mask=None,
303
+ key_value_states=None,
304
+ position_bias=None,
305
+ past_key_value=None,
306
+ layer_head_mask=None,
307
+ query_length=None,
308
+ use_cache=False,
309
+ output_attentions=False,
310
+ ):
311
+ """
312
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
313
+ """
314
+ # Input is (batch_size, seq_length, dim)
315
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
316
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
317
+ if len(hidden_states.shape) == 3:
318
+ batch_size, seq_length = hidden_states.shape[:2]
319
+ else:
320
+ batch_size, seq_length = hidden_states.shape[0],hidden_states.shape[2]
321
+ multivar_dim = hidden_states.shape[1]
322
+ real_seq_length = seq_length
323
+
324
+ if past_key_value is not None:
325
+ if len(past_key_value) != 2:
326
+ raise ValueError(
327
+ f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
328
+ )
329
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
330
+
331
+ if len(hidden_states.shape) == 3:
332
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
333
+ else:
334
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[2]
335
+
336
+
337
+ def shape(states):
338
+ """projection"""
339
+ # states: torch.Size([3, 16, 512]) -> query_states: torch.Size([3, 8, 16, 64])
340
+ # states: torch.Size([3, 6, 16, 512]) -> query_states: torch.Size([3, 6, 8 , 16, 64])
341
+ if len(states.shape) == 3:
342
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
343
+ else:
344
+ return states.view(batch_size, multivar_dim, -1, self.n_heads, self.key_value_proj_dim).transpose(2, 3)
345
+
346
+
347
+ def unshape(states):
348
+ """reshape"""
349
+ if len(states.shape) == 4:
350
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
351
+ else:
352
+ return states.transpose(2, 3).contiguous().view(batch_size, multivar_dim, -1, self.inner_dim)
353
+
354
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
355
+ """projects hidden states correctly to key/query states"""
356
+ if key_value_states is None:
357
+ # self-attn
358
+ # (batch_size, n_heads, seq_length, dim_per_head)
359
+ hidden_states = shape(proj_layer(hidden_states))
360
+ elif past_key_value is None:
361
+ # cross-attn
362
+ # (batch_size, n_heads, seq_length, dim_per_head)
363
+ hidden_states = shape(proj_layer(key_value_states))
364
+
365
+ if past_key_value is not None:
366
+ if key_value_states is None:
367
+ # self-attn
368
+ # (batch_size, n_heads, key_length, dim_per_head)
369
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
370
+ elif past_key_value.shape[2] != key_value_states.shape[1]:
371
+ # checking that the `sequence_length` of the `past_key_value` is the same as
372
+ # the provided `key_value_states` to support prefix tuning
373
+ # cross-attn
374
+ # (batch_size, n_heads, seq_length, dim_per_head)
375
+ hidden_states = shape(proj_layer(key_value_states))
376
+ else:
377
+ # cross-attn
378
+ hidden_states = past_key_value
379
+ return hidden_states
380
+
381
+ # get query states
382
+ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
383
+
384
+
385
+ # get key/value states
386
+ key_states = project(
387
+ hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
388
+ )
389
+ value_states = project(
390
+ hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
391
+ )
392
+
393
+
394
+
395
+ # compute scores
396
+ if len(hidden_states.shape) == 3:
397
+ scores = torch.matmul(
398
+ query_states, key_states.transpose(3, 2)
399
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
400
+ else:
401
+ scores = torch.matmul(
402
+ query_states, key_states.transpose(4, 3)
403
+ )
404
+
405
+
406
+
407
+
408
+
409
+ if position_bias is None:
410
+ if not self.has_relative_attention_bias:
411
+
412
+ if len(hidden_states.shape) == 3:
413
+ position_bias = torch.zeros(
414
+ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
415
+ )
416
+ else:
417
+ position_bias = torch.zeros(
418
+ (1,multivar_dim, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
419
+ )
420
+ if self.gradient_checkpointing and self.training:
421
+ position_bias.requires_grad = True
422
+ else:
423
+
424
+ if len(hidden_states.shape) == 3:
425
+ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
426
+ else:
427
+ position_bias = self.compute_bias(real_seq_length, key_length,multivar_dim=multivar_dim, device=scores.device)
428
+
429
+ # if key and values are already calculated
430
+ # we want only the last query position bias
431
+ if past_key_value is not None:
432
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
433
+
434
+ if mask is not None:
435
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
436
+
437
+
438
+
439
+ if self.pruned_heads:
440
+ mask = torch.ones(position_bias.shape[1])
441
+ mask[list(self.pruned_heads)] = 0
442
+ position_bias_masked = position_bias[:, mask.bool()]
443
+ else:
444
+ position_bias_masked = position_bias
445
+
446
+
447
+ scores += position_bias_masked
448
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
449
+ scores
450
+ ) # (batch_size, n_heads, seq_length, key_length)
451
+ attn_weights = nn.functional.dropout(
452
+ attn_weights, p=self.dropout, training=self.training
453
+ ) # (batch_size, n_heads, seq_length, key_length)
454
+
455
+ # Mask heads if we want to
456
+ if layer_head_mask is not None:
457
+ attn_weights = attn_weights * layer_head_mask
458
+
459
+
460
+ if len(hidden_states.shape) == 3:
461
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
462
+ else:
463
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, multivar_dim, seq_length, dim)
464
+ attn_output = self.o(attn_output)
465
+
466
+
467
+ present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
468
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
469
+
470
+
471
+ if output_attentions:
472
+ outputs = outputs + (attn_weights,)
473
+
474
+ return outputs
475
+
476
+
477
+ class T5LayerSelfAttention(nn.Module):
478
+ def __init__(self, config, has_relative_attention_bias=False):
479
+ super().__init__()
480
+ self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
481
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
482
+ self.dropout = nn.Dropout(config.dropout_rate)
483
+
484
+ def forward(
485
+ self,
486
+ hidden_states,
487
+ attention_mask=None,
488
+ position_bias=None,
489
+ layer_head_mask=None,
490
+ past_key_value=None,
491
+ use_cache=False,
492
+ output_attentions=False,
493
+ ):
494
+ normed_hidden_states = self.layer_norm(hidden_states)
495
+ attention_output = self.SelfAttention(
496
+ normed_hidden_states,
497
+ mask=attention_mask,
498
+ position_bias=position_bias,
499
+ layer_head_mask=layer_head_mask,
500
+ past_key_value=past_key_value,
501
+ use_cache=use_cache,
502
+ output_attentions=output_attentions,
503
+ )
504
+
505
+ hidden_states = hidden_states + self.dropout(attention_output[0])
506
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
507
+ return outputs
508
+
509
+
510
+ class T5LayerCrossAttention(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
514
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
515
+ self.dropout = nn.Dropout(config.dropout_rate)
516
+
517
+ def forward(
518
+ self,
519
+ hidden_states,
520
+ key_value_states,
521
+ attention_mask=None,
522
+ position_bias=None,
523
+ layer_head_mask=None,
524
+ past_key_value=None,
525
+ use_cache=False,
526
+ query_length=None,
527
+ output_attentions=False,
528
+ ):
529
+
530
+ normed_hidden_states = self.layer_norm(hidden_states)
531
+ attention_output = self.EncDecAttention(
532
+ normed_hidden_states,
533
+ mask=attention_mask,
534
+ key_value_states=key_value_states,
535
+ position_bias=position_bias,
536
+ layer_head_mask=layer_head_mask,
537
+ past_key_value=past_key_value,
538
+ use_cache=use_cache,
539
+ query_length=query_length,
540
+ output_attentions=output_attentions,
541
+ )
542
+ layer_output = hidden_states + self.dropout(attention_output[0])
543
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
544
+ return outputs
545
+
546
+
547
+ class T5Block(nn.Module):
548
+ def __init__(self, config, has_relative_attention_bias=False):
549
+ super().__init__()
550
+ self.is_decoder = config.is_decoder
551
+ self.layer = nn.ModuleList()
552
+ self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
553
+ if self.is_decoder:
554
+ self.layer.append(T5LayerCrossAttention(config))
555
+
556
+ self.layer.append(T5LayerFF(config))
557
+
558
+ def forward(
559
+ self,
560
+ hidden_states,
561
+ attention_mask=None,
562
+ position_bias=None,
563
+ encoder_hidden_states=None,
564
+ encoder_attention_mask=None,
565
+ encoder_decoder_position_bias=None,
566
+ layer_head_mask=None,
567
+ cross_attn_layer_head_mask=None,
568
+ past_key_value=None,
569
+ use_cache=False,
570
+ output_attentions=False,
571
+ return_dict=True,
572
+ ):
573
+ if past_key_value is not None:
574
+ if not self.is_decoder:
575
+ logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
576
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
577
+
578
+ if len(past_key_value) != expected_num_past_key_values:
579
+ raise ValueError(
580
+ f"There should be {expected_num_past_key_values} past states. "
581
+ f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
582
+ f"Got {len(past_key_value)} past key / value states"
583
+ )
584
+
585
+ self_attn_past_key_value = past_key_value[:2]
586
+ cross_attn_past_key_value = past_key_value[2:]
587
+ else:
588
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
589
+
590
+ self_attention_outputs = self.layer[0](
591
+ hidden_states,
592
+ attention_mask=attention_mask,
593
+ position_bias=position_bias,
594
+ layer_head_mask=layer_head_mask,
595
+ past_key_value=self_attn_past_key_value,
596
+ use_cache=use_cache,
597
+ output_attentions=output_attentions,
598
+ )
599
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
600
+ attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
601
+
602
+ # clamp inf values to enable fp16 training
603
+ if hidden_states.dtype == torch.float16:
604
+ clamp_value = torch.where(
605
+ torch.isinf(hidden_states).any(),
606
+ torch.finfo(hidden_states.dtype).max - 1000,
607
+ torch.finfo(hidden_states.dtype).max,
608
+ )
609
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
610
+
611
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
612
+ if do_cross_attention:
613
+ # the actual query length is unknown for cross attention
614
+ # if using past key value states. Need to inject it here
615
+ if present_key_value_state is not None:
616
+ query_length = present_key_value_state[0].shape[2]
617
+ else:
618
+ query_length = None
619
+
620
+ cross_attention_outputs = self.layer[1](
621
+ hidden_states,
622
+ key_value_states=encoder_hidden_states,
623
+ attention_mask=encoder_attention_mask,
624
+ position_bias=encoder_decoder_position_bias,
625
+ layer_head_mask=cross_attn_layer_head_mask,
626
+ past_key_value=cross_attn_past_key_value,
627
+ query_length=query_length,
628
+ use_cache=use_cache,
629
+ output_attentions=output_attentions,
630
+ )
631
+ hidden_states = cross_attention_outputs[0]
632
+
633
+ # clamp inf values to enable fp16 training
634
+ if hidden_states.dtype == torch.float16:
635
+ clamp_value = torch.where(
636
+ torch.isinf(hidden_states).any(),
637
+ torch.finfo(hidden_states.dtype).max - 1000,
638
+ torch.finfo(hidden_states.dtype).max,
639
+ )
640
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
641
+
642
+ # Combine self attn and cross attn key value states
643
+ if present_key_value_state is not None:
644
+ present_key_value_state = present_key_value_state + cross_attention_outputs[1]
645
+
646
+ # Keep cross-attention outputs and relative position weights
647
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
648
+
649
+ # Apply Feed Forward layer
650
+ hidden_states = self.layer[-1](hidden_states)
651
+
652
+ # clamp inf values to enable fp16 training
653
+ if hidden_states.dtype == torch.float16:
654
+ clamp_value = torch.where(
655
+ torch.isinf(hidden_states).any(),
656
+ torch.finfo(hidden_states.dtype).max - 1000,
657
+ torch.finfo(hidden_states.dtype).max,
658
+ )
659
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
660
+
661
+ outputs = (hidden_states,)
662
+
663
+ if use_cache:
664
+ outputs = outputs + (present_key_value_state,) + attention_outputs
665
+ else:
666
+ outputs = outputs + attention_outputs
667
+
668
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
669
+
670
+
671
+ class T5ClassificationHead(nn.Module):
672
+ """Head for sentence-level classification tasks."""
673
+
674
+ def __init__(self, config: T5MIMOConfig):
675
+ super().__init__()
676
+ self.dense = nn.Linear(config.d_model, config.d_model)
677
+ self.dropout = nn.Dropout(p=config.classifier_dropout)
678
+ self.out_proj = nn.Linear(config.d_model, config.num_labels)
679
+
680
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
681
+ hidden_states = self.dropout(hidden_states)
682
+ hidden_states = self.dense(hidden_states)
683
+ hidden_states = torch.tanh(hidden_states)
684
+ hidden_states = self.dropout(hidden_states)
685
+ hidden_states = self.out_proj(hidden_states)
686
+ return hidden_states
687
+
688
+
689
+ class T5PreTrainedModel(PreTrainedModel):
690
+ """
691
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
692
+ models.
693
+ """
694
+
695
+ config_class = T5MIMOConfig
696
+ base_model_prefix = "transformer"
697
+ is_parallelizable = True
698
+ supports_gradient_checkpointing = True
699
+ _no_split_modules = ["T5Block"]
700
+ _keep_in_fp32_modules = ["wo"]
701
+
702
+ @property
703
+ def dummy_inputs(self):
704
+ input_ids = torch.tensor(DUMMY_INPUTS)
705
+ input_mask = torch.tensor(DUMMY_MASK)
706
+ dummy_inputs = {
707
+ "decoder_input_ids": input_ids,
708
+ "input_ids": input_ids,
709
+ "decoder_attention_mask": input_mask,
710
+ }
711
+ return dummy_inputs
712
+
713
+ def _init_weights(self, module):
714
+ """Initialize the weights"""
715
+ factor = self.config.initializer_factor # Used for testing weights initialization
716
+ if isinstance(module, T5LayerNorm):
717
+ module.weight.data.fill_(factor * 1.0)
718
+ elif isinstance(
719
+ module,
720
+ (T5MIMOModel, T5MIMOForConditionalGeneration, T5MIMOEncoderModel),
721
+ ):
722
+ # Mesh TensorFlow embeddings initialization
723
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
724
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
725
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
726
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
727
+ if hasattr(module, "qa_outputs"):
728
+ module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
729
+ module.qa_outputs.bias.data.zero_()
730
+ elif isinstance(module, T5ClassificationHead):
731
+ module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
732
+ if hasattr(module.dense, "bias") and module.dense.bias is not None:
733
+ module.dense.bias.data.zero_()
734
+ module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
735
+ if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
736
+ module.out_proj.bias.data.zero_()
737
+ elif isinstance(module, T5DenseActDense):
738
+ # Mesh TensorFlow FF initialization
739
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
740
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
741
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
742
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
743
+ module.wi.bias.data.zero_()
744
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
745
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
746
+ module.wo.bias.data.zero_()
747
+ elif isinstance(module, T5DenseGatedActDense):
748
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
749
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
750
+ module.wi_0.bias.data.zero_()
751
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
752
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
753
+ module.wi_1.bias.data.zero_()
754
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
755
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
756
+ module.wo.bias.data.zero_()
757
+ elif isinstance(module, T5Attention):
758
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
759
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
760
+ d_model = self.config.d_model
761
+ key_value_proj_dim = self.config.d_kv
762
+ n_heads = self.config.num_heads
763
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
764
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
765
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
766
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
767
+ if module.has_relative_attention_bias:
768
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
769
+
770
+ def _shift_right(self, input_ids):
771
+ decoder_start_token_id = self.config.decoder_start_token_id
772
+ pad_token_id = self.config.pad_token_id
773
+
774
+ if decoder_start_token_id is None:
775
+ raise ValueError(
776
+ "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
777
+ "See T5 docs for more information."
778
+ )
779
+
780
+ # shift inputs to the right
781
+ if is_torch_fx_proxy(input_ids):
782
+ # Item assignment is not supported natively for proxies.
783
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
784
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
785
+ else:
786
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
787
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
788
+ shifted_input_ids[..., 0] = decoder_start_token_id
789
+
790
+ if pad_token_id is None:
791
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
792
+ # replace possible -100 values in labels by `pad_token_id`
793
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
794
+
795
+ return shifted_input_ids
796
+
797
+
798
+ class T5Stack(T5PreTrainedModel):
799
+ def __init__(self, config, embed_tokens=None):
800
+ super().__init__(config)
801
+
802
+ self.embed_tokens = embed_tokens
803
+ self.is_decoder = config.is_decoder
804
+
805
+ self.block = nn.ModuleList(
806
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
807
+ )
808
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
809
+ self.dropout = nn.Dropout(config.dropout_rate)
810
+
811
+ # Initialize weights and apply final processing
812
+ self.post_init()
813
+ # Model parallel
814
+ self.model_parallel = False
815
+ self.device_map = None
816
+ self.gradient_checkpointing = False
817
+
818
+ def parallelize(self, device_map=None):
819
+ warnings.warn(
820
+ "`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
821
+ " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
822
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
823
+ " 'block.1': 1, ...}",
824
+ FutureWarning,
825
+ )
826
+ # Check validity of device_map
827
+ self.device_map = (
828
+ get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
829
+ )
830
+ assert_device_map(self.device_map, len(self.block))
831
+ self.model_parallel = True
832
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
833
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
834
+ # Load onto devices
835
+ for k, v in self.device_map.items():
836
+ for layer in v:
837
+ cuda_device = "cuda:" + str(k)
838
+ self.block[layer] = self.block[layer].to(cuda_device)
839
+
840
+ # Set embed_tokens to first layer
841
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
842
+ # Set final layer norm to last device
843
+ self.final_layer_norm = self.final_layer_norm.to(self.last_device)
844
+
845
+
846
+ def deparallelize(self):
847
+ warnings.warn(
848
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
849
+ FutureWarning,
850
+ )
851
+ self.model_parallel = False
852
+ self.device_map = None
853
+ self.first_device = "cpu"
854
+ self.last_device = "cpu"
855
+ for i in range(len(self.block)):
856
+ self.block[i] = self.block[i].to("cpu")
857
+ self.embed_tokens = self.embed_tokens.to("cpu")
858
+ self.final_layer_norm = self.final_layer_norm.to("cpu")
859
+ torch.cuda.empty_cache()
860
+
861
+ def get_input_embeddings(self):
862
+ return self.embed_tokens
863
+
864
+ def set_input_embeddings(self, new_embeddings):
865
+ self.embed_tokens = new_embeddings
866
+
867
+ def forward(
868
+ self,
869
+ input_ids=None,
870
+ attention_mask=None,
871
+ encoder_hidden_states=None,
872
+ encoder_attention_mask=None,
873
+ inputs_embeds=None,
874
+ head_mask=None,
875
+ cross_attn_head_mask=None,
876
+ past_key_values=None,
877
+ use_cache=None,
878
+ output_attentions=None,
879
+ output_hidden_states=None,
880
+ return_dict=None,
881
+ ):
882
+ # Model parallel
883
+ if self.model_parallel:
884
+ torch.cuda.set_device(self.first_device)
885
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
886
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
887
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
888
+ output_hidden_states = (
889
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
890
+ )
891
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
892
+
893
+ if input_ids is not None and inputs_embeds is not None:
894
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
895
+ raise ValueError(
896
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
897
+ )
898
+ elif input_ids is not None:
899
+ input_shape = input_ids.size()
900
+ # input_ids = input_ids.view(-1, input_shape[-1])
901
+ elif inputs_embeds is not None:
902
+ input_shape = inputs_embeds.size()[:-1]
903
+ else:
904
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
905
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
906
+
907
+ if inputs_embeds is None:
908
+ if self.embed_tokens is None:
909
+ raise ValueError("You have to initialize the model with valid token embeddings")
910
+ inputs_embeds = self.embed_tokens(input_ids)
911
+
912
+ if len(input_shape) == 3:
913
+ batch_size, multivar_seqs ,seq_length = input_shape
914
+ else:
915
+ batch_size, seq_length = input_shape
916
+
917
+ # required mask seq length can be calculated via length of past
918
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
919
+
920
+ if use_cache is True:
921
+ if not self.is_decoder:
922
+ raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
923
+
924
+ # initialize past_key_values with `None` if past does not exist
925
+ if past_key_values is None:
926
+ past_key_values = [None] * len(self.block)
927
+
928
+ if attention_mask is None:
929
+ if len(input_shape) == 2:
930
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
931
+ else:
932
+ attention_mask = torch.ones(batch_size, multivar_seqs, mask_seq_length, device=inputs_embeds.device)
933
+
934
+
935
+
936
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
937
+ # ourselves in which case we just need to make it broadcastable to all heads.
938
+ if len(input_shape) == 2:
939
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
940
+ else:
941
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
942
+ # permute from [batch_size, 1, multivar_seqs, seq_length] to [batch_size, multivar_seqs, 1, seq_length]
943
+ extended_attention_mask = extended_attention_mask.permute(0, 2, 1, 3)
944
+ # Now make it [batch_size, multivar_seqs, 1, 1, seq_length]
945
+ extended_attention_mask = extended_attention_mask.unsqueeze(3)
946
+
947
+ # If a 2D or 3D attention mask is provided for the cross-attention
948
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
949
+ if self.is_decoder and encoder_hidden_states is not None:
950
+
951
+ if len(encoder_hidden_states.size()) == 3 :
952
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
953
+ else:
954
+ encoder_batch_size, multivar_dem, encoder_sequence_length, _ = encoder_hidden_states.size()
955
+
956
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
957
+ if encoder_attention_mask is None:
958
+ encoder_attention_mask = torch.ones(
959
+ encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
960
+ )
961
+ if len(input_shape) == 2:
962
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
963
+ else:
964
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
965
+ multivar_dim = extended_attention_mask.shape[1]
966
+ encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(1)
967
+ encoder_extended_attention_mask = encoder_extended_attention_mask.permute(0, 3, 1, 2, 4)
968
+
969
+ else:
970
+ encoder_extended_attention_mask = None
971
+
972
+
973
+
974
+ if self.gradient_checkpointing and self.training:
975
+ if use_cache:
976
+ logger.warning_once(
977
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
978
+ )
979
+ use_cache = False
980
+
981
+ # Prepare head mask if needed
982
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
983
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
984
+ present_key_value_states = () if use_cache else None
985
+ all_hidden_states = () if output_hidden_states else None
986
+ all_attentions = () if output_attentions else None
987
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
988
+ position_bias = None
989
+ encoder_decoder_position_bias = None
990
+
991
+ hidden_states = self.dropout(inputs_embeds)
992
+
993
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
994
+ layer_head_mask = head_mask[i]
995
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
996
+ # Model parallel
997
+ if self.model_parallel:
998
+ torch.cuda.set_device(hidden_states.device)
999
+ # Ensure that attention_mask is always on the same device as hidden_states
1000
+ if attention_mask is not None:
1001
+ attention_mask = attention_mask.to(hidden_states.device)
1002
+ if position_bias is not None:
1003
+ position_bias = position_bias.to(hidden_states.device)
1004
+ if encoder_hidden_states is not None:
1005
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
1006
+ if encoder_extended_attention_mask is not None:
1007
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
1008
+ if encoder_decoder_position_bias is not None:
1009
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
1010
+ if layer_head_mask is not None:
1011
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
1012
+ if cross_attn_layer_head_mask is not None:
1013
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
1014
+ if output_hidden_states:
1015
+ all_hidden_states = all_hidden_states + (hidden_states,)
1016
+
1017
+ if self.gradient_checkpointing and self.training:
1018
+ layer_outputs = self._gradient_checkpointing_func(
1019
+ layer_module.forward,
1020
+ hidden_states,
1021
+ extended_attention_mask,
1022
+ position_bias,
1023
+ encoder_hidden_states,
1024
+ encoder_extended_attention_mask,
1025
+ encoder_decoder_position_bias,
1026
+ layer_head_mask,
1027
+ cross_attn_layer_head_mask,
1028
+ None, # past_key_value is always None with gradient checkpointing
1029
+ use_cache,
1030
+ output_attentions,
1031
+ )
1032
+ else:
1033
+ layer_outputs = layer_module(
1034
+ hidden_states,
1035
+ attention_mask=extended_attention_mask,
1036
+ position_bias=position_bias,
1037
+ encoder_hidden_states=encoder_hidden_states,
1038
+ encoder_attention_mask=encoder_extended_attention_mask,
1039
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
1040
+ layer_head_mask=layer_head_mask,
1041
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
1042
+ past_key_value=past_key_value,
1043
+ use_cache=use_cache,
1044
+ output_attentions=output_attentions,
1045
+ )
1046
+
1047
+ # layer_outputs is a tuple with:
1048
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1049
+ if use_cache is False:
1050
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1051
+
1052
+ hidden_states, present_key_value_state = layer_outputs[:2]
1053
+
1054
+ # We share the position biases between the layers - the first layer store them
1055
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
1056
+ # (cross-attention position bias), (cross-attention weights)
1057
+ position_bias = layer_outputs[2]
1058
+ if self.is_decoder and encoder_hidden_states is not None:
1059
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
1060
+ # append next layer key value states
1061
+ if use_cache:
1062
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
1063
+
1064
+ if output_attentions:
1065
+ all_attentions = all_attentions + (layer_outputs[3],)
1066
+ if self.is_decoder:
1067
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
1068
+
1069
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1070
+ if self.model_parallel:
1071
+ for k, v in self.device_map.items():
1072
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1073
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1074
+
1075
+ hidden_states = self.final_layer_norm(hidden_states)
1076
+ hidden_states = self.dropout(hidden_states)
1077
+
1078
+ # Add last layer
1079
+ if output_hidden_states:
1080
+ all_hidden_states = all_hidden_states + (hidden_states,)
1081
+
1082
+ if not return_dict:
1083
+ return tuple(
1084
+ v
1085
+ for v in [
1086
+ hidden_states,
1087
+ present_key_value_states,
1088
+ all_hidden_states,
1089
+ all_attentions,
1090
+ all_cross_attentions,
1091
+ ]
1092
+ if v is not None
1093
+ )
1094
+ return BaseModelOutputWithPastAndCrossAttentions(
1095
+ last_hidden_state=hidden_states,
1096
+ past_key_values=present_key_value_states,
1097
+ hidden_states=all_hidden_states,
1098
+ attentions=all_attentions,
1099
+ cross_attentions=all_cross_attentions,
1100
+ )
1101
+
1102
+
1103
+
1104
+ class T5MIMOModel(T5PreTrainedModel):
1105
+ config_class = T5MIMOConfig
1106
+
1107
+ _keys_to_ignore_on_load_unexpected = [
1108
+ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1109
+ ]
1110
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1111
+
1112
+ def __init__(self, config: T5MIMOConfig):
1113
+ super().__init__(config)
1114
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1115
+
1116
+ encoder_config = copy.deepcopy(config)
1117
+ encoder_config.is_decoder = False
1118
+ encoder_config.use_cache = False
1119
+ encoder_config.is_encoder_decoder = False
1120
+ self.encoder = T5Stack(encoder_config, self.shared)
1121
+
1122
+ decoder_config = copy.deepcopy(config)
1123
+ decoder_config.is_decoder = True
1124
+ decoder_config.is_encoder_decoder = False
1125
+ decoder_config.num_layers = config.num_decoder_layers
1126
+ self.decoder = T5Stack(decoder_config, self.shared)
1127
+
1128
+ # Initialize weights and apply final processing
1129
+ self.post_init()
1130
+
1131
+ # Model parallel
1132
+ self.model_parallel = False
1133
+ self.device_map = None
1134
+
1135
+
1136
+ def parallelize(self, device_map=None):
1137
+ warnings.warn(
1138
+ "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
1139
+ " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1140
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
1141
+ " 0, 'encoder.block.1': 1, ...}",
1142
+ FutureWarning,
1143
+ )
1144
+ self.device_map = (
1145
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1146
+ if device_map is None
1147
+ else device_map
1148
+ )
1149
+ assert_device_map(self.device_map, len(self.encoder.block))
1150
+ self.encoder.parallelize(self.device_map)
1151
+ self.decoder.parallelize(self.device_map)
1152
+ self.model_parallel = True
1153
+
1154
+
1155
+ def deparallelize(self):
1156
+ warnings.warn(
1157
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1158
+ FutureWarning,
1159
+ )
1160
+ self.encoder.deparallelize()
1161
+ self.decoder.deparallelize()
1162
+ self.encoder = self.encoder.to("cpu")
1163
+ self.decoder = self.decoder.to("cpu")
1164
+ self.model_parallel = False
1165
+ self.device_map = None
1166
+ torch.cuda.empty_cache()
1167
+
1168
+ def get_input_embeddings(self):
1169
+ return self.shared
1170
+
1171
+ def set_input_embeddings(self, new_embeddings):
1172
+ self.shared = new_embeddings
1173
+ self.encoder.set_input_embeddings(new_embeddings)
1174
+ self.decoder.set_input_embeddings(new_embeddings)
1175
+
1176
+ def _tie_weights(self):
1177
+ if self.config.tie_word_embeddings:
1178
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
1179
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
1180
+
1181
+ def get_encoder(self):
1182
+ return self.encoder
1183
+
1184
+ def get_decoder(self):
1185
+ return self.decoder
1186
+
1187
+ def _prune_heads(self, heads_to_prune):
1188
+ """
1189
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1190
+ class PreTrainedModel
1191
+ """
1192
+ for layer, heads in heads_to_prune.items():
1193
+ self.encoder.layer[layer].attention.prune_heads(heads)
1194
+
1195
+ def forward(
1196
+ self,
1197
+ input_ids: Optional[torch.LongTensor] = None,
1198
+ attention_mask: Optional[torch.FloatTensor] = None,
1199
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1200
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1201
+ head_mask: Optional[torch.FloatTensor] = None,
1202
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1203
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1204
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1205
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1206
+ inputs_embeds: Optional[torch.Tensor] = None,
1207
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
1208
+ use_cache: Optional[bool] = None,
1209
+ output_attentions: Optional[bool] = None,
1210
+ output_hidden_states: Optional[bool] = None,
1211
+ return_dict: Optional[bool] = None,
1212
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
1213
+ r"""
1214
+ Returns:
1215
+
1216
+ Example:
1217
+
1218
+ ```python
1219
+ >>> from transformers import AutoTokenizer, T5Model
1220
+
1221
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
1222
+ >>> model = T5Model.from_pretrained("google-t5/t5-small")
1223
+
1224
+ >>> input_ids = tokenizer(
1225
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1226
+ ... ).input_ids # Batch size 1
1227
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
1228
+
1229
+ >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
1230
+ >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
1231
+ >>> decoder_input_ids = model._shift_right(decoder_input_ids)
1232
+
1233
+ >>> # forward pass
1234
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1235
+ >>> last_hidden_states = outputs.last_hidden_state
1236
+ ```"""
1237
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1238
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1239
+
1240
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1241
+ if head_mask is not None and decoder_head_mask is None:
1242
+ if self.config.num_layers == self.config.num_decoder_layers:
1243
+ decoder_head_mask = head_mask
1244
+
1245
+ # Encode if needed (training, first prediction pass)
1246
+ if encoder_outputs is None:
1247
+ encoder_outputs = self.encoder(
1248
+ input_ids=input_ids,
1249
+ attention_mask=attention_mask,
1250
+ inputs_embeds=inputs_embeds,
1251
+ head_mask=head_mask,
1252
+ output_attentions=output_attentions,
1253
+ output_hidden_states=output_hidden_states,
1254
+ return_dict=return_dict,
1255
+ )
1256
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1257
+ encoder_outputs = BaseModelOutput(
1258
+ last_hidden_state=encoder_outputs[0],
1259
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1260
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1261
+ )
1262
+
1263
+ hidden_states = encoder_outputs[0]
1264
+
1265
+ # Set device for model parallelism
1266
+ if self.model_parallel:
1267
+ torch.cuda.set_device(self.decoder.first_device)
1268
+ hidden_states = hidden_states.to(self.decoder.first_device)
1269
+ if decoder_input_ids is not None:
1270
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1271
+ if attention_mask is not None:
1272
+ attention_mask = attention_mask.to(self.decoder.first_device)
1273
+ if decoder_attention_mask is not None:
1274
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1275
+
1276
+ # Decode
1277
+ decoder_outputs = self.decoder(
1278
+ input_ids=decoder_input_ids,
1279
+ attention_mask=decoder_attention_mask,
1280
+ inputs_embeds=decoder_inputs_embeds,
1281
+ past_key_values=past_key_values,
1282
+ encoder_hidden_states=hidden_states,
1283
+ encoder_attention_mask=attention_mask,
1284
+ head_mask=decoder_head_mask,
1285
+ cross_attn_head_mask=cross_attn_head_mask,
1286
+ use_cache=use_cache,
1287
+ output_attentions=output_attentions,
1288
+ output_hidden_states=output_hidden_states,
1289
+ return_dict=return_dict,
1290
+ )
1291
+
1292
+ if not return_dict:
1293
+ return decoder_outputs + encoder_outputs
1294
+
1295
+ return Seq2SeqModelOutput(
1296
+ last_hidden_state=decoder_outputs.last_hidden_state,
1297
+ past_key_values=decoder_outputs.past_key_values,
1298
+ decoder_hidden_states=decoder_outputs.hidden_states,
1299
+ decoder_attentions=decoder_outputs.attentions,
1300
+ cross_attentions=decoder_outputs.cross_attentions,
1301
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1302
+ encoder_hidden_states=encoder_outputs.hidden_states,
1303
+ encoder_attentions=encoder_outputs.attentions,
1304
+ )
1305
+
1306
+
1307
+
1308
+ class T5MIMOForConditionalGeneration(T5PreTrainedModel):
1309
+ config_class = T5MIMOConfig
1310
+
1311
+ _keys_to_ignore_on_load_unexpected = [
1312
+ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1313
+ ]
1314
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
1315
+
1316
+ def __init__(self, config: T5MIMOConfig):
1317
+ super().__init__(config)
1318
+ self.model_dim = config.d_model
1319
+
1320
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1321
+
1322
+ encoder_config = copy.deepcopy(config)
1323
+ encoder_config.is_decoder = False
1324
+ encoder_config.use_cache = False
1325
+ encoder_config.is_encoder_decoder = False
1326
+ self.encoder = T5Stack(encoder_config, self.shared)
1327
+
1328
+ decoder_config = copy.deepcopy(config)
1329
+ decoder_config.is_decoder = True
1330
+ decoder_config.is_encoder_decoder = False
1331
+ decoder_config.num_layers = config.num_decoder_layers
1332
+ self.decoder = T5Stack(decoder_config, self.shared)
1333
+
1334
+
1335
+ self.conv_block = MultivariateConvBlock(config)
1336
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1337
+
1338
+ # Initialize weights and apply final processing
1339
+ self.post_init()
1340
+
1341
+ # Model parallel
1342
+ self.model_parallel = False
1343
+ self.device_map = None
1344
+
1345
+
1346
+ def parallelize(self, device_map=None):
1347
+ warnings.warn(
1348
+ "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
1349
+ " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
1350
+ " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1351
+ " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
1352
+ FutureWarning,
1353
+ )
1354
+ self.device_map = (
1355
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1356
+ if device_map is None
1357
+ else device_map
1358
+ )
1359
+ assert_device_map(self.device_map, len(self.encoder.block))
1360
+ self.encoder.parallelize(self.device_map)
1361
+ self.decoder.parallelize(self.device_map)
1362
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
1363
+ self.model_parallel = True
1364
+
1365
+
1366
+ def deparallelize(self):
1367
+ warnings.warn(
1368
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1369
+ FutureWarning,
1370
+ )
1371
+ self.encoder.deparallelize()
1372
+ self.decoder.deparallelize()
1373
+ self.encoder = self.encoder.to("cpu")
1374
+ self.decoder = self.decoder.to("cpu")
1375
+ self.lm_head = self.lm_head.to("cpu")
1376
+ self.model_parallel = False
1377
+ self.device_map = None
1378
+ torch.cuda.empty_cache()
1379
+
1380
+ def get_input_embeddings(self):
1381
+ return self.shared
1382
+
1383
+ def set_input_embeddings(self, new_embeddings):
1384
+ self.shared = new_embeddings
1385
+ self.encoder.set_input_embeddings(new_embeddings)
1386
+ self.decoder.set_input_embeddings(new_embeddings)
1387
+
1388
+ def _tie_weights(self):
1389
+ if self.config.tie_word_embeddings:
1390
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
1391
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
1392
+
1393
+ def set_output_embeddings(self, new_embeddings):
1394
+ self.lm_head = new_embeddings
1395
+
1396
+ def get_output_embeddings(self):
1397
+ return self.lm_head
1398
+
1399
+ def get_encoder(self):
1400
+ return self.encoder
1401
+
1402
+ def get_decoder(self):
1403
+ return self.decoder
1404
+
1405
+ def forward(
1406
+ self,
1407
+ input_ids: Optional[torch.LongTensor] = None,
1408
+ attention_mask: Optional[torch.FloatTensor] = None,
1409
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1410
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1411
+ head_mask: Optional[torch.FloatTensor] = None,
1412
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1413
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1414
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1415
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1416
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1417
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1418
+ labels: Optional[torch.LongTensor] = None,
1419
+ use_cache: Optional[bool] = None,
1420
+ output_attentions: Optional[bool] = None,
1421
+ output_hidden_states: Optional[bool] = None,
1422
+ return_dict: Optional[bool] = None,
1423
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1424
+ r"""
1425
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1426
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1427
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1428
+ labels in `[0, ..., config.vocab_size]`
1429
+
1430
+ Returns:
1431
+
1432
+ Examples:
1433
+
1434
+ ```python
1435
+ >>> from transformers import AutoTokenizer, T5ForConditionalGeneration
1436
+
1437
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
1438
+ >>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
1439
+
1440
+ >>> # training
1441
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
1442
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1443
+ >>> outputs = model(input_ids=input_ids, labels=labels)
1444
+ >>> loss = outputs.loss
1445
+ >>> logits = outputs.logits
1446
+
1447
+ >>> # inference
1448
+ >>> input_ids = tokenizer(
1449
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
1450
+ ... ).input_ids # Batch size 1
1451
+ >>> outputs = model.generate(input_ids)
1452
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1453
+ >>> # studies have shown that owning a dog is good for you.
1454
+ ```"""
1455
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1456
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1457
+
1458
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1459
+ if head_mask is not None and decoder_head_mask is None:
1460
+ if self.config.num_layers == self.config.num_decoder_layers:
1461
+ decoder_head_mask = head_mask
1462
+
1463
+ # Encode if needed (training, first prediction pass)
1464
+ if encoder_outputs is None:
1465
+ # Convert encoder inputs in embeddings if needed
1466
+ encoder_outputs = self.encoder(
1467
+ input_ids=input_ids,
1468
+ attention_mask=attention_mask,
1469
+ inputs_embeds=inputs_embeds,
1470
+ head_mask=head_mask,
1471
+ output_attentions=output_attentions,
1472
+ output_hidden_states=output_hidden_states,
1473
+ return_dict=return_dict,
1474
+ )
1475
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1476
+ encoder_outputs = BaseModelOutput(
1477
+ last_hidden_state=encoder_outputs[0],
1478
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1479
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1480
+ )
1481
+
1482
+ hidden_states = encoder_outputs[0]
1483
+
1484
+ if self.model_parallel:
1485
+ torch.cuda.set_device(self.decoder.first_device)
1486
+
1487
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1488
+ # get decoder inputs from shifting lm labels to the right
1489
+ decoder_input_ids = self._shift_right(labels)
1490
+
1491
+ # Set device for model parallelism
1492
+ if self.model_parallel:
1493
+ torch.cuda.set_device(self.decoder.first_device)
1494
+ hidden_states = hidden_states.to(self.decoder.first_device)
1495
+ if decoder_input_ids is not None:
1496
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1497
+ if attention_mask is not None:
1498
+ attention_mask = attention_mask.to(self.decoder.first_device)
1499
+ if decoder_attention_mask is not None:
1500
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1501
+
1502
+ # Decode
1503
+ decoder_outputs = self.decoder(
1504
+ input_ids=decoder_input_ids,
1505
+ attention_mask=decoder_attention_mask,
1506
+ inputs_embeds=decoder_inputs_embeds,
1507
+ past_key_values=past_key_values,
1508
+ encoder_hidden_states=hidden_states,
1509
+ encoder_attention_mask=attention_mask,
1510
+ head_mask=decoder_head_mask,
1511
+ cross_attn_head_mask=cross_attn_head_mask,
1512
+ use_cache=use_cache,
1513
+ output_attentions=output_attentions,
1514
+ output_hidden_states=output_hidden_states,
1515
+ return_dict=return_dict,
1516
+ )
1517
+
1518
+ sequence_output = decoder_outputs[0]
1519
+
1520
+ # Set device for model parallelism
1521
+ if self.model_parallel:
1522
+ torch.cuda.set_device(self.encoder.first_device)
1523
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
1524
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
1525
+
1526
+ if self.config.tie_word_embeddings:
1527
+ # Rescale output before projecting on vocab
1528
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1529
+ sequence_output = sequence_output * (self.model_dim**-0.5)
1530
+
1531
+ sequence_output = self.conv_block(sequence_output)
1532
+ lm_logits = self.lm_head(sequence_output)
1533
+
1534
+ loss = None
1535
+ if labels is not None:
1536
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1537
+ # move labels to correct device to enable PP
1538
+ labels = labels.to(lm_logits.device)
1539
+ if len(labels.shape) == 2:
1540
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1541
+ else:
1542
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.reshape(-1))
1543
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
1544
+
1545
+ if not return_dict:
1546
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1547
+ return ((loss,) + output) if loss is not None else output
1548
+
1549
+ return Seq2SeqLMOutput(
1550
+ loss=loss,
1551
+ logits=lm_logits,
1552
+ past_key_values=decoder_outputs.past_key_values,
1553
+ decoder_hidden_states=decoder_outputs.hidden_states,
1554
+ decoder_attentions=decoder_outputs.attentions,
1555
+ cross_attentions=decoder_outputs.cross_attentions,
1556
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1557
+ encoder_hidden_states=encoder_outputs.hidden_states,
1558
+ encoder_attentions=encoder_outputs.attentions,
1559
+ )
1560
+
1561
+ def prepare_inputs_for_generation(
1562
+ self,
1563
+ input_ids,
1564
+ past_key_values=None,
1565
+ attention_mask=None,
1566
+ head_mask=None,
1567
+ decoder_head_mask=None,
1568
+ decoder_attention_mask=None,
1569
+ cross_attn_head_mask=None,
1570
+ use_cache=None,
1571
+ encoder_outputs=None,
1572
+ **kwargs,
1573
+ ):
1574
+ # cut decoder_input_ids if past_key_values is used
1575
+ if past_key_values is not None:
1576
+ past_length = past_key_values[0][0].shape[2]
1577
+
1578
+ # Some generation methods already pass only the last input ID
1579
+ if input_ids.shape[1] > past_length:
1580
+ remove_prefix_length = past_length
1581
+ else:
1582
+ # Default to old behavior: keep only final ID
1583
+ remove_prefix_length = input_ids.shape[1] - 1
1584
+
1585
+ input_ids = input_ids[:, remove_prefix_length:]
1586
+
1587
+ return {
1588
+ "decoder_input_ids": input_ids,
1589
+ "past_key_values": past_key_values,
1590
+ "encoder_outputs": encoder_outputs,
1591
+ "attention_mask": attention_mask,
1592
+ "head_mask": head_mask,
1593
+ "decoder_head_mask": decoder_head_mask,
1594
+ "decoder_attention_mask": decoder_attention_mask,
1595
+ "cross_attn_head_mask": cross_attn_head_mask,
1596
+ "use_cache": use_cache,
1597
+ }
1598
+
1599
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1600
+ return self._shift_right(labels)
1601
+
1602
+ def _reorder_cache(self, past_key_values, beam_idx):
1603
+ # if decoder past is not included in output
1604
+ # speedy decoding is disabled and no need to reorder
1605
+ if past_key_values is None:
1606
+ logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1607
+ return past_key_values
1608
+
1609
+ reordered_decoder_past = ()
1610
+ for layer_past_states in past_key_values:
1611
+ # get the correct batch idx from layer past batch dim
1612
+ # batch dim of `past` is at 2nd position
1613
+ reordered_layer_past_states = ()
1614
+ for layer_past_state in layer_past_states:
1615
+ # need to set correct `past` for each of the four key / value states
1616
+ reordered_layer_past_states = reordered_layer_past_states + (
1617
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1618
+ )
1619
+
1620
+ if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
1621
+ raise ValueError(
1622
+ f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
1623
+ )
1624
+ if len(reordered_layer_past_states) != len(layer_past_states):
1625
+ raise ValueError(
1626
+ f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
1627
+ )
1628
+
1629
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1630
+ return reordered_decoder_past
1631
+
1632
+
1633
+
1634
+ class T5MIMOEncoderModel(T5PreTrainedModel):
1635
+ _tied_weights_keys = ["encoder.embed_tokens.weight"]
1636
+ _keys_to_ignore_on_load_unexpected = [r"decoder"]
1637
+
1638
+ def __init__(self, config: T5MIMOConfig):
1639
+ super().__init__(config)
1640
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1641
+
1642
+ encoder_config = copy.deepcopy(config)
1643
+ encoder_config.use_cache = False
1644
+ encoder_config.is_encoder_decoder = False
1645
+ self.encoder = T5Stack(encoder_config, self.shared)
1646
+
1647
+ # Initialize weights and apply final processing
1648
+ self.post_init()
1649
+
1650
+ # Model parallel
1651
+ self.model_parallel = False
1652
+ self.device_map = None
1653
+
1654
+ def parallelize(self, device_map=None):
1655
+ warnings.warn(
1656
+ "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1657
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1658
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
1659
+ " 'block.1': 1, ...}",
1660
+ FutureWarning,
1661
+ )
1662
+ self.device_map = (
1663
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1664
+ if device_map is None
1665
+ else device_map
1666
+ )
1667
+ assert_device_map(self.device_map, len(self.encoder.block))
1668
+ self.encoder.parallelize(self.device_map)
1669
+ self.model_parallel = True
1670
+
1671
+ def deparallelize(self):
1672
+ warnings.warn(
1673
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1674
+ FutureWarning,
1675
+ )
1676
+ self.encoder.deparallelize()
1677
+ self.encoder = self.encoder.to("cpu")
1678
+ self.model_parallel = False
1679
+ self.device_map = None
1680
+ torch.cuda.empty_cache()
1681
+
1682
+ def get_input_embeddings(self):
1683
+ return self.shared
1684
+
1685
+ def set_input_embeddings(self, new_embeddings):
1686
+ self.shared = new_embeddings
1687
+ self.encoder.set_input_embeddings(new_embeddings)
1688
+
1689
+ def _tie_weights(self):
1690
+ if self.config.tie_word_embeddings:
1691
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
1692
+
1693
+ def get_encoder(self):
1694
+ return self.encoder
1695
+
1696
+ def _prune_heads(self, heads_to_prune):
1697
+ """
1698
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1699
+ class PreTrainedModel
1700
+ """
1701
+ for layer, heads in heads_to_prune.items():
1702
+ self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
1703
+
1704
+ def forward(
1705
+ self,
1706
+ input_ids: Optional[torch.LongTensor] = None,
1707
+ attention_mask: Optional[torch.FloatTensor] = None,
1708
+ head_mask: Optional[torch.FloatTensor] = None,
1709
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1710
+ output_attentions: Optional[bool] = None,
1711
+ output_hidden_states: Optional[bool] = None,
1712
+ return_dict: Optional[bool] = None,
1713
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
1714
+ r"""
1715
+ Returns:
1716
+
1717
+ Example:
1718
+
1719
+ ```python
1720
+ >>> from transformers import AutoTokenizer, T5EncoderModel
1721
+
1722
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
1723
+ >>> model = T5EncoderModel.from_pretrained("google-t5/t5-small")
1724
+ >>> input_ids = tokenizer(
1725
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1726
+ ... ).input_ids # Batch size 1
1727
+ >>> outputs = model(input_ids=input_ids)
1728
+ >>> last_hidden_states = outputs.last_hidden_state
1729
+ ```"""
1730
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1731
+
1732
+ encoder_outputs = self.encoder(
1733
+ input_ids=input_ids,
1734
+ attention_mask=attention_mask,
1735
+ inputs_embeds=inputs_embeds,
1736
+ head_mask=head_mask,
1737
+ output_attentions=output_attentions,
1738
+ output_hidden_states=output_hidden_states,
1739
+ return_dict=return_dict,
1740
+ )
1741
+
1742
+ return encoder_outputs
1743
+
1744
+
1745
+