simplyjaga's picture
Create model.py
faadaec
#imports
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
#modelling
class NeuralStyleTransfer(nn.Module):
def __init__(self):
super(NeuralStyleTransfer, self).__init__()
self.model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1).features
# picking the feature layers from the conv layers in the model
# this is choosen manually going through all the layers in the model, we can also experiment with this selection
self.feature_layers = [4, 6, 8, 10]
def forward(self, x):
features = []
for layer_num, layer in enumerate(self.model):
x = layer(x)
#getting the selected layer's output from the model as features
if layer_num in self.feature_layers:
features.append(x)
return features
def get_output(style_image, content_image, alpha, beta, step, progress=gr.Progress()):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loader = transforms.Compose([transforms.Resize((224,224)),
transforms.ToTensor()]) #converting to tensor automatically scales the values to 0 and 1
style = loader(style_image).to(device)
content = loader(content_image).to(device)
#starting with content image instead of some random noise image to speed up the process
generated = loader(content_image).to(device)
# setting the generated images values to be tracked and modified while training
generated.requires_grad_(True)
# densenets weigths need not to be updated
model = NeuralStyleTransfer()
model.to(device)
model.eval()
#setting parameters
step_count = int(step)
learning_rate = 0.001
#custom loss is defined inside the training loop
#the values in the generated matrix needs to be updated by the optimizer
optimizer = torch.optim.Adam([generated], lr = learning_rate)
#training
for i in progress.tqdm(range(step_count)):
style_features = model(style.unsqueeze(0))
content_features = model(content.unsqueeze(0))
generated_features = model(generated.unsqueeze(0))
#content loss
content_loss = 0
for cf, gf in zip(content_features, generated_features):
content_loss += torch.sum((cf-gf)**2)
#style loss
style_loss = 0
for sf, gf in zip(style_features, generated_features):
bs, c, h, w = sf.shape
s_gram = torch.mm(sf.view(c, h*w), sf.view(c, h*w).T)
g_gram = torch.mm(gf.view(c, h*w), gf.view(c, h*w).T)
style_loss += torch.sum((s_gram - g_gram)**2)
#total_loss
loss = alpha * content_loss + beta * style_loss
#update values in the generated image
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 5 == 0:
print(f"\nLoss at {i+1} epoch -----> {loss.item()}", end='')
convertor = transforms.ToPILImage() #converts tensor to pil image formate used for displaying in gradio
return convertor(generated)