|
|
|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.models as models |
|
from torchvision import transforms |
|
|
|
|
|
|
|
class NeuralStyleTransfer(nn.Module): |
|
def __init__(self): |
|
super(NeuralStyleTransfer, self).__init__() |
|
self.model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1).features |
|
|
|
|
|
self.feature_layers = [4, 6, 8, 10] |
|
|
|
def forward(self, x): |
|
features = [] |
|
for layer_num, layer in enumerate(self.model): |
|
x = layer(x) |
|
|
|
if layer_num in self.feature_layers: |
|
features.append(x) |
|
return features |
|
|
|
|
|
def get_output(style_image, content_image, alpha, beta, step, progress=gr.Progress()): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
loader = transforms.Compose([transforms.Resize((224,224)), |
|
transforms.ToTensor()]) |
|
style = loader(style_image).to(device) |
|
content = loader(content_image).to(device) |
|
|
|
generated = loader(content_image).to(device) |
|
|
|
|
|
generated.requires_grad_(True) |
|
|
|
|
|
model = NeuralStyleTransfer() |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
step_count = int(step) |
|
learning_rate = 0.001 |
|
|
|
|
|
|
|
optimizer = torch.optim.Adam([generated], lr = learning_rate) |
|
|
|
|
|
for i in progress.tqdm(range(step_count)): |
|
style_features = model(style.unsqueeze(0)) |
|
content_features = model(content.unsqueeze(0)) |
|
generated_features = model(generated.unsqueeze(0)) |
|
|
|
|
|
content_loss = 0 |
|
for cf, gf in zip(content_features, generated_features): |
|
content_loss += torch.sum((cf-gf)**2) |
|
|
|
|
|
style_loss = 0 |
|
for sf, gf in zip(style_features, generated_features): |
|
bs, c, h, w = sf.shape |
|
s_gram = torch.mm(sf.view(c, h*w), sf.view(c, h*w).T) |
|
g_gram = torch.mm(gf.view(c, h*w), gf.view(c, h*w).T) |
|
style_loss += torch.sum((s_gram - g_gram)**2) |
|
|
|
|
|
loss = alpha * content_loss + beta * style_loss |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if (i+1) % 5 == 0: |
|
print(f"\nLoss at {i+1} epoch -----> {loss.item()}", end='') |
|
|
|
convertor = transforms.ToPILImage() |
|
return convertor(generated) |
|
|