Spaces:
Sleeping
Sleeping
# Attention Block with option to return the mean of k over heads from attention | |
import torch | |
from timm.models.vision_transformer import Attention, Block | |
import torch.nn.functional as F | |
from typing import Tuple | |
class AttentionWQKVReturn(Attention): | |
""" | |
Modifications: | |
- Return the qkv tensors from the attention | |
""" | |
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
B, N, C = x.shape | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv.unbind(0) | |
q, k = self.q_norm(q), self.k_norm(k) | |
if self.fused_attn: | |
x = F.scaled_dot_product_attention( | |
q, k, v, | |
dropout_p=self.attn_drop.p if self.training else 0., | |
) | |
else: | |
q = q * self.scale | |
attn = q @ k.transpose(-2, -1) | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = attn @ v | |
x = x.transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x, torch.stack((q, k, v), dim=0) | |
class BlockWQKVReturn(Block): | |
""" | |
Modifications: | |
- Use AttentionWQKVReturn instead of Attention | |
- Return the qkv tensors from the attention | |
""" | |
def forward(self, x: torch.Tensor, return_qkv: bool = False) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: | |
# Note: this is copied from timm.models.vision_transformer.Block with modifications. | |
x_attn, qkv = self.attn(self.norm1(x)) | |
x = x + self.drop_path1(self.ls1(x_attn)) | |
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) | |
if return_qkv: | |
return x, qkv | |
else: | |
return x | |