Whether neural style transfer has much concrete business value is questionable but it remains a fun thing to do and to look at. The theoretical value is however huge, it emphasizes how ‘style’ can be captured via a matrix (the Gram matrix) and how it potentially can be applied to text and sound. It makes the idea of ‘style’ more tangible and it’s amazingly effective. Below you can find an implementation using PyTorch but there are heaps of (TensorFlow) alternatives around.
Some terminology first. The image on which a style is applied is called the content image, the image with a particular style is called the style image and the result is called the pastiche (19th century French for imitation).
Style transfer is just another neural learning process but with two pulling forces: it tries to learn the content image and the style image at the same time thus creating a tension or mixture. The tension is captured by the loss function which accounts for how much it resembles both images. The crucial (and magical) bit is that the loss with respect to the style is not simply the difference between the images but with respect to the Gram matrix. Why this works is a bit of a mystery but you can find in this article an alternative view on things. So, if are the matrices representing the content, style and pastiche respectively, the loss in content is
with respect to some norm, usually the Frobenius norm, and the style loss is
with the Gram matrix of the style and pastiche. The total loss is then
where $latex\alpha$ acts as a weight emphasizing either style or content. Below you can see how this parameter affects the resulting pastiche. With this all the rest is basic neural network mechanics (i.e. SGD).
import torch import torch.nn as nn import torch.optim as optim import torchvision.models as models import torchvision.transforms as transforms from torch.autograd import Variable from PIL import Image import scipy.misc imsize = 256 class GramMatrix(nn.Module): def forward(self, input): a, b, c, d = input.size() features = input.view(a * b, c * d) G = torch.mm(features, features.t()) return G.div(a * b * c * d) class StyleCNN(object): def __init__(self, style, content, pastiche): super(StyleCNN, self).__init__() self.style = style self.content = content self.pastiche = nn.Parameter(pastiche.data) self.content_layers = ['conv_4'] self.style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] self.content_weight = 1 self.style_weight = 1000 self.loss_network = models.vgg19(pretrained=True) self.gram = GramMatrix() self.loss = nn.MSELoss() self.optimizer = optim.LBFGS([self.pastiche]) self.use_cuda = torch.cuda.is_available() if self.use_cuda: self.loss_network.cuda() self.gram.cuda() def train(self): def closure(): self.optimizer.zero_grad() pastiche = self.pastiche.clone() pastiche.data.clamp_(0, 1) content = self.content.clone() style = self.style.clone() content_loss = 0 style_loss = 0 i = 1 not_inplace = lambda layer: nn.ReLU(inplace=False) if isinstance(layer, nn.ReLU) else layer for layer in list(self.loss_network.features): layer = not_inplace(layer) if self.use_cuda: layer.cuda() pastiche, content, style = layer.forward(pastiche), layer.forward(content), layer.forward(style) if isinstance(layer, nn.Conv2d): name = "conv_" + str(i) if name in self.content_layers: content_loss += self.loss(pastiche * self.content_weight, content.detach() * self.content_weight) if name in self.style_layers: pastiche_g, style_g = self.gram.forward(pastiche), self.gram.forward(style) style_loss += self.loss(pastiche_g * self.style_weight, style_g.detach() * self.style_weight) if isinstance(layer, nn.ReLU): i += 1 total_loss = content_loss + style_loss total_loss.backward() return total_loss self.optimizer.step(closure) return self.pastiche loader = transforms.Compose([ transforms.Scale(imsize), transforms.ToTensor() ]) unloader = transforms.ToPILImage() def image_loader(image_name): image = Image.open(image_name) image = Variable(loader(image)) image = image.unsqueeze(0) return image def save_image(inputdata, path): image = inputdata.clone() image = image.view(3, imsize, imsize) image = unloader(image) scipy.misc.imsave(path, image) import torch.utils.data import torchvision.datasets as datasets def main(): # CUDA Configurations dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor # Content and style style = image_loader("style.jpg").type(dtype) content = image_loader("content.jpg").type(dtype) pastiche = image_loader("patiche.jpg").type(dtype) pastichedata = torch.randn(pastiche.size()).type(dtype) num_epochs = 31 style_cnn = StyleCNN(style, content, pastiche) for i in range(num_epochs): pastiche = style_cnn.train() if i % 10 == 0: print("Iteration: %d" % (i)) path = "/Users/swa/Downloads/%d.png" % (i) pastichedata.clamp_(0, 1) save_image(pastiche, path) main()
Upon animating the balance between content and style you get an animation like below.