pdiscoformer / utils /misc_utils.py
ananthu-aniraj's picture
add initial files
20239f9
raw
history blame
5.03 kB
import math
from functools import reduce
import torch
import numpy as np
import os
from pathlib import Path
def factors(n):
return reduce(list.__add__,
([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0))
def file_line_count(filename: str) -> int:
"""Count the number of lines in a file"""
with open(filename, 'rb') as f:
return sum(1 for _ in f)
def compute_attention(qkv, scale=None):
"""
Compute attention matrix (same as in the pytorch scaled dot product attention)
Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
:param qkv: Query, key and value tensors concatenated along the first dimension
:param scale: Scale factor for the attention computation
:return:
"""
if isinstance(qkv, torch.Tensor):
query, key, value = qkv.unbind(0)
else:
query, key, value = qkv
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
L, S = query.size(-2), key.size(-2)
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_out = attn_weight @ value
return attn_weight, attn_out
def compute_dot_product_similarity(a, b):
scores = a @ b.transpose(-1, -2)
return scores
def compute_cross_entropy(p, q):
q = torch.nn.functional.log_softmax(q, dim=-1)
loss = torch.sum(p * q, dim=-1)
return - loss.mean()
def rollout(attentions, discard_ratio=0.9, head_fusion="max", device=torch.device("cuda")):
"""
Perform attention rollout,
Ref: https://github.com/jacobgil/vit-explain/blob/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/vit_rollout.py#L9C1-L42C16
Parameters
----------
attentions : list
List of attention matrices, one for each transformer layer
discard_ratio : float
Ratio of lowest attention values to discard
head_fusion : str
Type of fusion to use for attention heads. One of "mean", "max", "min"
device : torch.device
Device to use for computation
Returns
-------
mask : np.ndarray
Mask of shape (width, width), where width is the square root of the number of patches
"""
result = torch.eye(attentions[0].size(-1), device=device)
attentions = [attention.to(device) for attention in attentions]
with torch.no_grad():
for attention in attentions:
if head_fusion == "mean":
attention_heads_fused = attention.mean(axis=1)
elif head_fusion == "max":
attention_heads_fused = attention.max(axis=1).values
elif head_fusion == "min":
attention_heads_fused = attention.min(axis=1).values
else:
raise "Attention head fusion type Not supported"
# Drop the lowest attentions, but
# don't drop the class token
flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
_, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False)
indices = indices[indices != 0]
flat[0, indices] = 0
I = torch.eye(attention_heads_fused.size(-1), device=device)
a = (attention_heads_fused + 1.0 * I) / 2
a = a / a.sum(dim=-1)
result = torch.matmul(a, result)
# Normalize the result by max value in each row
result = result / result.max(dim=-1, keepdim=True)[0]
return result
def sync_bn_conversion(model: torch.nn.Module):
"""
Convert BatchNorm to SyncBatchNorm (used for DDP)
:param model: PyTorch model
:return:
model: PyTorch model with SyncBatchNorm layers
"""
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
return model
def check_snapshot(args):
"""
Create directory to save training checkpoints, otherwise load the existing checkpoint.
Additionally, if it is an array training job, create a new directory for each training job.
:param args: Arguments from the argument parser
:return:
"""
# Check if it is an array training job (i.e. training with multiple random seeds on the same settings)
if args.array_training_job and not args.resume_training:
args.snapshot_dir = os.path.join(args.snapshot_dir, str(args.seed))
if not os.path.exists(args.snapshot_dir):
save_dir = Path(args.snapshot_dir)
save_dir.mkdir(parents=True, exist_ok=True)
else:
# Create directory to save training checkpoints, otherwise load the existing checkpoint
if not os.path.exists(args.snapshot_dir):
if ".pt" not in args.snapshot_dir or ".pth" not in args.snapshot_dir:
save_dir = Path(args.snapshot_dir)
save_dir.mkdir(parents=True, exist_ok=True)
else:
raise ValueError('Snapshot checkpoint does not exist.')