from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data) -> None: super().__init__() self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): # Get data d = self.data[index] return d class EarlyStopping(): def __init__(self, tolerance=10, min_delta=0): self.tolerance = tolerance self.min_delta = min_delta self.counter = 0 self.early_stop = False def __call__(self, train_loss, min_loss): if (train_loss-min_loss) > self.min_delta: self.counter +=1 if self.counter >= self.tolerance: self.early_stop = True # def gen_text_from_center(args,plugin_vae, vae_model, decoder_tokenizer,label,epoch,pos): # gen_text = [] # latent_z = gen_latent_center(plugin_vae,pos).to(args.device).repeat((1,1)) # print("latent_z",latent_z.shape) # text_analogy = text_from_latent_code_batch(latent_z, vae_model, args, decoder_tokenizer) # print("label",label) # print(text_analogy) # gen_text.extend([(label,y,epoch) for y in text_analogy]) # text2out(gen_text, '/cognitive_comp/liangyuxin/projects/cond_vae/outputs/test.json')