simplyjaga commited on
Commit
faadaec
·
1 Parent(s): c1eb9f7

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +85 -0
model.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #imports
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.models as models
6
+ from torchvision import transforms
7
+
8
+
9
+ #modelling
10
+ class NeuralStyleTransfer(nn.Module):
11
+ def __init__(self):
12
+ super(NeuralStyleTransfer, self).__init__()
13
+ self.model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1).features
14
+ # picking the feature layers from the conv layers in the model
15
+ # this is choosen manually going through all the layers in the model, we can also experiment with this selection
16
+ self.feature_layers = [4, 6, 8, 10]
17
+
18
+ def forward(self, x):
19
+ features = []
20
+ for layer_num, layer in enumerate(self.model):
21
+ x = layer(x)
22
+ #getting the selected layer's output from the model as features
23
+ if layer_num in self.feature_layers:
24
+ features.append(x)
25
+ return features
26
+
27
+
28
+ def get_output(style_image, content_image, alpha, beta, step, progress=gr.Progress()):
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ loader = transforms.Compose([transforms.Resize((224,224)),
32
+ transforms.ToTensor()]) #converting to tensor automatically scales the values to 0 and 1
33
+ style = loader(style_image).to(device)
34
+ content = loader(content_image).to(device)
35
+ #starting with content image instead of some random noise image to speed up the process
36
+ generated = loader(content_image).to(device)
37
+
38
+ # setting the generated images values to be tracked and modified while training
39
+ generated.requires_grad_(True)
40
+
41
+ # densenets weigths need not to be updated
42
+ model = NeuralStyleTransfer()
43
+ model.to(device)
44
+ model.eval()
45
+
46
+ #setting parameters
47
+ step_count = int(step)
48
+ learning_rate = 0.001
49
+
50
+ #custom loss is defined inside the training loop
51
+ #the values in the generated matrix needs to be updated by the optimizer
52
+ optimizer = torch.optim.Adam([generated], lr = learning_rate)
53
+
54
+ #training
55
+ for i in progress.tqdm(range(step_count)):
56
+ style_features = model(style.unsqueeze(0))
57
+ content_features = model(content.unsqueeze(0))
58
+ generated_features = model(generated.unsqueeze(0))
59
+
60
+ #content loss
61
+ content_loss = 0
62
+ for cf, gf in zip(content_features, generated_features):
63
+ content_loss += torch.sum((cf-gf)**2)
64
+
65
+ #style loss
66
+ style_loss = 0
67
+ for sf, gf in zip(style_features, generated_features):
68
+ bs, c, h, w = sf.shape
69
+ s_gram = torch.mm(sf.view(c, h*w), sf.view(c, h*w).T)
70
+ g_gram = torch.mm(gf.view(c, h*w), gf.view(c, h*w).T)
71
+ style_loss += torch.sum((s_gram - g_gram)**2)
72
+
73
+ #total_loss
74
+ loss = alpha * content_loss + beta * style_loss
75
+
76
+ #update values in the generated image
77
+ optimizer.zero_grad()
78
+ loss.backward()
79
+ optimizer.step()
80
+
81
+ if (i+1) % 5 == 0:
82
+ print(f"\nLoss at {i+1} epoch -----> {loss.item()}", end='')
83
+
84
+ convertor = transforms.ToPILImage() #converts tensor to pil image formate used for displaying in gradio
85
+ return convertor(generated)