tikendraw's picture
application
a3e82d3
raw
history blame
2.18 kB
# train.py
import sys
import tensorflow as tf
# def create_datasets(x_train, y_train, text_vectorizer, batch_size):
# print('Building slices...')
# train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
# print('Mapping...')
# train_dataset = train_dataset.map(lambda x, y: (text_vectorizer(x), y), tf.data.AUTOTUNE)
# print('Prefetching...')
# train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
# return train_dataset
def sizeof_fmt(num, suffix='B'):
''' by Fred Cirera, https://stackoverflow.com/a/1094933/1870254, modified'''
for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']:
if abs(num) < 1024.0:
return "%3.1f %s%s" % (num, unit, suffix)
num /= 1024.0
return "%.1f %s%s" % (num, 'Yi', suffix)
for name, size in sorted(((name, sys.getsizeof(value)) for name, value in list(
locals().items())), key= lambda x: -x[1], reverse = False)[:10]:
print("{:>30}: {:>8}".format(name, sizeof_fmt(size)))
def data_generator(x, y):
num_samples = len(x)
for i in range(num_samples):
yield x[i], y[i]
def create_datasets(x, y, text_vectorizer, batch_size:int = 32, shuffle:bool=False, n_repeat:int = 0, buffer_size:int=1_000_000):
generator = data_generator(x, y)
print('Generating...')
train_dataset = tf.data.Dataset.from_generator(
lambda: generator,
output_signature=(
tf.TensorSpec(shape=(None, x.shape[1]), dtype=tf.string),
tf.TensorSpec(shape=(None, y.shape[1]), dtype=tf.int32)
)
)
print('Mapping...')
train_dataset = train_dataset.map(lambda x, y: (tf.cast(text_vectorizer(x), tf.int32)[0], y[0]), tf.data.AUTOTUNE)
train_dataset = train_dataset.batch(batch_size)
if shuffle:
train_dataset = train_dataset.shuffle(buffer_size)
if n_repeat > 0:
return train_dataset.cache().repeat(n_repeat).prefetch(tf.data.AUTOTUNE)
elif n_repeat == -1:
return train_dataset.cache().repeat().prefetch(tf.data.AUTOTUNE)
elif n_repeat == 0:
return train_dataset.cache().prefetch(tf.data.AUTOTUNE)