amc-madalin commited on
Commit
54b6d58
1 Parent(s): 726a3fb

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_olmo.py +44 -0
  2. modeling_olmo.py +145 -0
configuration_olmo.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OLMo configuration
3
+ """
4
+
5
+ from transformers import AutoConfig, PretrainedConfig
6
+ from transformers.utils import logging
7
+
8
+ from olmo.config import ModelConfig
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ class OLMoConfig(PretrainedConfig):
14
+ model_type = "olmo"
15
+ keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
16
+
17
+ def __init__(self, use_cache: bool = False, **kwargs):
18
+ model_config = ModelConfig()
19
+ all_kwargs = model_config.asdict()
20
+ all_kwargs.update(kwargs)
21
+ all_kwargs.update({"use_cache": use_cache})
22
+ all_kwargs.update(
23
+ {
24
+ "architectures": all_kwargs.get("architectures", ["OlmoModelForCausalLM"])
25
+ or ["OlmoModelForCausalLM"]
26
+ }
27
+ )
28
+ super().__init__(**all_kwargs)
29
+
30
+ @property
31
+ def num_attention_heads(self):
32
+ return self.n_heads
33
+
34
+ @property
35
+ def num_hidden_layers(self):
36
+ return self.n_layers
37
+
38
+ @property
39
+ def hidden_size(self):
40
+ return self.d_model
41
+
42
+
43
+ # Register the config class so that it is available for transformer pipelines, auto-loading etc.
44
+ AutoConfig.register("olmo", OLMoConfig)
modeling_olmo.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from transformers import PreTrainedModel
5
+ from transformers.modeling_outputs import CausalLMOutputWithPast
6
+ from transformers.models.auto import AutoModelForCausalLM
7
+
8
+ from olmo.config import ModelConfig
9
+ from olmo.model import Olmo
10
+
11
+ from .configuration_olmo import OLMoConfig
12
+
13
+
14
+ def create_model_config_from_pretrained_config(config: OLMoConfig):
15
+ """
16
+ Utility function
17
+ """
18
+
19
+ kwargs = {}
20
+ for key in ModelConfig.__match_args__:
21
+ kwargs[key] = getattr(config, key)
22
+
23
+ model_config = ModelConfig(**kwargs)
24
+ return model_config
25
+
26
+
27
+ class OLMoForCausalLM(PreTrainedModel):
28
+ """
29
+ Extremely barebones HF model wrapper.
30
+ """
31
+
32
+ config_class = OLMoConfig
33
+ base_model_prefix = "model"
34
+ _no_split_modules = ["OLMoBlock"]
35
+
36
+ def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None, init_params: bool = False):
37
+ super().__init__(config)
38
+
39
+ if not model:
40
+ model_config = create_model_config_from_pretrained_config(config)
41
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
42
+ model_config.init_device = "cpu"
43
+ self.model = Olmo(model_config, init_params=init_params)
44
+ else:
45
+ self.model = model
46
+
47
+ def forward(
48
+ self,
49
+ input_ids: torch.LongTensor = None,
50
+ attention_mask: Optional[torch.Tensor] = None,
51
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
52
+ labels: Optional[torch.LongTensor] = None,
53
+ use_cache: Optional[bool] = None,
54
+ output_attentions: Optional[bool] = None,
55
+ output_hidden_states: Optional[bool] = None,
56
+ return_dict: Optional[bool] = None,
57
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
58
+ if use_cache is None:
59
+ use_cache = self.config.use_cache
60
+
61
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
62
+
63
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
64
+ outputs = self.model.forward(
65
+ input_ids=input_ids,
66
+ attention_mask=attention_mask,
67
+ past_key_values=past_key_values,
68
+ use_cache=use_cache,
69
+ )
70
+
71
+ logits = outputs.logits
72
+
73
+ loss = None
74
+ if labels is not None:
75
+ # Shift so that tokens < n predict n
76
+ shift_logits = logits[..., :-1, :].contiguous()
77
+ shift_labels = labels[..., 1:].contiguous()
78
+ # Flatten the tokens
79
+ loss_fct = torch.nn.CrossEntropyLoss()
80
+ shift_logits = shift_logits.view(-1, self.config.embedding_size)
81
+ shift_labels = shift_labels.view(-1)
82
+ # Enable model parallelism
83
+ shift_labels = shift_labels.to(shift_logits.device)
84
+ loss = loss_fct(shift_logits, shift_labels)
85
+
86
+ if not return_dict:
87
+ output = (logits,) + outputs[1:]
88
+ return (loss,) + output if loss is not None else output
89
+
90
+ return CausalLMOutputWithPast(
91
+ loss=loss,
92
+ logits=logits,
93
+ past_key_values=outputs.attn_key_values,
94
+ )
95
+
96
+ def can_generate(self) -> bool:
97
+ return True
98
+
99
+ def prepare_inputs_for_generation(
100
+ self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
101
+ ):
102
+ if past_key_values:
103
+ # This is because we want the model to only process the last generated token.
104
+ input_ids = input_ids[:, -1:]
105
+ model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
106
+
107
+ model_inputs.update(kwargs)
108
+ model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
109
+ return model_inputs
110
+
111
+ # TODO: these are required to make the implementation complete.
112
+ # def resize_position_embeddings(self, new_num_position_embeddings: int):
113
+ # pass
114
+ #
115
+ # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
116
+ # pass
117
+ #
118
+ # def _reorder_cache(self, past_key_values, beam_idx):
119
+ # pass
120
+
121
+ def get_input_embeddings(self) -> torch.nn.Module:
122
+ return self.model.transformer.wte
123
+
124
+ def set_input_embeddings(self, value: torch.nn.Module):
125
+ self.model.transformer.wte = value
126
+
127
+ def get_output_embeddings(self):
128
+ if self.config.weight_tying:
129
+ return self.model.transformer.wte
130
+ else:
131
+ return self.model.transformer.ff_out
132
+
133
+ def set_output_embeddings(self, value: torch.nn.Module):
134
+ if self.config.weight_tying:
135
+ self.model.transformer.wte = value
136
+ else:
137
+ self.model.transformer.ff_out = value
138
+
139
+ def tie_weights(self):
140
+ if self.config.weight_tying:
141
+ self.model.transformer.ff_out = self.model.transformer.wte
142
+
143
+
144
+ # Register the model so that it is available for transformer pipelines, auto-loading, etc.
145
+ AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)