memorizing_transformer_gpt2 / vector_cache.py
lavawolfiee's picture
Finally
6bc49a9
raw
history blame
1.54 kB
import json
import os
import numpy as np
class VectorCache:
"""
Caches vectors on disk so one can later build an index on them (indexes like IVF requires big amount of vetores for building)
"""
def __init__(self, filename='vector_cache.memmap', d=768, size=7000000):
self.filename = filename
self.offset_file = filename + '.offset'
self.d = d
self.size = size
if os.path.isfile(filename):
mode = 'r+'
self.f = open(self.offset_file, mode)
data = json.load(self.f)
self.offset = data[0]
self.length = data[1]
else:
mode = 'w+'
self.f = open(self.offset_file, mode)
self.offset = 0
self.length = 0
self.db = np.memmap(filename, dtype=np.float32, mode='w+',
shape=(size, d), order='C')
def sync_offset(self):
self.f.seek(0)
self.f.truncate(0)
self.f.write(json.dumps([self.offset, self.length]))
def close(self):
self.db.flush()
self.db.close()
self.sync_offset()
self.f.flush()
self.f.close()
def add(self, vs):
l = len(vs)
to_end = self.size - self.offset
if to_end < l:
self.add(vs[:to_end])
self.add(vs[to_end:])
return
self.db[self.offset:self.offset+l+1, :] = vs
self.offset = (self.offset + l + 1) % self.size
self.length = min(self.length + l, self.size)