""" Training Flant_T5 model on tner/mit_restaurant on seq2seq task """ from dataclasses import asdict import torch import evaluate import datasets from torch.utils.data import DataLoader from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, get_scheduler, ) from accelerate import Accelerator import numpy as np import mlflow from tqdm.auto import tqdm from utils.logger import get_logger from configs import T5TrainingConfig from data import MITRestaurants, get_default_transforms log = get_logger("Flan_T5") log.debug("heloooooooooooo?") # get dataset transforms = get_default_transforms() dataset = ( MITRestaurants.from_hf("tner/mit_restaurant") .set_transforms(transforms) .hf_training() ) dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["test"]]) # log.info(dataset) print(dataset) tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base") def tokenize(example): """Tokenizes dataset for seq2seq task""" tokenized = tokenizer( example["tokens"], text_target=example["labels"], max_length=512, truncation=True, ) return tokenized tokenized_datasets = dataset.map( tokenize, batched=True, remove_columns=dataset["train"].column_names, ) # bleu metric metric = evaluate.load("sacrebleu") def postprocess(predictions, labels): """Post processing to convert model output for evaluation""" predictions = predictions.cpu().numpy() labels = labels.cpu().numpy() decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) # Replace -100 in the labels as we can't decode them. labels = np.where(labels != -100, labels, tokenizer.pad_token_id) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # Some simple post-processing decoded_preds = [pred.strip() for pred in decoded_preds] decoded_labels = [[label.strip()] for label in decoded_labels] return decoded_preds, decoded_labels config = T5TrainingConfig() # data collator data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) # data loaders tokenized_datasets.set_format("torch") train_dataloader = DataLoader( tokenized_datasets["train"], shuffle=True, collate_fn=data_collator, batch_size=config.train_batch_size, ) eval_dataloader = DataLoader( tokenized_datasets["validation"], collate_fn=data_collator, batch_size=config.eval_batch_size, ) # optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) num_update_steps_per_epoch = len(train_dataloader) num_training_steps = config.epochs * num_update_steps_per_epoch lr_scheduler = get_scheduler( "linear", optimizer=optimizer, num_warmup_steps=config.num_warmup_steps, num_training_steps=num_training_steps, ) # accelerator accelerator = Accelerator( mixed_precision=config.mixed_precision, gradient_accumulation_steps=config.gradient_accumulation_steps, ) model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( model, optimizer, train_dataloader, eval_dataloader ) progress_bar = tqdm(range(num_training_steps)) def train(): """Training function for finetuing flanT5""" # log.info("Starting Training") print("Starting Traning") for epoch in range(config.epochs): # Training model.train() for batch in train_dataloader: with accelerator.accumulate(model): outputs = model(**batch) loss = outputs.loss accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() progress_bar.update(1) # Evaluation model.eval() for batch in tqdm(eval_dataloader): with torch.no_grad(): generated_tokens = accelerator.unwrap_model(model).generate( batch["input_ids"], attention_mask=batch["attention_mask"], max_length=128, ) labels = batch["labels"] # Necessary to pad predictions and labels for being gathered generated_tokens = accelerator.pad_across_processes( generated_tokens, dim=1, pad_index=tokenizer.pad_token_id ) labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100) predictions_gathered = accelerator.gather(generated_tokens) labels_gathered = accelerator.gather(labels) decoded_preds, decoded_labels = postprocess( predictions_gathered, labels_gathered ) metric.add_batch(predictions=decoded_preds, references=decoded_labels) results = metric.compute() mlflow.log_metrics({"epoch": epoch, "BLEU score": results["score"]}) print(f"epoch {epoch}, BLEU score: {results['score']:.2f}") # Save and upload accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) unwrapped_model.save_pretrained( config.output_dir, save_function=accelerator.save ) if accelerator.is_main_process: tokenizer.save_pretrained(config.output_dir) # save model with mlflow mlflow.transformers.log_model( transformers_model={"model": unwrapped_model, "tokenizer": tokenizer}, task="text2text-generation", artifact_path="seq2seq_model", registered_model_name="FlanT5_MIT", ) mlflow.set_tracking_uri("http://127.0.0.1:5000") with mlflow.start_run() as mlflow_run: mlflow.log_params(asdict(config)) train()