simplyjaga
commited on
Commit
·
faadaec
1
Parent(s):
c1eb9f7
Create model.py
Browse files
model.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#imports
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torchvision.models as models
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
|
9 |
+
#modelling
|
10 |
+
class NeuralStyleTransfer(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super(NeuralStyleTransfer, self).__init__()
|
13 |
+
self.model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1).features
|
14 |
+
# picking the feature layers from the conv layers in the model
|
15 |
+
# this is choosen manually going through all the layers in the model, we can also experiment with this selection
|
16 |
+
self.feature_layers = [4, 6, 8, 10]
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
features = []
|
20 |
+
for layer_num, layer in enumerate(self.model):
|
21 |
+
x = layer(x)
|
22 |
+
#getting the selected layer's output from the model as features
|
23 |
+
if layer_num in self.feature_layers:
|
24 |
+
features.append(x)
|
25 |
+
return features
|
26 |
+
|
27 |
+
|
28 |
+
def get_output(style_image, content_image, alpha, beta, step, progress=gr.Progress()):
|
29 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
|
31 |
+
loader = transforms.Compose([transforms.Resize((224,224)),
|
32 |
+
transforms.ToTensor()]) #converting to tensor automatically scales the values to 0 and 1
|
33 |
+
style = loader(style_image).to(device)
|
34 |
+
content = loader(content_image).to(device)
|
35 |
+
#starting with content image instead of some random noise image to speed up the process
|
36 |
+
generated = loader(content_image).to(device)
|
37 |
+
|
38 |
+
# setting the generated images values to be tracked and modified while training
|
39 |
+
generated.requires_grad_(True)
|
40 |
+
|
41 |
+
# densenets weigths need not to be updated
|
42 |
+
model = NeuralStyleTransfer()
|
43 |
+
model.to(device)
|
44 |
+
model.eval()
|
45 |
+
|
46 |
+
#setting parameters
|
47 |
+
step_count = int(step)
|
48 |
+
learning_rate = 0.001
|
49 |
+
|
50 |
+
#custom loss is defined inside the training loop
|
51 |
+
#the values in the generated matrix needs to be updated by the optimizer
|
52 |
+
optimizer = torch.optim.Adam([generated], lr = learning_rate)
|
53 |
+
|
54 |
+
#training
|
55 |
+
for i in progress.tqdm(range(step_count)):
|
56 |
+
style_features = model(style.unsqueeze(0))
|
57 |
+
content_features = model(content.unsqueeze(0))
|
58 |
+
generated_features = model(generated.unsqueeze(0))
|
59 |
+
|
60 |
+
#content loss
|
61 |
+
content_loss = 0
|
62 |
+
for cf, gf in zip(content_features, generated_features):
|
63 |
+
content_loss += torch.sum((cf-gf)**2)
|
64 |
+
|
65 |
+
#style loss
|
66 |
+
style_loss = 0
|
67 |
+
for sf, gf in zip(style_features, generated_features):
|
68 |
+
bs, c, h, w = sf.shape
|
69 |
+
s_gram = torch.mm(sf.view(c, h*w), sf.view(c, h*w).T)
|
70 |
+
g_gram = torch.mm(gf.view(c, h*w), gf.view(c, h*w).T)
|
71 |
+
style_loss += torch.sum((s_gram - g_gram)**2)
|
72 |
+
|
73 |
+
#total_loss
|
74 |
+
loss = alpha * content_loss + beta * style_loss
|
75 |
+
|
76 |
+
#update values in the generated image
|
77 |
+
optimizer.zero_grad()
|
78 |
+
loss.backward()
|
79 |
+
optimizer.step()
|
80 |
+
|
81 |
+
if (i+1) % 5 == 0:
|
82 |
+
print(f"\nLoss at {i+1} epoch -----> {loss.item()}", end='')
|
83 |
+
|
84 |
+
convertor = transforms.ToPILImage() #converts tensor to pil image formate used for displaying in gradio
|
85 |
+
return convertor(generated)
|