LongNet: Scaling Transformers to 1,000,000,000 Tokens
Abstract
Scaling sequence length has become a critical demand in the era of large language models. However, existing methods struggle with either computational complexity or model expressivity, rendering the maximum sequence length restricted. In this work, we introduce LongNet, a Transformer variant that can scale sequence length to more than 1 billion tokens, without sacrificing the performance on shorter sequences. Specifically, we propose dilated attention, which expands the attentive field exponentially as the distance grows. LongNet has significant advantages: 1) it has a linear computation complexity and a logarithm dependency between tokens; 2) it can be served as a distributed trainer for extremely long sequences; 3) its dilated attention is a drop-in replacement for standard attention, which can be seamlessly integrated with the existing Transformer-based optimization. Experiments results demonstrate that LongNet yields strong performance on both long-sequence modeling and general language tasks. Our work opens up new possibilities for modeling very long sequences, e.g., treating a whole corpus or even the entire Internet as a sequence.
Community
They are literally taking all the tricks that vision ppl used on ViT re-pbublishing them. When are they going to publish something like Swin-LLM?
They are literally taking all the tricks that vision ppl used on ViT re-pbublishing them. When are they going to publish something like Swin-LLM?
Good points. I think the next one will be Deformable Masked Attention
โ Different from vanilla attention, both sizes of K and V are independent of the sequence length N, making the
communication cost constant.โ
This sentence is doubtful.Show the provement that the K_i and V_i are independent of the sequence length N.These tensors' size are still related to sequence length even you did the dilation.
we should have a dislike button too.. don't you think?
Is there any model avaliable ?
This is happening for quite a sometime now. Using NLP in CV and CV in NLP.
End of the day its the math.
Hey, Im reviewing deep learning papers on twitter daily in Hebrew via hashtag #https://twitter.com/hashtag/shorthebrewpapereviews?src=hashtag_click. So far I've shortly reviewed about deep learning papers. You are invited to follow and comment
This paper review can be found at: https://twitter.com/MikeE_3_14/status/1676988738377744388?s=20
They are literally taking all the tricks that vision ppl used on ViT re-pbublishing them. When are they going to publish something like Swin-LLM?
Is this so bad, as long as they cite CV papers? It's... arguably... how science ought to work?
No matter the source of their inspiration (deepmind always does this..), we want it on huggingface ASAP !
please gibe me model
Is "LongNet: Scaling Transformers to 1,000,000,000 Tokens" something like this?
import torch
import torch.nn as nn
from tqdm import tqdm
class CrossBar(nn.Module):
def init(self, dim, heads):
super().init()
self.dim = dim
self.heads = heads
self.crossbar_linear = nn.Linear(self.dim, self.dim * self.heads)
self.scale = nn.Parameter(torch.ones(1))
def forward(self, input):
# reshaping input and scaling
input = self.crossbar_linear(input).reshape(*input.shape[:-1], self.heads, -1)
return self.scale * torch.gelu(input)
class DilatedMHAttention(nn.Module):
def init(self, dim, num_heads=8, qkv_bias=False, dilation_rates=[1]):
super().init()
self.dim = dim
self.num_heads = num_heads
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.dilation_rates = dilation_rates
self.crossbars = nn.ModuleList([CrossBar(dim, num_heads) for _ in dilation_rates])
def forward(self, x):
# mapping tensor to each crossbar's dimension
q, k, v = map(lambda t: t.view(*t.shape[:-1], self.num_heads, -1), (self.q(x), self.k(x), self.v(x)))
# forwarding to each crossbar and outputting
outputs = [crossbar((q * k).mean(dim=-1)) for crossbar, q, k in zip(self.crossbars, q.chunk(len(self.crossbars), dim=-2), k.chunk(len(self.crossbars), dim=-2))]
return sum(outputs) / len(outputs)
class FeedForward(nn.Module):
def init(self, dim, hidden_dim, dropout=0.):
super().init()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class LongNet(nn.Module):
def init(self, dim, depth, heads, mlp_dim, num_classes, dilation_rates=None):
super().init()
self.blocks = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(dim),
DilatedMHAttention(dim, heads, dilation_rates=[dilation_rates[i]]),
nn.LayerNorm(dim),
FeedForward(dim, mlp_dim),
)
for i in range(depth)
])
self.classifier = nn.Linear(dim, num_classes)
def forward(self, x):
try:
for block in tqdm(self.blocks, desc='Progress:', bar_format='{l_bar}{bar} | {n_fmt}/{total_fmt}', ascii=False, dynamic_ncols=True):
x = block(x) + x
x = x.mean(dim=1)
return self.classifier(x)
except Exception as e:
with open("errors.txt", "a", encoding="utf-8") as f:
f.write(str(e) + "\n")
print("Error occurred! Please check errors.txt file for details.")
The official Microsoft implementation is in the TorchScale repo (no pretrained checkpoints that I know of, you have to train it yourself). https://github.com/Microsoft/TorchScale
Transforming AI: How LongNet Handles A Billion Tokens Effortlessly!
Links ๐:
๐ Subscribe: https://www.youtube.com/@Arxflix
๐ Twitter: https://x.com/arxflix
๐ LMNT (Partner): https://lmnt.com/
Models citing this paper 1
Datasets citing this paper 0
No dataset linking this paper
Spaces citing this paper 0
No Space linking this paper