from collections import deque import faiss import torch import torch.nn.functional as F import numpy as np from torch import nn class KNN: """ KNN for one element in batch. Handles all heads """ def __init__(self, num_heads, head_dim, memories_size=16000, shrink_size=None, cache=None): self.num_heads = num_heads self.head_dim = head_dim self.memories_size = memories_size self.shrink_size = shrink_size or memories_size * 1.1 self.indexes = [faiss.IndexFlat( self.head_dim, faiss.METRIC_INNER_PRODUCT) for _ in range(self.num_heads)] self.values = [deque([]) for _ in range(self.num_heads)] self.cache = cache def __del__(self): if hasattr(self, 'indexes'): del self.indexes del self.values def clear(self): for index in self.indexes: index.reset() for value in self.values: value.clear() def shrink(self): """Shrinks index to memories_size""" for i, index in enumerate(self.indexes): if index.ntotal > self.shrink_size: to_delete = index.ntotal - self.memories_size index.remove_ids(np.arange(0, to_delete)) for _ in range(to_delete): self.values[i].popleft() def add(self, key, value): for i, k in enumerate(key): self.indexes[i].add(k) for i, v in enumerate(value): self.values[i].extend(v) if self.cache is not None: raise RuntimeError("Cache for KNN not implemented") # self.cache.add(key) self.shrink() def search(self, query, k=32): """ Searchs for query in keys' index. Returns k most relevant keys and corresponding values """ k = min(k, len(self.values[0])) if k <= 0: return torch.empty((query.shape[0], query.shape[1], 0, query.shape[2])),\ torch.empty( (query.shape[0], query.shape[1], 0, query.shape[2])) Ks, Vs = [], [] for i, q in enumerate(query): D, I, K = self.indexes[i].search_and_reconstruct(q, k=k) V = np.take(self.values[i], indices=I, axis=0) Ks.append(K) Vs.append(V) return np.stack(Ks, axis=0), np.stack(Vs, axis=0) class KNNLayer: """ KNN Attention layer. Handles KNN's for batch (every elemnt separately) """ def __init__(self, config, share_memory=True, batch_size=None, memory_size=16000, shrink_size=None, n_jobs=4, cache=None): if not share_memory and batch_size is None: raise RuntimeError( "If share_memory is False, batch_size should be passed") self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads self.share_memory = share_memory self.batch_size = batch_size self.memory_size = memory_size self.shrink_size = shrink_size or self.memory_size * 1.1 self.closed = False if not share_memory: self.knns = [KNN(self.num_heads, self.head_dim, memory_size, self.shrink_size, cache=cache) for _ in range(self.batch_size)] else: self.knn = KNN(self.num_heads, self.head_dim, memory_size, self.shrink_size, cache=cache) faiss.omp_set_num_threads(n_jobs) def clear_batches(self, batch_indexes): if self.closed: return if not self.share_memory: for idx in batch_indexes: self.knns[idx].clear() def clear(self): if self.closed: return if self.share_memory: self.knn.clear() else: for idx in range(len(self.knns)): self.knns[idx].clear() def add(self, keys, values): if self.closed: return keys, values = keys.numpy(force=True), values.numpy(force=True) if not self.share_memory: for i, (key, value) in enumerate(zip(keys, values)): self.knns[i].add(key, value) else: for key, value in zip(keys, values): self.knn.add(key, value) def search(self, queries, k=32): queries = queries.numpy(force=True) keys, values = [], [] max_len = 0 if self.share_memory: for query in queries: key, value = self.knn.search(query, k) keys.append(key) values.append(value) max_len = max(max_len, key.shape[2]) else: for i, query in enumerate(queries): key, value = self.knns[i].search(query, k) keys.append(key) values.append(value) max_len = max(max_len, key.shape[2]) masks = np.ones((len(keys), max_len), dtype=np.float32) for i, (key, value) in enumerate(zip(keys, values)): l = key.shape[2] if l == max_len: continue elif l > max_len: raise RuntimeError("What? max_len is not max") sh = list(key.shape) sh[2] = max_len - sh[2] keys[i] = np.concatenate( (key, np.zeros(sh, dtype=np.float32)), axis=2) values[i] = np.concatenate( (value, np.zeros(sh, dtype=np.float32)), axis=2) masks[i, l:] = 0 return torch.from_numpy(np.stack(keys, axis=0)),\ torch.from_numpy(np.stack(values, axis=0)),\ torch.from_numpy(masks) def close(self): self.closed = True def open(self): self.closed = False def reset(self): self.open() self.clear() class ClearMemoryLayer(nn.Module): def __init__(self, knn_memory, bos_token, eos_token, next_layer): super().__init__() self.knn_memory = knn_memory self.bos_token = bos_token self.eos_token = eos_token self.next_layer = next_layer def _clear_if_token(self, tokens, token): batches_to_clear = (tokens == token).any(dim=-1).nonzero() if len(batches_to_clear) > 0: self.knn_memory.clear_batches(batches_to_clear[0]) def forward(self, tokens, *args, **kwargs): # self._clear_if_token(tokens, self.bos_token) batches_to_clear = (tokens[:, 0] == self.bos_token).nonzero() if len(batches_to_clear) > 0: self.knn_memory.clear_batches(batches_to_clear[:, 0]) res = self.next_layer(tokens, *args, **kwargs) # self._clear_if_token(tokens, self.eos_token) return res