import argparse import os import torch import mlflow from data_ingestion import create_dataloaders, test_data_ingestion from model import Generator, Discriminator, init_weights, test_models from train import train, test_training from app import setup_gradio_app EXPERIMENT_NAME = "Colorizer_Experiment" def setup_mlflow(): experiment = mlflow.get_experiment_by_name(EXPERIMENT_NAME) if experiment is None: experiment_id = mlflow.create_experiment(EXPERIMENT_NAME) else: experiment_id = experiment.experiment_id return experiment_id def run_pipeline(args): device = torch.device(args.device) print(f"Using device: {device}") experiment_id = setup_mlflow() if args.ingest_data or args.run_all: print("Starting data ingestion...") train_loader = create_dataloaders(batch_size=args.batch_size) if train_loader is None: print("Data ingestion failed.") return else: train_loader = None if args.create_model or args.train or args.run_all: print("Creating and testing models...") generator = Generator().to(device) discriminator = Discriminator().to(device) generator.apply(init_weights) discriminator.apply(init_weights) if not test_models(): print("Model creation or testing failed.") return else: generator = None discriminator = None if args.train or args.run_all: print("Starting model training...") if train_loader is None: print("Creating dataloader for training...") train_loader = create_dataloaders(batch_size=args.batch_size) if train_loader is None: print("Failed to create dataloader for training.") return if generator is None or discriminator is None: print("Creating models for training...") generator = Generator().to(device) discriminator = Discriminator().to(device) generator.apply(init_weights) discriminator.apply(init_weights) run_id = train(generator, discriminator, train_loader, num_epochs=args.num_epochs, device=device) if run_id: print(f"Training completed. Run ID: {run_id}") with open("latest_run_id.txt", "w") as f: f.write(run_id) else: print("Training failed.") return if args.test_training: print("Testing training process...") if train_loader is None: print("Creating dataloader for testing...") train_loader = create_dataloaders(batch_size=args.batch_size) if train_loader is None: print("Failed to create dataloader for testing.") return if generator is None or discriminator is None: print("Creating models for testing...") generator = Generator().to(device) discriminator = Discriminator().to(device) generator.apply(init_weights) discriminator.apply(init_weights) if test_training(generator, discriminator, train_loader, device): print("Training process test passed.") else: print("Training process test failed.") if args.serve or args.run_all: print("Setting up Gradio app for serving...") if not args.run_id: try: with open("latest_run_id.txt", "r") as f: args.run_id = f.read().strip() except FileNotFoundError: print("No run ID provided and couldn't find latest_run_id.txt") return iface = setup_gradio_app(args.run_id, device) iface.launch(share=args.share) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run Colorizer Pipeline") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use (cuda/cpu)") parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training") parser.add_argument("--num_epochs", type=int, default=50, help="Number of epochs to train") parser.add_argument("--run_id", type=str, help="MLflow run ID of the trained model for inference") parser.add_argument("--ingest_data", action="store_true", help="Run data ingestion") parser.add_argument("--create_model", action="store_true", help="Create and test the model") parser.add_argument("--train", action="store_true", help="Train the model") parser.add_argument("--test_training", action="store_true", help="Test the training process") parser.add_argument("--serve", action="store_true", help="Serve the model using Gradio") parser.add_argument("--run_all", action="store_true", help="Run all steps") parser.add_argument("--share", action="store_true", help="Share the Gradio app publicly") args = parser.parse_args() run_pipeline(args)