|
import torch
|
|
from typing import Union, Tuple, List
|
|
|
|
|
|
def _to_tuple(x, dim=2):
|
|
if isinstance(x, int):
|
|
return (x,) * dim
|
|
elif len(x) == dim:
|
|
return x
|
|
else:
|
|
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
|
|
|
|
|
def get_meshgrid_nd(start, *args, dim=2):
|
|
"""
|
|
Get n-D meshgrid with start, stop and num.
|
|
|
|
Args:
|
|
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
|
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
|
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
|
n-tuples.
|
|
*args: See above.
|
|
dim (int): Dimension of the meshgrid. Defaults to 2.
|
|
|
|
Returns:
|
|
grid (np.ndarray): [dim, ...]
|
|
"""
|
|
if len(args) == 0:
|
|
|
|
num = _to_tuple(start, dim=dim)
|
|
start = (0,) * dim
|
|
stop = num
|
|
elif len(args) == 1:
|
|
|
|
start = _to_tuple(start, dim=dim)
|
|
stop = _to_tuple(args[0], dim=dim)
|
|
num = [stop[i] - start[i] for i in range(dim)]
|
|
elif len(args) == 2:
|
|
|
|
start = _to_tuple(start, dim=dim)
|
|
stop = _to_tuple(args[0], dim=dim)
|
|
num = _to_tuple(args[1], dim=dim)
|
|
else:
|
|
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
|
|
|
|
|
axis_grid = []
|
|
for i in range(dim):
|
|
a, b, n = start[i], stop[i], num[i]
|
|
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
|
axis_grid.append(g)
|
|
grid = torch.meshgrid(*axis_grid, indexing="ij")
|
|
grid = torch.stack(grid, dim=0)
|
|
|
|
return grid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reshape_for_broadcast(
|
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
|
x: torch.Tensor,
|
|
head_first=False,
|
|
):
|
|
"""
|
|
Reshape frequency tensor for broadcasting it with another tensor.
|
|
|
|
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
|
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
|
|
|
Notes:
|
|
When using FlashMHAModified, head_first should be False.
|
|
When using Attention, head_first should be True.
|
|
|
|
Args:
|
|
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
|
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
|
head_first (bool): head dimension first (except batch dim) or not.
|
|
|
|
Returns:
|
|
torch.Tensor: Reshaped frequency tensor.
|
|
|
|
Raises:
|
|
AssertionError: If the frequency tensor doesn't match the expected shape.
|
|
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
|
"""
|
|
ndim = x.ndim
|
|
assert 0 <= 1 < ndim
|
|
|
|
if isinstance(freqs_cis, tuple):
|
|
|
|
if head_first:
|
|
assert freqs_cis[0].shape == (
|
|
x.shape[-2],
|
|
x.shape[-1],
|
|
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
|
shape = [
|
|
d if i == ndim - 2 or i == ndim - 1 else 1
|
|
for i, d in enumerate(x.shape)
|
|
]
|
|
else:
|
|
assert freqs_cis[0].shape == (
|
|
x.shape[1],
|
|
x.shape[-1],
|
|
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
|
else:
|
|
|
|
if head_first:
|
|
assert freqs_cis.shape == (
|
|
x.shape[-2],
|
|
x.shape[-1],
|
|
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
|
shape = [
|
|
d if i == ndim - 2 or i == ndim - 1 else 1
|
|
for i, d in enumerate(x.shape)
|
|
]
|
|
else:
|
|
assert freqs_cis.shape == (
|
|
x.shape[1],
|
|
x.shape[-1],
|
|
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
return freqs_cis.view(*shape)
|
|
|
|
|
|
def rotate_half(x):
|
|
x_real, x_imag = (
|
|
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
|
)
|
|
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
|
|
|
|
|
def apply_rotary_emb(
|
|
xq: torch.Tensor,
|
|
xk: torch.Tensor,
|
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
|
head_first: bool = False,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Apply rotary embeddings to input tensors using the given frequency tensor.
|
|
|
|
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
|
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
|
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
|
returned as real tensors.
|
|
|
|
Args:
|
|
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
|
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
|
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
|
|
head_first (bool): head dimension first (except batch dim) or not.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
|
|
|
"""
|
|
xk_out = None
|
|
if isinstance(freqs_cis, tuple):
|
|
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
|
|
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
|
|
|
|
|
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
|
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
|
else:
|
|
|
|
xq_ = torch.view_as_complex(
|
|
xq.float().reshape(*xq.shape[:-1], -1, 2)
|
|
)
|
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
|
|
xq.device
|
|
)
|
|
|
|
|
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
|
xk_ = torch.view_as_complex(
|
|
xk.float().reshape(*xk.shape[:-1], -1, 2)
|
|
)
|
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
|
|
|
return xq_out, xk_out
|
|
|
|
|
|
def get_nd_rotary_pos_embed(
|
|
rope_dim_list,
|
|
start,
|
|
*args,
|
|
theta=10000.0,
|
|
use_real=False,
|
|
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
|
interpolation_factor: Union[float, List[float]] = 1.0,
|
|
):
|
|
"""
|
|
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
|
|
|
Args:
|
|
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
|
sum(rope_dim_list) should equal to head_dim of attention layer.
|
|
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
|
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
|
*args: See above.
|
|
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
|
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
|
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
|
part and an imaginary part separately.
|
|
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
|
|
|
Returns:
|
|
pos_embed (torch.Tensor): [HW, D/2]
|
|
"""
|
|
|
|
grid = get_meshgrid_nd(
|
|
start, *args, dim=len(rope_dim_list)
|
|
)
|
|
|
|
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
|
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
|
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
|
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
|
assert len(theta_rescale_factor) == len(
|
|
rope_dim_list
|
|
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
|
|
|
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
|
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
|
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
|
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
|
assert len(interpolation_factor) == len(
|
|
rope_dim_list
|
|
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
|
|
|
|
|
embs = []
|
|
for i in range(len(rope_dim_list)):
|
|
emb = get_1d_rotary_pos_embed(
|
|
rope_dim_list[i],
|
|
grid[i].reshape(-1),
|
|
theta,
|
|
use_real=use_real,
|
|
theta_rescale_factor=theta_rescale_factor[i],
|
|
interpolation_factor=interpolation_factor[i],
|
|
)
|
|
embs.append(emb)
|
|
|
|
if use_real:
|
|
cos = torch.cat([emb[0] for emb in embs], dim=1)
|
|
sin = torch.cat([emb[1] for emb in embs], dim=1)
|
|
return cos, sin
|
|
else:
|
|
emb = torch.cat(embs, dim=1)
|
|
return emb
|
|
|
|
|
|
def get_1d_rotary_pos_embed(
|
|
dim: int,
|
|
pos: Union[torch.FloatTensor, int],
|
|
theta: float = 10000.0,
|
|
use_real: bool = False,
|
|
theta_rescale_factor: float = 1.0,
|
|
interpolation_factor: float = 1.0,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
"""
|
|
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
|
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
|
|
|
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
|
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
|
The returned tensor contains complex values in complex64 data type.
|
|
|
|
Args:
|
|
dim (int): Dimension of the frequency tensor.
|
|
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
|
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
|
use_real (bool, optional): If True, return real part and imaginary part separately.
|
|
Otherwise, return complex numbers.
|
|
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
|
|
|
Returns:
|
|
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
|
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
|
"""
|
|
if isinstance(pos, int):
|
|
pos = torch.arange(pos).float()
|
|
|
|
|
|
|
|
if theta_rescale_factor != 1.0:
|
|
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
|
|
|
freqs = 1.0 / (
|
|
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
|
)
|
|
|
|
freqs = torch.outer(pos * interpolation_factor, freqs)
|
|
if use_real:
|
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1)
|
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1)
|
|
return freqs_cos, freqs_sin
|
|
else:
|
|
freqs_cis = torch.polar(
|
|
torch.ones_like(freqs), freqs
|
|
)
|
|
return freqs_cis
|
|
|