import numpy as np
import scipy.io
from matplotlib import pyplot as plt
from keras.utils import np_utils
from keras.models import Sequential, Input, Model
from keras.layers.core import Dense, Dropout, Activation, Reshape, Flatten, Lambda
from keras.layers import Conv2D, MaxPooling2D, UpSampling2D
from keras.callbacks import EarlyStopping
mat = scipy.io.loadmat('Data/train_32x32.mat')
mat = mat['X']
b, h, d, n = mat.shape
#Convert all RGB-Images to greyscale
img_gray = np.zeros(shape =(n, b, h, 1))
def rgb2gray(rgb):
return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])
for i in range(n):
#Convert to greyscale
img = rgb2gray(mat[:,:,:,i])
img = img.reshape(1, 32, 32, 1)
img_gray[i,:] = img
# Normalize input
img_gray = img_gray/255.
img_size = Input(shape=(b, h, 1))
x = Conv2D(16, (3, 3), activation='relu', padding='same')(img_size)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)
autoencoder = Model(img_size, decoded)
autoencoder.compile(optimizer='rmsprop', loss='binary_crossentropy')#, metrics=['binary_accuracy'])
# Output summary of network
autoencoder.summary()
callbacks = EarlyStopping(monitor='val_loss', patience=5)
n_epochs = 1000
batch_size = 128
autoencoder.fit(
img_gray, img_gray,
epochs=n_epochs,
batch_size=batch_size,
shuffle=True, validation_split=0.2
callbacks=callbacks
)
pred = autoencoder.predict(img_gray)
n = 5
plt.figure(figsize=(15, 5))
for i in range(n):
rand = np.random.randint(len(img_gray))
ax = plt.subplot(2, n, i + 1)
plt.imshow(img_gray[i].reshape(32, 32), cmap='gray')
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(pred[i].reshape(32, 32), cmap='gray')
plt.show()