lavawolfiee
commited on
Commit
·
6bc49a9
1
Parent(s):
076e659
Finally
Browse files- .gitignore +2 -0
- app.py +70 -0
- batched_dataloader.py +48 -0
- gpt2_knn_attention.pt +3 -0
- gpt2_knn_attention.py +151 -0
- knn_memory.py +220 -0
- requirements.txt +5 -0
- vector_cache.py +57 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
flagged/
|
app.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
|
4 |
+
|
5 |
+
from gpt2_knn_attention import GPT2KNNAttention
|
6 |
+
from knn_memory import KNNLayer, ClearMemoryLayer
|
7 |
+
|
8 |
+
|
9 |
+
def inject_knn_in_gpt2(model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8):
|
10 |
+
layer = model.transformer.h[layer_ind].attn
|
11 |
+
state = layer.state_dict()
|
12 |
+
knn_layer = GPT2KNNAttention(
|
13 |
+
config, knn_memory, device, is_cross_attention=False, layer_idx=layer.layer_idx)
|
14 |
+
knn_state = knn_layer.state_dict()
|
15 |
+
|
16 |
+
for k, v in state.items():
|
17 |
+
knn_state[k] = v
|
18 |
+
|
19 |
+
knn_layer.load_state_dict(knn_state)
|
20 |
+
|
21 |
+
model.transformer.h[8].attn = knn_layer
|
22 |
+
model.transformer = ClearMemoryLayer(
|
23 |
+
knn_memory, bos_token_id, eos_token_id, model.transformer)
|
24 |
+
model.eval()
|
25 |
+
|
26 |
+
|
27 |
+
model_name = "gpt2"
|
28 |
+
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
|
29 |
+
model = GPT2LMHeadModel.from_pretrained(model_name)
|
30 |
+
config = model.config
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
knn_memory = KNNLayer(config, share_memory=False, batch_size=1)
|
34 |
+
bos_token_id, eos_token_id = tokenizer.bos_token_id, tokenizer.eos_token_id
|
35 |
+
bos_token, eos_token = tokenizer.bos_token, tokenizer.eos_token
|
36 |
+
|
37 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
+
knn_model = inject_knn_in_gpt2(
|
39 |
+
model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8)
|
40 |
+
knn_model.load_state_dict(torch.load('gpt2_knn_attention.pt'))
|
41 |
+
|
42 |
+
|
43 |
+
def generate(text, temperature, max_new_tokens, top_p):
|
44 |
+
encoded_input = tokenizer(text, return_tensors='pt')
|
45 |
+
output = model.generate(**encoded_input, do_sample=True,
|
46 |
+
max_new_tokens=int(max_new_tokens), temperature=temperature, top_p=top_p)
|
47 |
+
return tokenizer.decode(output[0])
|
48 |
+
|
49 |
+
|
50 |
+
desc = "Попытка повторить статью от Google (Memorizing Transformers)[https://arxiv.org/abs/2203.08913]. "\
|
51 |
+
"В ней вводиться новый слой **KNNAttention**, который использует approximate kNN в базе с (key, value), чтобы делать attention по большому контексту. Это позволяет расширить контекст трансформера до размера книг и статей, несильно замедляя его.\n\n"\
|
52 |
+
"Я написал свои **KNNAttention**, переписал слой **GPT2Attention**, чтобы он использовал KNNAttention, а также написал несколько вспомогательный классов для всего этого.\n\n"\
|
53 |
+
"Я сел писать это за **3 недели** до дедлайна, но все равно не довел до результата, которого изначально хотел. Но я доволен проделанной работой :)"
|
54 |
+
|
55 |
+
|
56 |
+
demo = gr.Interface(
|
57 |
+
fn=generate,
|
58 |
+
inputs=[gr.inputs.Textbox(lines=5, label="Input Text"),
|
59 |
+
gr.Slider(0.001, 2.0, step=0.05, value=0.8, label='temperature'),
|
60 |
+
gr.Slider(1, 512, step=1, value=32, label='max_new_tokens'),
|
61 |
+
gr.Slider(0.1, 1.0, step=0.02, value=0.92, label='top_p')],
|
62 |
+
outputs=gr.outputs.Textbox(label="Generated Text"),
|
63 |
+
description=desc,
|
64 |
+
title="Memorizing Transformers",
|
65 |
+
examples=[
|
66 |
+
["My name is Lewis and I like to", 0.8, 32, 0.92]
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
demo.launch()
|
batched_dataloader.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
from transformers import DataCollatorWithPadding
|
3 |
+
|
4 |
+
# some utils for training
|
5 |
+
|
6 |
+
|
7 |
+
class BooksBatcherIter:
|
8 |
+
def __init__(self, data_iter, batch_size, tokenizer, chunk_size=1024):
|
9 |
+
self.data_iter = data_iter
|
10 |
+
self.batch_size = batch_size
|
11 |
+
self.chunk_size = chunk_size
|
12 |
+
self.batch_fns = [self._batch_fn()]
|
13 |
+
self.collate_fn = DataCollatorWithPadding(tokenizer)
|
14 |
+
|
15 |
+
def _batch_fn(self):
|
16 |
+
for book in self.data_iter:
|
17 |
+
for i in range(0, len(book), self.chunk_size):
|
18 |
+
yield book[i:i+self.chunk_size]
|
19 |
+
|
20 |
+
def __iter__(self) -> 'BooksBatcherIter':
|
21 |
+
return self
|
22 |
+
|
23 |
+
def __next__(self) -> Any:
|
24 |
+
batch = []
|
25 |
+
|
26 |
+
try:
|
27 |
+
for b in self.batch_fns:
|
28 |
+
batch.append(next(b))
|
29 |
+
except StopIteration:
|
30 |
+
raise StopIteration
|
31 |
+
|
32 |
+
return self.collate_fn(batch)
|
33 |
+
|
34 |
+
|
35 |
+
class BooksBatcher:
|
36 |
+
def __init__(self, dataset, batch_size, tokenizer) -> None:
|
37 |
+
self.batch_size = batch_size
|
38 |
+
self.tokenizer = tokenizer
|
39 |
+
self.dataloader = DataLoader(
|
40 |
+
dataset=dataset,
|
41 |
+
batch_size=None, # return raw samples
|
42 |
+
shuffle=True,
|
43 |
+
num_workers=2,
|
44 |
+
prefetch_factor=4
|
45 |
+
)
|
46 |
+
|
47 |
+
def __iter__(self) -> 'BooksBatcherIter':
|
48 |
+
return BooksBatcherIter(iter(self.dataloader), self.batch_size, self.tokenizer)
|
gpt2_knn_attention.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:610319abc523c64f595d4cb8fec2ff7faeab331b05722206c6c42656eae0bdff
|
3 |
+
size 510408492
|
gpt2_knn_attention.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
7 |
+
|
8 |
+
|
9 |
+
class GPT2KNNAttention(GPT2Attention):
|
10 |
+
def __init__(self, config, knn_memory, device, is_cross_attention=False, layer_idx=None, num_retrieve_memories=32):
|
11 |
+
super().__init__(config, is_cross_attention, layer_idx)
|
12 |
+
|
13 |
+
self.knn_memory = knn_memory
|
14 |
+
self.device = device
|
15 |
+
self.num_retrieve_memories = num_retrieve_memories
|
16 |
+
self.knn_attn_dropout = nn.Dropout(config.attn_pdrop)
|
17 |
+
self.attn_comb_bias = nn.Parameter(torch.empty(self.num_heads,))
|
18 |
+
nn.init.normal_(self.attn_comb_bias, mean=0.0, std=1.0)
|
19 |
+
# self.attn_comb_bias = nn.Parameter(torch.full((self.num_heads,), 1.0))
|
20 |
+
|
21 |
+
def _knn_attn(self, query, key, value, mask, head_mask=None):
|
22 |
+
query = query.unsqueeze(-2)
|
23 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
24 |
+
|
25 |
+
if self.scale_attn_weights:
|
26 |
+
attn_weights = attn_weights / torch.full(
|
27 |
+
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
28 |
+
)
|
29 |
+
|
30 |
+
# Layer-wise attention scaling
|
31 |
+
if self.scale_attn_by_inverse_layer_idx:
|
32 |
+
attn_weights = attn_weights / float(self.layer_idx + 1)
|
33 |
+
|
34 |
+
# if not self.is_cross_attention:
|
35 |
+
# raise RuntimeError("KNN attention is not yet implemented for !cross_attention")
|
36 |
+
# # if only "normal" attention layer implements causal mask
|
37 |
+
# query_length, key_length = query.size(-3), key.size(-3)
|
38 |
+
# causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
39 |
+
# mask_value = torch.finfo(attn_weights.dtype).min
|
40 |
+
# # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
41 |
+
# # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
42 |
+
# mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
43 |
+
# attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
|
44 |
+
|
45 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
46 |
+
|
47 |
+
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
|
48 |
+
attn_weights = attn_weights.type(value.dtype)
|
49 |
+
attn_weights = self.knn_attn_dropout(attn_weights)
|
50 |
+
|
51 |
+
# masking missing keys
|
52 |
+
sh = mask.size()
|
53 |
+
attn_weights = attn_weights * mask.view((sh[0], 1, 1, 1, sh[1]))
|
54 |
+
|
55 |
+
# Mask heads if we want to
|
56 |
+
if head_mask is not None:
|
57 |
+
attn_weights = attn_weights * head_mask
|
58 |
+
|
59 |
+
attn_output = torch.matmul(attn_weights, value)
|
60 |
+
attn_output.squeeze_(dim=-2)
|
61 |
+
|
62 |
+
return attn_output
|
63 |
+
|
64 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
65 |
+
attn_output, attn_weights = super()._attn(
|
66 |
+
query, key, value, attention_mask, head_mask)
|
67 |
+
knn_key, knn_value, knn_mask = self.knn_memory.search(
|
68 |
+
query, self.num_retrieve_memories)
|
69 |
+
g = torch.sigmoid(self.attn_comb_bias)[:, None, None]
|
70 |
+
|
71 |
+
if knn_key.numel() == 0:
|
72 |
+
return attn_output * (1 - g), attn_weights
|
73 |
+
|
74 |
+
knn_key, knn_value, knn_mask = knn_key.to(
|
75 |
+
self.device), knn_value.to(self.device), knn_mask.to(self.device)
|
76 |
+
knn_attn_output = self._knn_attn(
|
77 |
+
query, knn_key, knn_value, knn_mask, head_mask)
|
78 |
+
|
79 |
+
# combining two attentions
|
80 |
+
attn = knn_attn_output * g + attn_output * (1 - g)
|
81 |
+
|
82 |
+
return attn, attn_weights
|
83 |
+
|
84 |
+
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
|
85 |
+
raise RuntimeError(
|
86 |
+
"KNN attention is not yet implemented for _upcast_and_reordered_attn")
|
87 |
+
|
88 |
+
def forward(
|
89 |
+
self,
|
90 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
91 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
92 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
93 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
94 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
95 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
96 |
+
use_cache: Optional[bool] = False,
|
97 |
+
output_attentions: Optional[bool] = False,
|
98 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
99 |
+
if encoder_hidden_states is not None:
|
100 |
+
if not hasattr(self, "q_attn"):
|
101 |
+
raise ValueError(
|
102 |
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
103 |
+
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
104 |
+
)
|
105 |
+
|
106 |
+
query = self.q_attn(hidden_states)
|
107 |
+
key, value = self.c_attn(encoder_hidden_states).split(
|
108 |
+
self.split_size, dim=2)
|
109 |
+
attention_mask = encoder_attention_mask
|
110 |
+
else:
|
111 |
+
query, key, value = self.c_attn(
|
112 |
+
hidden_states).split(self.split_size, dim=2)
|
113 |
+
|
114 |
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
115 |
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
116 |
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
117 |
+
|
118 |
+
# normalization of queries and keys reduces the effect of staleness
|
119 |
+
query, key = F.normalize(query, dim=-1), F.normalize(key, dim=-1)
|
120 |
+
new_memories = (key, value)
|
121 |
+
|
122 |
+
if layer_past is not None:
|
123 |
+
past_key, past_value = layer_past
|
124 |
+
key = torch.cat((past_key, key), dim=-2)
|
125 |
+
value = torch.cat((past_value, value), dim=-2)
|
126 |
+
|
127 |
+
if use_cache is True:
|
128 |
+
present = (key, value)
|
129 |
+
else:
|
130 |
+
present = None
|
131 |
+
|
132 |
+
if self.reorder_and_upcast_attn:
|
133 |
+
raise RuntimeError("Not implemented")
|
134 |
+
attn_output, attn_weights = self._upcast_and_reordered_attn(
|
135 |
+
query, key, value, attention_mask, head_mask)
|
136 |
+
else:
|
137 |
+
attn_output, attn_weights = self._attn(
|
138 |
+
query, key, value, attention_mask, head_mask)
|
139 |
+
|
140 |
+
attn_output = self._merge_heads(
|
141 |
+
attn_output, self.num_heads, self.head_dim)
|
142 |
+
attn_output = self.c_proj(attn_output)
|
143 |
+
attn_output = self.resid_dropout(attn_output)
|
144 |
+
|
145 |
+
outputs = (attn_output, present)
|
146 |
+
if output_attentions:
|
147 |
+
outputs += (attn_weights,)
|
148 |
+
|
149 |
+
self.knn_memory.add(*new_memories)
|
150 |
+
|
151 |
+
return outputs # a, present, (attentions)
|
knn_memory.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import deque
|
2 |
+
|
3 |
+
import faiss
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class KNN:
|
11 |
+
"""
|
12 |
+
KNN for one element in batch. Handles all heads
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, num_heads, head_dim, memories_size=16000, shrink_size=None, cache=None):
|
16 |
+
self.num_heads = num_heads
|
17 |
+
self.head_dim = head_dim
|
18 |
+
self.memories_size = memories_size
|
19 |
+
self.shrink_size = shrink_size or memories_size * 1.1
|
20 |
+
self.indexes = [faiss.IndexFlat(
|
21 |
+
self.head_dim, faiss.METRIC_INNER_PRODUCT) for _ in range(self.num_heads)]
|
22 |
+
self.values = [deque([]) for _ in range(self.num_heads)]
|
23 |
+
self.cache = cache
|
24 |
+
|
25 |
+
def __del__(self):
|
26 |
+
if hasattr(self, 'indexes'):
|
27 |
+
del self.indexes
|
28 |
+
del self.values
|
29 |
+
|
30 |
+
def clear(self):
|
31 |
+
for index in self.indexes:
|
32 |
+
index.reset()
|
33 |
+
|
34 |
+
for value in self.values:
|
35 |
+
value.clear()
|
36 |
+
|
37 |
+
def shrink(self):
|
38 |
+
"""Shrinks index to memories_size"""
|
39 |
+
|
40 |
+
for i, index in enumerate(self.indexes):
|
41 |
+
if index.ntotal > self.shrink_size:
|
42 |
+
to_delete = index.ntotal - self.memories_size
|
43 |
+
index.remove_ids(np.arange(0, to_delete))
|
44 |
+
|
45 |
+
for _ in range(to_delete):
|
46 |
+
self.values[i].popleft()
|
47 |
+
|
48 |
+
def add(self, key, value):
|
49 |
+
for i, k in enumerate(key):
|
50 |
+
self.indexes[i].add(k)
|
51 |
+
for i, v in enumerate(value):
|
52 |
+
self.values[i].extend(v)
|
53 |
+
|
54 |
+
if self.cache is not None:
|
55 |
+
raise RuntimeError("Cache for KNN not implemented")
|
56 |
+
# self.cache.add(key)
|
57 |
+
|
58 |
+
self.shrink()
|
59 |
+
|
60 |
+
def search(self, query, k=32):
|
61 |
+
"""
|
62 |
+
Searchs for query in keys' index.
|
63 |
+
Returns k most relevant keys and corresponding values
|
64 |
+
"""
|
65 |
+
|
66 |
+
k = min(k, len(self.values[0]))
|
67 |
+
|
68 |
+
if k <= 0:
|
69 |
+
return torch.empty((query.shape[0], query.shape[1], 0, query.shape[2])),\
|
70 |
+
torch.empty(
|
71 |
+
(query.shape[0], query.shape[1], 0, query.shape[2]))
|
72 |
+
|
73 |
+
Ks, Vs = [], []
|
74 |
+
|
75 |
+
for i, q in enumerate(query):
|
76 |
+
D, I, K = self.indexes[i].search_and_reconstruct(q, k=k)
|
77 |
+
V = np.take(self.values[i], indices=I, axis=0)
|
78 |
+
Ks.append(K)
|
79 |
+
Vs.append(V)
|
80 |
+
|
81 |
+
return np.stack(Ks, axis=0), np.stack(Vs, axis=0)
|
82 |
+
|
83 |
+
|
84 |
+
class KNNLayer:
|
85 |
+
"""
|
86 |
+
KNN Attention layer. Handles KNN's for batch (every elemnt separately)
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(self, config, share_memory=True, batch_size=None, memory_size=16000, shrink_size=None, n_jobs=4, cache=None):
|
90 |
+
if not share_memory and batch_size is None:
|
91 |
+
raise RuntimeError(
|
92 |
+
"If share_memory is False, batch_size should be passed")
|
93 |
+
|
94 |
+
self.embed_dim = config.hidden_size
|
95 |
+
self.num_heads = config.num_attention_heads
|
96 |
+
self.head_dim = self.embed_dim // self.num_heads
|
97 |
+
|
98 |
+
self.share_memory = share_memory
|
99 |
+
self.batch_size = batch_size
|
100 |
+
self.memory_size = memory_size
|
101 |
+
self.shrink_size = shrink_size or self.memory_size * 1.1
|
102 |
+
self.closed = False
|
103 |
+
|
104 |
+
if not share_memory:
|
105 |
+
self.knns = [KNN(self.num_heads, self.head_dim, memory_size,
|
106 |
+
self.shrink_size, cache=cache) for _ in range(self.batch_size)]
|
107 |
+
else:
|
108 |
+
self.knn = KNN(self.num_heads, self.head_dim,
|
109 |
+
memory_size, self.shrink_size, cache=cache)
|
110 |
+
|
111 |
+
faiss.omp_set_num_threads(n_jobs)
|
112 |
+
|
113 |
+
def clear_batches(self, batch_indexes):
|
114 |
+
if self.closed:
|
115 |
+
return
|
116 |
+
|
117 |
+
if not self.share_memory:
|
118 |
+
for idx in batch_indexes:
|
119 |
+
self.knns[idx].clear()
|
120 |
+
|
121 |
+
def clear(self):
|
122 |
+
if self.closed:
|
123 |
+
return
|
124 |
+
|
125 |
+
if self.share_memory:
|
126 |
+
self.knn.clear()
|
127 |
+
else:
|
128 |
+
for idx in range(len(self.knns)):
|
129 |
+
self.knns[idx].clear()
|
130 |
+
|
131 |
+
def add(self, keys, values):
|
132 |
+
if self.closed:
|
133 |
+
return
|
134 |
+
|
135 |
+
keys, values = keys.numpy(force=True), values.numpy(force=True)
|
136 |
+
if not self.share_memory:
|
137 |
+
for i, (key, value) in enumerate(zip(keys, values)):
|
138 |
+
self.knns[i].add(key, value)
|
139 |
+
else:
|
140 |
+
for key, value in zip(keys, values):
|
141 |
+
self.knn.add(key, value)
|
142 |
+
|
143 |
+
def search(self, queries, k=32):
|
144 |
+
queries = queries.numpy(force=True)
|
145 |
+
keys, values = [], []
|
146 |
+
max_len = 0
|
147 |
+
|
148 |
+
if self.share_memory:
|
149 |
+
for query in queries:
|
150 |
+
key, value = self.knn.search(query, k)
|
151 |
+
keys.append(key)
|
152 |
+
values.append(value)
|
153 |
+
max_len = max(max_len, key.shape[2])
|
154 |
+
else:
|
155 |
+
for i, query in enumerate(queries):
|
156 |
+
key, value = self.knns[i].search(query, k)
|
157 |
+
keys.append(key)
|
158 |
+
values.append(value)
|
159 |
+
max_len = max(max_len, key.shape[2])
|
160 |
+
|
161 |
+
masks = np.ones((len(keys), max_len), dtype=np.float32)
|
162 |
+
|
163 |
+
for i, (key, value) in enumerate(zip(keys, values)):
|
164 |
+
l = key.shape[2]
|
165 |
+
|
166 |
+
if l == max_len:
|
167 |
+
continue
|
168 |
+
elif l > max_len:
|
169 |
+
raise RuntimeError("What? max_len is not max")
|
170 |
+
|
171 |
+
sh = list(key.shape)
|
172 |
+
sh[2] = max_len - sh[2]
|
173 |
+
keys[i] = np.concatenate(
|
174 |
+
(key, np.zeros(sh, dtype=np.float32)), axis=2)
|
175 |
+
values[i] = np.concatenate(
|
176 |
+
(value, np.zeros(sh, dtype=np.float32)), axis=2)
|
177 |
+
masks[i, l:] = 0
|
178 |
+
|
179 |
+
return torch.from_numpy(np.stack(keys, axis=0)),\
|
180 |
+
torch.from_numpy(np.stack(values, axis=0)),\
|
181 |
+
torch.from_numpy(masks)
|
182 |
+
|
183 |
+
def close(self):
|
184 |
+
self.closed = True
|
185 |
+
|
186 |
+
def open(self):
|
187 |
+
self.closed = False
|
188 |
+
|
189 |
+
def reset(self):
|
190 |
+
self.open()
|
191 |
+
self.clear()
|
192 |
+
|
193 |
+
|
194 |
+
class ClearMemoryLayer(nn.Module):
|
195 |
+
def __init__(self, knn_memory, bos_token, eos_token, next_layer):
|
196 |
+
super().__init__()
|
197 |
+
|
198 |
+
self.knn_memory = knn_memory
|
199 |
+
self.bos_token = bos_token
|
200 |
+
self.eos_token = eos_token
|
201 |
+
self.next_layer = next_layer
|
202 |
+
|
203 |
+
def _clear_if_token(self, tokens, token):
|
204 |
+
batches_to_clear = (tokens == token).any(dim=-1).nonzero()
|
205 |
+
|
206 |
+
if len(batches_to_clear) > 0:
|
207 |
+
self.knn_memory.clear_batches(batches_to_clear[0])
|
208 |
+
|
209 |
+
def forward(self, tokens, *args, **kwargs):
|
210 |
+
# self._clear_if_token(tokens, self.bos_token)
|
211 |
+
|
212 |
+
batches_to_clear = (tokens[:, 0] == self.bos_token).nonzero()
|
213 |
+
|
214 |
+
if len(batches_to_clear) > 0:
|
215 |
+
self.knn_memory.clear_batches(batches_to_clear[:, 0])
|
216 |
+
|
217 |
+
res = self.next_layer(tokens, *args, **kwargs)
|
218 |
+
# self._clear_if_token(tokens, self.eos_token)
|
219 |
+
|
220 |
+
return res
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.28.1
|
2 |
+
torch==2.0.0
|
3 |
+
faiss-cpu==1.7.4
|
4 |
+
gradio==3.27.0
|
5 |
+
numpy==1.21.2
|
vector_cache.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class VectorCache:
|
8 |
+
"""
|
9 |
+
Caches vectors on disk so one can later build an index on them (indexes like IVF requires big amount of vetores for building)
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, filename='vector_cache.memmap', d=768, size=7000000):
|
13 |
+
self.filename = filename
|
14 |
+
self.offset_file = filename + '.offset'
|
15 |
+
self.d = d
|
16 |
+
self.size = size
|
17 |
+
|
18 |
+
if os.path.isfile(filename):
|
19 |
+
mode = 'r+'
|
20 |
+
self.f = open(self.offset_file, mode)
|
21 |
+
data = json.load(self.f)
|
22 |
+
self.offset = data[0]
|
23 |
+
self.length = data[1]
|
24 |
+
else:
|
25 |
+
mode = 'w+'
|
26 |
+
self.f = open(self.offset_file, mode)
|
27 |
+
self.offset = 0
|
28 |
+
self.length = 0
|
29 |
+
|
30 |
+
self.db = np.memmap(filename, dtype=np.float32, mode='w+',
|
31 |
+
shape=(size, d), order='C')
|
32 |
+
|
33 |
+
def sync_offset(self):
|
34 |
+
self.f.seek(0)
|
35 |
+
self.f.truncate(0)
|
36 |
+
self.f.write(json.dumps([self.offset, self.length]))
|
37 |
+
|
38 |
+
def close(self):
|
39 |
+
self.db.flush()
|
40 |
+
self.db.close()
|
41 |
+
|
42 |
+
self.sync_offset()
|
43 |
+
self.f.flush()
|
44 |
+
self.f.close()
|
45 |
+
|
46 |
+
def add(self, vs):
|
47 |
+
l = len(vs)
|
48 |
+
to_end = self.size - self.offset
|
49 |
+
|
50 |
+
if to_end < l:
|
51 |
+
self.add(vs[:to_end])
|
52 |
+
self.add(vs[to_end:])
|
53 |
+
return
|
54 |
+
|
55 |
+
self.db[self.offset:self.offset+l+1, :] = vs
|
56 |
+
self.offset = (self.offset + l + 1) % self.size
|
57 |
+
self.length = min(self.length + l, self.size)
|