oliverdk's picture
End of training
446b3a0 verified
raw
history blame
4.73 kB
from typing import Optional, Tuple, Union
from abc import abstractmethod
import torch
from torch.nn import BCEWithLogitsLoss
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
from .sensor_loc_reg import SENSOR_LOC_REGISTRY
from .sensor_loc_finder import SensorLocFinder
class MeasurementPredictorMixin(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.sensor_loc_type = config.sensor_loc_type
self.sensor_token = config.sensor_token
self.n_sensors = config.n_sensors
self.sensor_probes = torch.nn.ModuleList([
torch.nn.Linear(config.emb_dim, 1) for _ in range(config.n_sensors)
])
self.aggregate_probe = torch.nn.Linear(config.emb_dim, 1)
self.sensors_weight = config.sensors_weight
self.aggregate_weight = config.aggregate_weight
self.find_sensor_locs: SensorLocFinder = None
@abstractmethod
def set_pad_token(self, tokenizer: PreTrainedTokenizerBase):
pass
def init_sensor_loc_finder(self, tokenizer: PreTrainedTokenizerBase):
self.find_sensor_locs = SENSOR_LOC_REGISTRY[self.sensor_loc_type](
tokenizer, sensor_token=self.sensor_token, n_sensors=self.n_sensors
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
base_model_output: BaseModelOutputWithPast = self.base_model(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# get sensor embeddings (including aggregate)
sensor_locs = self.find_sensor_locs(input_ids)
sensor_embs = base_model_output.last_hidden_state.gather(
1, sensor_locs.unsqueeze(-1).expand(-1, -1, self.config.emb_dim)
)
assert sensor_embs.shape == (input_ids.shape[0], self.n_sensors + 1, self.config.emb_dim), sensor_embs.shape
# get sensor and aggregate logits
sensor_logits = torch.concat([self.sensor_probes[i](sensor_embs[:, i, :])
for i in range(self.n_sensors)], dim=-1)
aggregate_logits = self.aggregate_probe(sensor_embs[:, -1, :])
logits = torch.concat([sensor_logits, aggregate_logits], dim=-1)
# compute loss
loss = None
if labels is not None:
loss_fct = BCEWithLogitsLoss()
sensor_loss = loss_fct(sensor_logits[:, :self.n_sensors], labels[:, :self.n_sensors]) * self.sensors_weight
loss = sensor_loss
aggregate_loss = loss_fct(aggregate_logits, labels[:, -1:]) * self.aggregate_weight
loss += aggregate_loss
if not return_dict:
output = (logits, ) + base_model_output[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=logits,
past_key_values=base_model_output.past_key_values,
hidden_states=base_model_output.hidden_states,
attentions=base_model_output.attentions,
)