Terra / modeling_llama_action.py
koukyo1994's picture
add LlamaActionV2
a69bb58 verified
raw
history blame
11.4 kB
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_llama_action import LlamaActionConfig
class LearnableFactorizedSpatioTemporalPositionalEmbedding(nn.Module):
def __init__(self, num_spatio_embeddings: int, num_temporal_embeddings: int, embedding_dim: int):
super().__init__()
self.spatio_embeddings = nn.Embedding(num_spatio_embeddings, embedding_dim)
self.temporal_embeddings = nn.Embedding(num_temporal_embeddings, embedding_dim)
self.num_spatio_embeddings = num_spatio_embeddings
self.num_temporal_embeddings = num_temporal_embeddings
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int):
seq_length = attention_mask.size(1)
batch_size = attention_mask.size(0)
if past_key_values_length == 0:
# create a tensor of the form [0, 1, 2, ..., num_spatio_embeddings-1]
spatio_indices = torch.arange(
self.num_spatio_embeddings,
device=attention_mask.device
).repeat(self.num_temporal_embeddings).unsqueeze(0).repeat((batch_size, 1))
# create a tensor of the form [0, 0, 0, ..., 1, 1, 1, ..., 2, 2, 2, ...]
temporal_indices = torch.arange(
self.num_temporal_embeddings,
device=attention_mask.device
).repeat_interleave(self.num_spatio_embeddings).unsqueeze(0).repeat((batch_size, 1))
spatio_indices = spatio_indices[:, :seq_length]
temporal_indices = temporal_indices[:, :seq_length]
else:
temporal_index = past_key_values_length // self.num_spatio_embeddings
spatio_index = past_key_values_length % self.num_spatio_embeddings
spatio_indices = torch.tensor([[spatio_index]], device=attention_mask.device).repeat((batch_size, 1))
temporal_indices = torch.tensor([[temporal_index]], device=attention_mask.device).repeat((batch_size, 1))
return self.spatio_embeddings(spatio_indices) + self.temporal_embeddings(temporal_indices)
class LlamaActionForCausalLM(LlamaForCausalLM):
config_class = LlamaActionConfig
def __init__(self, config: LlamaActionConfig):
super().__init__(config)
self.num_spatio_embeddings = config.num_spatio_embeddings
self.num_temporal_embeddings = config.num_temporal_embeddings
self.num_image_patches = config.num_image_patches
self.num_action_embeddings = config.num_action_embeddings
self.pos_embedding_spatio_temporal = LearnableFactorizedSpatioTemporalPositionalEmbedding(
config.num_spatio_embeddings, config.num_temporal_embeddings, config.hidden_size,
)
self.action_projection = nn.Linear(config.action_dim, config.hidden_size)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
actions: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
pass
elif inputs_embeds is not None:
pass
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if past_key_values is None or len(past_key_values) == 0:
inputs_embeds_list = torch.split(
inputs_embeds,
split_size_or_sections=self.num_image_patches,
dim=1
)
actions_list = torch.split(
actions,
split_size_or_sections=self.num_action_embeddings,
dim=1
)
embeddings = []
if len(inputs_embeds_list) == len(actions_list):
# mostly used in training phase
for inputs_embeds, action_embeds in zip(inputs_embeds_list, actions_list):
action_features = self.action_projection(action_embeds)
embeddings.append(inputs_embeds)
embeddings.append(action_features)
elif len(inputs_embeds_list) < len(actions_list):
# used in inference phase (mostly)
for i, inputs_embeds in enumerate(inputs_embeds_list):
embeddings.append(inputs_embeds)
if i < len(inputs_embeds_list) - 1:
# the last frame might be generating image tokens, so we don't add action embedding
action_embeds = self.action_projection(actions_list[i])
embeddings.append(action_embeds)
if inputs_embeds_list[-1].size(1) == self.num_image_patches:
# if the last frame has generated all image tokens, we add action embedding
action_embeds = self.action_projection(actions_list[len(inputs_embeds_list) - 1])
embeddings.append(action_embeds)
else:
if isinstance(past_key_values, tuple):
past_key_values_length = past_key_values[0][0].size(2)
else:
past_key_values_length = past_key_values.get_seq_length()
embeddings = []
# create an interleaved sequence of image and action embeddings like image, image, ..., image, action, action, ..., action
# we only generate image tokens, so we add action tokens after generating one frame
if past_key_values_length % self.num_spatio_embeddings == (self.num_spatio_embeddings - self.num_action_embeddings):
seq_index = past_key_values_length // self.num_spatio_embeddings + 1
actions_list = torch.split(
actions,
split_size_or_sections=self.num_action_embeddings,
dim=1
)
action_features = self.action_projection(actions_list[seq_index - 1])
embeddings.append(action_features)
embeddings.append(inputs_embeds)
else:
pass
if len(embeddings) > 0:
inputs_embeds = torch.cat(embeddings, dim=1)
# insert spatio-temporal positional embedding
if past_key_values is not None:
if isinstance(past_key_values, tuple):
past_key_values_length = past_key_values[0][0].size(2)
else:
past_key_values_length = past_key_values.get_seq_length()
else:
past_key_values_length = 0
inputs_embeds += self.pos_embedding_spatio_temporal(inputs_embeds, past_key_values_length)
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.lm_head(sequence_output).contiguous()
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
use_cache=None,
progress_bar=None,
**kwargs):
batch_size = input_ids.size(0)
seq_length = input_ids.size(1)
n_frames = seq_length // self.num_image_patches
attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
if progress_bar is not None:
progress_bar.update()
if seq_length % self.num_image_patches != 0:
n_last_frame_tokens = seq_length % self.num_image_patches
attention_mask_length += n_last_frame_tokens
attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None and len(past_key_values) > 0:
if isinstance(past_key_values, tuple):
past_length = past_key_values[0][0].size(2)
else:
past_length = past_key_values.get_seq_length()
if input_ids.size(1) > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = input_ids.size(1) - 1
input_ids = input_ids[:, remove_prefix_length:]
seq_length = input_ids.size(1)
past_key_values_length = past_length
mask_seq_length = seq_length + past_key_values_length
if past_key_values_length % self.num_spatio_embeddings == (self.num_spatio_embeddings - self.num_action_embeddings):
mask_seq_length += self.num_action_embeddings
attention_mask = torch.ones((batch_size, mask_seq_length), device=input_ids.device, dtype=torch.long)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"actions": kwargs.get("actions"),
"past_key_values": past_key_values,
"use_cache": use_cache,
}
class LlamaActionV2ForCausalLM(LlamaActionForCausalLM):
config_class = LlamaActionConfig
def __init__(self, config: LlamaActionConfig):
super().__init__(config)
self.action_projection = nn.Sequential(
nn.Linear(config.action_dim, config.hidden_size),
nn.ReLU(),
nn.Linear(config.hidden_size, config.hidden_size),
nn.ReLU(),
nn.Linear(config.hidden_size, config.hidden_size),
)
self.post_init()