import numpy as np
    import os

    from cntk import Trainer, Axis
    from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT
    from cntk.learners import momentum_sgd, fsadagrad, momentum_as_time_constant_schedule, learning_rate_schedule, UnitType
    from cntk import input, cross_entropy_with_softmax, classification_error, sequence, element_select, alias, hardmax, placeholder, combine, parameter, times, plus
    from cntk.ops.functions import CloneMethod, load_model, Function
    from cntk.initializer import glorot_uniform
    from cntk.logging import log_number_of_parameters, ProgressPrinter
    from cntk.logging.graph import plot
    from cntk.layers import *
    from cntk.layers.sequence import *
    from cntk.layers.models.attention import *
    from cntk.layers.typing import *


    data_eng = 'Data/translations/small_vocab_en'
    data_fr = 'Data/translations/small_vocab_fr'
    '../Data/tr'

    with open(data_eng, 'r', encoding='utf-8') as f:
        sentences_eng = f.read()

    with open(data_fr, 'r', encoding='utf-8') as f:
        sentences_fr = f.read()


    word_counts = [len(sentence.split()) for sentence in sentences_eng]
    print('Number of unique words in English: {}'.format(len({word: None for word in sentences_eng.lower().split()})))
    print('Number of sentences: {}'.format(len(sentences_eng)))
    print('Average number of words in a sentence: {}'.format(np.average(word_counts)))

    n_examples = 5
    for i in range(n_examples):
        print('\nExample {}'.format(i))
        print(sentences_eng.split('\n')[i])
        print(sentences_fr.split('\n')[i])


    def create_lookup_tables(text):
        vocab = set(text.split())
        vocab_to_int = {'<S>': 0, '<E>': 1, '<UNK>': 2, '<PAD>': 3 }

        for i, v in enumerate(vocab, len(vocab_to_int)):
            vocab_to_int[v] = i

        int_to_vocab = {i: v for v, i in vocab_to_int.items()}

        return vocab_to_int, int_to_vocab

    vocab_to_int_eng, int_to_vocab_eng = create_lookup_tables(sentences_eng.lower())
    vocab_to_int_fr, int_to_vocab_fr = create_lookup_tables(sentences_fr.lower())


    def text_to_ids(source_text, target_text, source_vocab_to_int, target_vocab_to_int):
        source_id_text = [[source_vocab_to_int[word] for word in sentence.split()] for sentence in source_text.split('\n')]
        target_id_text = [[target_vocab_to_int[word] for word in sentence.split()]+[target_vocab_to_int['<E>']] for sentence in target_text.split('\n')]

        return source_id_text, target_id_text

    X, y = text_to_ids(sentences_eng.lower(), sentences_fr.lower(), vocab_to_int_eng, vocab_to_int_fr)


    input_vocab_dim = 128
    label_vocab_dim = 128
    hidden_dim = 256
    num_layers = 2
    attention_dim = 128
    attention_span = 12
    embedding_dim = 200
    n_epochs = 20
    learning_rate = 0.001
    batch_size = 64


    def create_model(n_layers):
        embed = Embedding(embedding_dim, name='embed')

        LastRecurrence = C.layers.Recurrence
        encode = C.layers.Sequential([
            embed,
            C.layers.Stabilizer(),
            C.layers.For(range(num_layers-1), lambda:
                C.layers.Recurrence(C.layers.LSTM(hidden_dim))),
            LastRecurrence(C.layers.LSTM(hidden_dim), return_full_state=True),
            (C.layers.Label('encoded_h'), C.layers.Label('encoded_c')),
        ])

        with default_options(enable_self_stabilization=True):
            stab_in = Stabilizer()
            rec_blocks = [LSTM(hidden_dim) for i in range(n_layers)]
            stab_out = Stabilizer()
            out = Dense(label_vocab_dim, name='out')
            attention_model = AttentionModel(attention_dim, None, None, name='attention_model')

            @Function
            def decode(history, input):
                encoded_input = encode(input)
                r = history
                r = embed(r)
                r = stab_in(r)
                for i in range(n_layers):
                    rec_block = rec_blocks[i]
                    @Function
                    def lstm_with_attention(dh, dc, x):
                        h_att = attention_model(encoded_input.outputs[0], dh)
                        x = splice(x, h_att)
                        return rec_block(dh, dc, x)
                    r = Recurrence(lstm_with_attention)(r)
                r = stab_out(r)
                r = out(r)
                r = Label('out')(r)
                return r

        return decode


    def create_loss_function(model):
        @Function
        @Signature(input = InputSequence[Tensor[input_vocab_dim]], labels = LabelSequence[Tensor[label_vocab_dim]])
        def loss (input, labels):
            postprocessed_labels = sequence.slice(labels, 1, 0)
            z = model(input, postprocessed_labels)
            ce = cross_entropy_with_softmax(z, postprocessed_labels)
            errs = classification_error (z, postprocessed_labels)
            return (ce, errs)
        return loss


    def create_model_train(s2smodel):
        @Function
        def model_train(input, labels):
            past_labels = Delay(initial_state=sentence_start)(labels)
            return s2smodel(past_labels, input)
        return model_train


    def train(train_reader, valid_reader, vocab, i2w, s2smodel, max_epochs, epoch_size):
        model_train = create_model_train(s2smodel)
        loss = create_loss_function(model_train)
        learner = fsadagrad(model_train.parameters,
                            lr = learning_rate,
                            momentum = momentum_as_time_constant_schedule(1100),
                            gradient_clipping_threshold_per_sample=2.3,
                            gradient_clipping_with_truncation=True)
        trainer = Trainer(None, loss, learner)

        total_samples = 0

        for epoch in range(n_epochs):
            while total_samples < (epoch+1) * epoch_size:
                mb_train = train_reader.next_minibatch(minibatch_size)
                #trainer.train_minibatch(mb_train[train_reader.streams.features], mb_train[train_reader.streams.labels])
                trainer.train_minibatch({criterion.arguments[0]: mb_train[train_reader.streams.features], criterion.arguments[1]: mb_train[train_reader.streams.labels]})
                total_samples += mb_train[train_reader.streams.labels].num_samples



    def create_reader(path, is_training):
        return MinibatchSource(CTFDeserializer(path, StreamDefs(
            features = StreamDef(field='S0', shape=input_vocab_dim, is_sparse=True),
            labels   = StreamDef(field='S1', shape=label_vocab_dim, is_sparse=True)
        )), randomize = is_training, max_sweeps = INFINITELY_REPEAT if is_training else 1)


    model = create_model(2)
    loss = create_loss_function(model)
    learner = fsadagrad(model.parameters,
                        lr = learning_rate,
                        momentum = momentum_as_time_constant_schedule(1100),
                        gradient_clipping_threshold_per_sample=2.3,
                        gradient_clipping_with_truncation=True)
    trainer = Trainer(None, loss, learner)

    total_samples = 0

    for epoch in range(n_epochs):
        while total_samples < (epoch+1) * epoch_size:
            mb_train = train_reader.next_minibatch(minibatch_size)
            trainer.train_minibatch(mb_train[train_reader.streams.features], mb_train[train_reader.streams.labels])
            total_samples += mb_train[train_reader.streams.labels].num_samples