|
""" |
|
Orginally Taken verbatim from xformers library |
|
https://github.com/facebookresearch/xformers/blob/bcb707576c6a80eaf850aa80e8643d3497ec2bc4/xformers/components/positional_embedding/rotary.py |
|
|
|
The difference is that xformers seems to assume the inputs to be |
|
(bs, head, seq_len, dim) while we assume (bs, seq_len, head, dim) |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import List, Optional, Tuple, Dict, Union |
|
|
|
import torch |
|
import dataclasses |
|
from transformers.utils import logging |
|
|
|
from transformers import PretrainedConfig |
|
|
|
is_dacite_available = False |
|
try: |
|
import dacite |
|
is_dacite_available = True |
|
except ImportError: |
|
pass |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
@dataclasses.dataclass |
|
class LongRopeConfig(object): |
|
short_factor: List[float] |
|
long_factor: List[float] |
|
original_max_position_embeddings: int |
|
type: str = "longrope" |
|
short_mscale: float = -1 |
|
long_mscale: float = -1 |
|
|
|
|
|
def __post_init__(self): |
|
assert self.type in ("longrope", "su"), f"Invalid type {self.type} for LongRopeConfig. Expected longrope / su" |
|
|
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: Dict[str, Union[float, List[float], int]]) -> "LongRopeConfig": |
|
if is_dacite_available: |
|
|
|
return dacite.from_dict(data_class=cls, data=config_dict) |
|
kwargs = {} |
|
for field in dataclasses.fields(cls): |
|
if field.name in config_dict: |
|
if field.init: |
|
kwargs[field.name] = config_dict[field.name] |
|
else: |
|
raise ValueError(f"Field {field.name} is not initiable") |
|
else: |
|
if field.default is dataclasses.MISSING: |
|
raise ValueError(f"Field {field.name} is required") |
|
extra_keys = set(config_dict.keys()) - set(kwargs.keys()) |
|
if len(extra_keys) > 0: |
|
for key in extra_keys: |
|
logger.error(f"Unrecognized key {key} in config_dict") |
|
raise ValueError(f"Unrecognized keys in config_dict") |
|
return cls(**kwargs) |
|
|
|
def rotate_half(x): |
|
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=x1.ndim - 1) |
|
|
|
|
|
|
|
@torch.jit.script |
|
def apply_rotary_pos_emb(x, cos, sin, seq_dimension: int): |
|
|
|
|
|
if seq_dimension == 0: |
|
cos = cos[: x.shape[0], None, None, :] |
|
sin = sin[: x.shape[0], None, None, :] |
|
elif seq_dimension == 1: |
|
|
|
cos = cos[None, : x.shape[1], None, :] |
|
sin = sin[None, : x.shape[1], None, :] |
|
elif seq_dimension == 2: |
|
cos = cos[None, None, : x.shape[2], :] |
|
sin = sin[None, None, : x.shape[2], :] |
|
|
|
return (x * cos) + (rotate_half(x) * sin) |
|
|
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module): |
|
""" |
|
Adapted from the xformers library |
|
|
|
The rotary position embeddings from RoFormer_ (Su et. al). |
|
A crucial insight from the method is that the query and keys are |
|
transformed by rotation matrices which depend on the relative positions. |
|
Other implementations are available in the Rotary Transformer repo_ and in |
|
GPT-NeoX_, GPT-NeoX was an inspiration |
|
.. _RoFormer: https://arxiv.org/abs/2104.09864 |
|
.. _repo: https://github.com/ZhuiyiTechnology/roformer |
|
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox |
|
.. warning: Please note that this embedding is not registered on purpose, as it is transformative |
|
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis |
|
|
|
# Arguments |
|
:param dim_mode: head dimention |
|
:param max_seq_len: |
|
:param default_seq_dimension: which dim is the sequence length |
|
:param dtype: cos/sin dtype |
|
:param use_fused_kernel: if to use customized fused kernel. |
|
Note: if used, q, k will be modified inplace. Ok for both forward & backward. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim_model: int, |
|
*, |
|
max_seq_len: Optional[int] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
base=10000, |
|
position_scale=1, |
|
device: Optional[torch.device] = None, |
|
longrope_config: Optional[LongRopeConfig] = None, |
|
): |
|
super().__init__() |
|
self.base = base |
|
self.dim_model = dim_model |
|
self.max_seq_len = max_seq_len |
|
self.longrope_config = longrope_config |
|
|
|
if self.is_longrope: |
|
|
|
self.register_buffer( |
|
"range_vector", |
|
torch.arange(max_seq_len, device=device, dtype=torch.float32), |
|
persistent=False |
|
) |
|
self.register_buffer( |
|
"short_factors", |
|
torch.tensor(self.longrope_config.short_factor, dtype=torch.float32), |
|
persistent=False |
|
) |
|
self.register_buffer( |
|
"long_factors", |
|
torch.tensor(self.longrope_config.long_factor, dtype=torch.float32), |
|
persistent=False |
|
) |
|
else: |
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim_model, 2).float().to(device) / self.dim_model)) |
|
self.register_buffer("inv_freq", inv_freq) |
|
|
|
self.position_scale = position_scale |
|
|
|
if not self.is_longrope: |
|
dtype = dtype or torch.get_default_dtype() |
|
self._set_cos_sin_cache( |
|
seq_len=max_seq_len, |
|
device=self.inv_freq.device, |
|
dtype=dtype, |
|
) |
|
@property |
|
def is_longrope(self): |
|
return self.longrope_config is not None |
|
|
|
@property |
|
def original_max_seq_len(self): |
|
if self.longrope_config is not None: |
|
return self.longrope_config.original_max_position_embeddings |
|
logger.warning_once( |
|
( |
|
"``original_max_seq_len'' is being accessed, but longrope_config has not been set. " |
|
"Please only do this if you are sure about the context." |
|
) |
|
) |
|
return self.max_seq_len |
|
|
|
def get_range_vector(self, seq_len: int, device: torch.device): |
|
if self.is_longrope: |
|
assert seq_len < self.range_vector.shape[0], f"Found seq_len {seq_len} greater than max_seq_len {self.range_vector.shape[0]}" |
|
if self.range_vector.device != device: |
|
self.range_vector = self.range_vector.to(device) |
|
return self.range_vector[:seq_len] |
|
return torch.arange(seq_len, device=device, dtype=torch.float32) |
|
|
|
|
|
def _calc_mscale(self, scale: torch.Tensor) -> torch.Tensor: |
|
if scale <= 1.0: |
|
return 1.0 |
|
return math.sqrt(1 + math.log(scale) / math.log(self.original_max_seq_len)) |
|
|
|
def _set_cos_sin_cache( |
|
self, |
|
seq_len: int, |
|
device: Optional[torch.device] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
) -> None: |
|
dtype = dtype or torch.get_default_dtype() |
|
self.max_seq_len_cached = seq_len |
|
t = (torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) * self.position_scale).type_as(self.inv_freq) |
|
device_type = device.type if device is not None else "cpu" |
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
|
with torch.autocast(device_type=device_type, enabled=False): |
|
|
|
freqs = torch.outer(t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos() |
|
sin = emb.sin() |
|
self.register_buffer("cos_cached", cos.to(dtype), persistent=False) |
|
self.register_buffer("sin_cached", sin.to(dtype), persistent=False) |
|
|
|
def forward( |
|
self, q: torch.Tensor, |
|
k: torch.Tensor, |
|
seq_dimension: int = 1, |
|
seqlen_offset: int = 0, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""q, k does not include `seqlen_offset` |
|
q: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim) |
|
k: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim) |
|
""" |
|
if seq_dimension < 0: |
|
seq_dimension = k.ndim + seq_dimension |
|
assert seq_dimension in (0, 1, 2) |
|
seq_len = k.shape[seq_dimension] + seqlen_offset |
|
|
|
if self.is_longrope: |
|
if seq_len > self.original_max_seq_len: |
|
t = self.get_range_vector(seq_len, device=q.device) |
|
rescale_factors = self.long_factors.to(q.device) |
|
long_mscale = self.longrope_config.long_mscale |
|
mscale = long_mscale if long_mscale > 0 else self._calc_mscale(self.max_seq_len / self.original_max_seq_len) |
|
else: |
|
t = self.get_range_vector(self.original_max_seq_len, device=q.device) |
|
rescale_factors = self.short_factors.to(q.device) |
|
short_mscale = self.longrope_config.short_mscale |
|
mscale = short_mscale if short_mscale > 0 else 1.0 |
|
assert rescale_factors.shape == (self.dim_model // 2, ), ( |
|
f"misaligned shape for LongRoPE rescale factors:\n" |
|
f"\tExpected {(self.dim_model // 2, )}, got {rescale_factors.shape}." |
|
) |
|
inv_freq = 1.0 / (rescale_factors * (self.base ** (torch.arange(0, self.dim_model, 2).float().to(q.device) / self.dim_model))) |
|
device_type = q.device.type if q.device is not None else "cpu" |
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
|
with torch.autocast(device_type=device_type, enabled=False): |
|
freqs = torch.outer(t, inv_freq) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos() * mscale |
|
sin = emb.sin() * mscale |
|
cos_cached = cos.to(q.dtype) |
|
sin_cached = sin.to(q.dtype) |
|
else: |
|
if seq_len > self.max_seq_len_cached: |
|
self._set_cos_sin_cache( |
|
seq_len=seq_len, |
|
device=k.device, |
|
dtype=k.dtype, |
|
) |
|
cos_cached = self.cos_cached |
|
sin_cached = self.sin_cached |
|
return ( |
|
apply_rotary_pos_emb( |
|
q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension |
|
).to(q.dtype), |
|
apply_rotary_pos_emb( |
|
k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension |
|
).to(k.dtype), |
|
) |
|
|
|
@classmethod |
|
def from_config(cls, config: PretrainedConfig) -> "RotaryEmbedding": |
|
kwargs = dict( |
|
dim_model=config.hidden_size // config.num_attention_heads, |
|
max_seq_len=config.max_position_embeddings, |
|
base=config.rope_embedding_base, |
|
position_scale=config.rope_position_scale, |
|
) |
|
if config.rope_scaling is not None: |
|
kwargs["longrope_config"] = LongRopeConfig.from_dict(config.rope_scaling) |
|
return cls(**kwargs) |