# Demystifying neural style transfer

Whether neural style transfer has much concrete business value is questionable but it remains a fun thing to do and to look at. The theoretical value is however huge, it emphasizes how ‘style’ can be captured via a matrix (the Gram matrix) and how it potentially can be applied to text and sound. It makes the idea of ‘style’ more tangible and it’s amazingly effective. Below you can find an implementation using PyTorch but there are heaps of (TensorFlow) alternatives around.

Some terminology first. The image on which a style is applied is called the content image, the image with a particular style is called the style image and the result is called the pastiche (19th century French for imitation).

Style transfer is just another neural learning process but with two pulling forces: it tries to learn the content image and the style image at the same time thus creating a tension or mixture. The tension is captured by the loss function which accounts for how much it resembles both images. The crucial (and magical) bit is that the loss with respect to the style is not simply the difference between the images but with respect to the Gram matrix. Why this works is a bit of a mystery but you can find in this article an alternative view on things. So, if $C, S, P$ are the matrices representing the content, style and pastiche respectively, the loss in content is $\mathcal{L}_C = \|C-P\|$

with respect to some norm, usually the Frobenius norm, and the style loss is $\mathcal{L}_S = \|G(S)-G(P)\|$

with $G(.)$ the Gram matrix of the style and pastiche. The total loss is then $\mathcal{L} - \alpha\mathcal{L}_C + (1-\alpha)\mathcal{L}_S$

where $latex\alpha$ acts as a weight emphasizing either style or content. Below you can see how this parameter affects the resulting pastiche. With this all the rest is basic neural network mechanics (i.e. SGD).

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import scipy.misc

imsize = 256

class GramMatrix(nn.Module):

def forward(self, input):
a, b, c, d = input.size()
features = input.view(a * b, c * d)
G = torch.mm(features, features.t())

return G.div(a * b * c * d)

class StyleCNN(object):
def __init__(self, style, content, pastiche):
super(StyleCNN, self).__init__()

self.style = style
self.content = content
self.pastiche = nn.Parameter(pastiche.data)

self.content_layers = ['conv_4']
self.style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
self.content_weight = 1
self.style_weight = 1000

self.loss_network = models.vgg19(pretrained=True)

self.gram = GramMatrix()
self.loss = nn.MSELoss()
self.optimizer = optim.LBFGS([self.pastiche])

self.use_cuda = torch.cuda.is_available()
if self.use_cuda:
self.loss_network.cuda()
self.gram.cuda()

def train(self):
def closure():

pastiche = self.pastiche.clone()
pastiche.data.clamp_(0, 1)
content = self.content.clone()
style = self.style.clone()

content_loss = 0
style_loss = 0

i = 1
not_inplace = lambda layer: nn.ReLU(inplace=False) if isinstance(layer, nn.ReLU) else layer
for layer in list(self.loss_network.features):
layer = not_inplace(layer)
if self.use_cuda:
layer.cuda()

pastiche, content, style = layer.forward(pastiche), layer.forward(content), layer.forward(style)

if isinstance(layer, nn.Conv2d):
name = "conv_" + str(i)

if name in self.content_layers:
content_loss += self.loss(pastiche * self.content_weight, content.detach() * self.content_weight)

if name in self.style_layers:
pastiche_g, style_g = self.gram.forward(pastiche), self.gram.forward(style)
style_loss += self.loss(pastiche_g * self.style_weight, style_g.detach() * self.style_weight)

if isinstance(layer, nn.ReLU):
i += 1

total_loss = content_loss + style_loss
total_loss.backward()

self.optimizer.step(closure)
return self.pastiche

transforms.Scale(imsize),
transforms.ToTensor()
])

image = Image.open(image_name)
image = image.unsqueeze(0)
return image

def save_image(inputdata, path):
image = inputdata.clone()
image = image.view(3, imsize, imsize)
scipy.misc.imsave(path, image)

import torch.utils.data
import torchvision.datasets as datasets

def main():

# CUDA Configurations
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

# Content and style

pastichedata = torch.randn(pastiche.size()).type(dtype)

num_epochs = 31
style_cnn = StyleCNN(style, content, pastiche)

for i in range(num_epochs):
pastiche = style_cnn.train()

if i % 10 == 0:
print("Iteration: %d" % (i))

pastichedata.clamp_(0, 1)
save_image(pastiche, path)

main()


Upon animating the balance between content and style you get an animation like below. 