Spaces:
Sleeping
Sleeping
import math | |
from functools import reduce | |
import torch | |
import numpy as np | |
import os | |
from pathlib import Path | |
def factors(n): | |
return reduce(list.__add__, | |
([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0)) | |
def file_line_count(filename: str) -> int: | |
"""Count the number of lines in a file""" | |
with open(filename, 'rb') as f: | |
return sum(1 for _ in f) | |
def compute_attention(qkv, scale=None): | |
""" | |
Compute attention matrix (same as in the pytorch scaled dot product attention) | |
Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html | |
:param qkv: Query, key and value tensors concatenated along the first dimension | |
:param scale: Scale factor for the attention computation | |
:return: | |
""" | |
if isinstance(qkv, torch.Tensor): | |
query, key, value = qkv.unbind(0) | |
else: | |
query, key, value = qkv | |
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale | |
L, S = query.size(-2), key.size(-2) | |
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) | |
attn_weight = query @ key.transpose(-2, -1) * scale_factor | |
attn_weight += attn_bias | |
attn_weight = torch.softmax(attn_weight, dim=-1) | |
attn_out = attn_weight @ value | |
return attn_weight, attn_out | |
def compute_dot_product_similarity(a, b): | |
scores = a @ b.transpose(-1, -2) | |
return scores | |
def compute_cross_entropy(p, q): | |
q = torch.nn.functional.log_softmax(q, dim=-1) | |
loss = torch.sum(p * q, dim=-1) | |
return - loss.mean() | |
def rollout(attentions, discard_ratio=0.9, head_fusion="max", device=torch.device("cuda")): | |
""" | |
Perform attention rollout, | |
Ref: https://github.com/jacobgil/vit-explain/blob/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/vit_rollout.py#L9C1-L42C16 | |
Parameters | |
---------- | |
attentions : list | |
List of attention matrices, one for each transformer layer | |
discard_ratio : float | |
Ratio of lowest attention values to discard | |
head_fusion : str | |
Type of fusion to use for attention heads. One of "mean", "max", "min" | |
device : torch.device | |
Device to use for computation | |
Returns | |
------- | |
mask : np.ndarray | |
Mask of shape (width, width), where width is the square root of the number of patches | |
""" | |
result = torch.eye(attentions[0].size(-1), device=device) | |
attentions = [attention.to(device) for attention in attentions] | |
with torch.no_grad(): | |
for attention in attentions: | |
if head_fusion == "mean": | |
attention_heads_fused = attention.mean(axis=1) | |
elif head_fusion == "max": | |
attention_heads_fused = attention.max(axis=1).values | |
elif head_fusion == "min": | |
attention_heads_fused = attention.min(axis=1).values | |
else: | |
raise "Attention head fusion type Not supported" | |
# Drop the lowest attentions, but | |
# don't drop the class token | |
flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) | |
_, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False) | |
indices = indices[indices != 0] | |
flat[0, indices] = 0 | |
I = torch.eye(attention_heads_fused.size(-1), device=device) | |
a = (attention_heads_fused + 1.0 * I) / 2 | |
a = a / a.sum(dim=-1) | |
result = torch.matmul(a, result) | |
# Normalize the result by max value in each row | |
result = result / result.max(dim=-1, keepdim=True)[0] | |
return result | |
def sync_bn_conversion(model: torch.nn.Module): | |
""" | |
Convert BatchNorm to SyncBatchNorm (used for DDP) | |
:param model: PyTorch model | |
:return: | |
model: PyTorch model with SyncBatchNorm layers | |
""" | |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
return model | |
def check_snapshot(args): | |
""" | |
Create directory to save training checkpoints, otherwise load the existing checkpoint. | |
Additionally, if it is an array training job, create a new directory for each training job. | |
:param args: Arguments from the argument parser | |
:return: | |
""" | |
# Check if it is an array training job (i.e. training with multiple random seeds on the same settings) | |
if args.array_training_job and not args.resume_training: | |
args.snapshot_dir = os.path.join(args.snapshot_dir, str(args.seed)) | |
if not os.path.exists(args.snapshot_dir): | |
save_dir = Path(args.snapshot_dir) | |
save_dir.mkdir(parents=True, exist_ok=True) | |
else: | |
# Create directory to save training checkpoints, otherwise load the existing checkpoint | |
if not os.path.exists(args.snapshot_dir): | |
if ".pt" not in args.snapshot_dir or ".pth" not in args.snapshot_dir: | |
save_dir = Path(args.snapshot_dir) | |
save_dir.mkdir(parents=True, exist_ok=True) | |
else: | |
raise ValueError('Snapshot checkpoint does not exist.') | |