import os
    import glob
    import tensorflow as tf



    # Data can be downloaded at http://www.cvlibs.net/download.php?file=data_road.zip


    def get_tensors(sess, path):
        tf.saved_model.loader.load(sess, path)
        vgg_image_input = sess.graph.get_tensor_by_name('image_input:0')
        vgg_keep_prob = sess.graph.get_tensor_by_name('keep_prob:0')
        vgg_layer3_out = sess.graph.get_tensor_by_name('layer3_out:0')
        vgg_layer4_out = sess.graph.get_tensor_by_name('layer4_out:0')
        vgg_layer7_out = sess.graph.get_tensor_by_name('layer7_out:0')

        return vgg_image_input, vgg_keep_prob, vgg_layer3_out, vgg_layer4_out, vgg_layer7_out


    def layers(vgg_layer3_out, vgg_layer4_out, vgg_layer7_out, n_classes):
        encoder_layer1 = tf.layers.conv2d(vgg_layer3_out, n_classes, kernel_size=1, padding='same', kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-3))
        encoder_layer2 = tf.layers.conv2d(vgg_layer4_out, n_classes, kernel_size=1, padding='same', kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-3))
        encoder_layer3 = tf.layers.conv2d(vgg_layer7_out, n_classes, kernel_size=1, padding='same', kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-3))
        decoder_layer1 = tf.layers.conv2d_transpose(encoder_layer3, n_classes, kernel_size=4, strides=(2, 2), padding='same', kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-3))
        decoder_layer2 = tf.add(decoder_layer1, encoder_layer2)
        decoder_layer3 = tf.layers.conv2d_transpose(decoder_layer2, n_classes, kernel_size=4, strides=(2, 2), padding='same', kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-3))
        decoder_layer4 = tf.add(decoder_layer3, encoder_layer1)
        output = tf.layers.conv2d_transpose(decoder_layer4, n_classes, kernel_size=16, strides=(8, 8), padding='same', kernel_regularizer=tf.contrib.layers.l2_regularizer(1e-3))
        return output


    def batch_generator(batch_size):
        image_paths = glob(os.path.join(data_folder, 'image_2', '*.png'))
        label_paths = {
            re.sub(r'_(lane|road)_', '_', os.path.basename(path)): path
            for path in glob(os.path.join(data_folder, 'gt_image_2', '*_road_*.png'))}
        background_color = np.array([255, 0, 0])

        random.shuffle(image_paths)
        for batch_i in range(0, len(image_paths), batch_size):
            images = []
            gt_images = []
            for image_file in image_paths[batch_i:batch_i + batch_size]:
                gt_image_file = label_paths[os.path.basename(image_file)]

                image = scipy.misc.imresize(scipy.misc.imread(image_file), image_shape)
                gt_image = scipy.misc.imresize(scipy.misc.imread(gt_image_file), image_shape)

                gt_bg = np.all(gt_image == background_color, axis=2)
                gt_bg = gt_bg.reshape(*gt_bg.shape, 1)
                gt_image = np.concatenate((gt_bg, np.invert(gt_bg)), axis=2)

                images.append(image)
                gt_images.append(gt_image)

            yield np.array(images), np.array(gt_images)


    n_classes = 2
    image_shape = (160, 576)
    n_epochs = 23
    batch_size = 16
    path = '../Data/data_road/'


    with tf.Session() as sess:
        tf.saved_model.loader.load(sess, path, path)
        vgg_image_input = sess.graph.get_tensor_by_name('image_input:0')
        vgg_keep_prob = sess.graph.get_tensor_by_name('keep_prob:0')
        vgg_layer3_out = sess.graph.get_tensor_by_name('layer3_out:0')
        vgg_layer4_out = sess.graph.get_tensor_by_name('layer4_out:0')
        vgg_layer7_out = sess.graph.get_tensor_by_name('layer7_out:0')

        temp = set(tf.global_variables())
        out_layer = layers(vgg_layer3_out, vgg_layer4_out, vgg_layer7_out, n_classes)
        softmax = tf.nn.softmax(out_layer, name='softmax')
        logits = tf.reshape(out_layer, (-1, n_classes), name='logits')
        labels = tf.reshape(correct_label, (-1, n_classes))
        cross_entropy_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels))
        train_op = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy_loss)

        sess.run(tf.variables_initializer(set(tf.global_variables()) - temp))
        for i in range(n_epochs):
            batches = batch_generator(batch_size)
            epoch_loss = 0
            epoch_size = 0
            for batch_input, batch_label in batches:
                _, loss = sess.run([train_op, cross_entropy_loss], feed_dict={input_image: batch_input,
                                                                              correct_label: batch_label,
                                                                              keep_prob: 0.5,
                                                                              learning_rate: 1e-4})
                epoch_loss += loss * len(batch_input)
                epoch_size += len(batch_input)
            print("Loss at epoch {}: {}".format(i, epoch_loss/epoch_size))






    for image_file in glob(os.path.join(data_folder, 'image_2', '*.png')):    
        image = scipy.misc.imresize(scipy.misc.imread(image_file), image_shape)

        pred_softmax = sess.run(
            [tf.nn.softmax(logits)],
            {keep_prob: 1.0, image_pl: [image]})
        pred_softmax = pred_softmax[0][:, 1].reshape(image_shape[0], image_shape[1])
        segmentation = (pred_softmax > 0.5).reshape(image_shape[0], image_shape[1], 1)
        mask = np.dot(segmentation, np.array([[0, 255, 0, 127]]))
        mask = scipy.misc.toimage(mask, mode="RGBA")
        street_im = scipy.misc.toimage(image)
        street_im.paste(mask, box=None, mask=mask)