conditioned-prior / README.md
nousr's picture
Update README.md
cafee7e
|
raw
history blame
1.73 kB
metadata
license: mit

A Text-Conditioned Diffusion-Prior

Training Details

Training details can be found here: https://wandb.ai/nousr_laion/conditioned-prior/reports/LAION-DALLE2-PyTorch-Prior--VmlldzoyMDI2OTIx

Source Code

Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch with defaults specified in the train_diffusion_prior.py script

Community: LAION

Join Us!: https://discord.gg/uPMftTmrvS


Models

depth=12
d_model=768
clip = OpenAIClipAdapter(clip_choice=["ViT-L/14" | "ViT-B/32"])

Loading the models might look something like this:

def load_diffusion_model(dprior_path, device, clip_choice):

    loaded_obj = torch.load(str(dprior_path), map_location='cpu')
    
    if clip_choice == "ViT-B/32":
        dim = 512
    else:
        dim = 768

    prior_network = DiffusionPriorNetwork(
        dim=dim,
        depth=12,
        dim_head=64,
        heads=12,
        normformer=True
    ).to(device)

    diffusion_prior = DiffusionPrior(
        net=prior_network,
        clip=OpenAIClipAdapter(clip_choice),
        image_embed_dim=dim,
        timesteps=1000,
        cond_drop_prob=0.1,
        loss_type="l2",
    ).to(device)


    diffusion_prior.load_state_dict(loaded_obj["model"], strict=True)

    diffusion_prior = DiffusionPriorTrainer(
                      diffusion_prior = diffusion_prior,
                      lr = 1.1e-4,
                      wd = 6.02e-2,
                      max_grad_norm = 0.5,
                      amp = False,
                  ).to(device)

    diffusion_prior.optimizer.load_state_dict(loaded_obj['optimizer'])
    diffusion_prior.scaler.load_state_dict(loaded_obj['scaler'])

    return diffusion_prior