import numpy as np
from keras.preprocessing import sequence
from keras.models import Sequential, Model
from keras.layers import Input, Dense, merge, Dropout, Lambda, Activation, Flatten, Embedding, LSTM, TimeDistributed, RepeatVector, Permute
from keras.callbacks import EarlyStopping
from keras.datasets import imdb
from keras import backend as K
n_words = 1000
(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=n_words)
print('Train seq: {}'.format(len(X_train)))
print('Test seq: {}'.format(len(X_train)))
Train seq: 25000
Test seq: 25000
print('Train example: \n{}'.format(X_train[0]))
print('\nTest example: \n{}'.format(X_test[0]))
# Note: the data is already preprocessed (words are mapped to vectors)
Train example:
[1, 14, 22, 16, 43, 530, 973, 2, 2, 65, 458, 2, 66, 2, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 2, 2, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2, 19, 14, 22, 4, 2, 2, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 2, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2, 2, 16, 480, 66, 2, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 2, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 2, 15, 256, 4, 2, 7, 2, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 2, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2, 56, 26, 141, 6, 194, 2, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 2, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 2, 88, 12, 16, 283, 5, 16, 2, 113, 103, 32, 15, 16, 2, 19, 178, 32]
Test example:
[1, 89, 27, 2, 2, 17, 199, 132, 5, 2, 16, 2, 24, 8, 760, 4, 2, 7, 4, 22, 2, 2, 16, 2, 17, 2, 7, 2, 2, 9, 4, 2, 8, 14, 991, 13, 877, 38, 19, 27, 239, 13, 100, 235, 61, 483, 2, 4, 7, 4, 20, 131, 2, 72, 8, 14, 251, 27, 2, 7, 308, 16, 735, 2, 17, 29, 144, 28, 77, 2, 18, 12]
y_train
array([1, 0, 0, ..., 0, 1, 0])
# Pad sequences with max_len
max_len = 200
X_train = sequence.pad_sequences(X_train, maxlen=max_len)
X_test = sequence.pad_sequences(X_test, maxlen=max_len)
# Define network architecture and compile
_input = Input(shape=[max_len], dtype='int32')
# get the embedding layer
embedded = Embedding(
input_dim=n_words,
output_dim=50,
input_length=max_len,
trainable=True,
# mask_zero=masking,
# weights=[embeddings]
)(_input)
activations = LSTM(100, return_sequences=True)(embedded)
attention = TimeDistributed(Dense(1, activation='tanh'))(activations)
attention = Flatten()(attention)
attention = Activation('softmax')(attention)
attention = RepeatVector(100)(attention)
attention = Permute([2, 1])(attention)
sent_representation = merge([activations, attention], mode='mul')
sent_representation = Lambda(lambda xin: K.sum(xin, axis=1))(sent_representation)
probabilities = Dense(1, activation='softmax')(sent_representation)
model = Model(input=_input, output=probabilities)
# model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=[])
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_14 (InputLayer) (None, 200) 0
____________________________________________________________________________________________________
embedding_11 (Embedding) (None, 200, 50) 50000 input_14[0][0]
____________________________________________________________________________________________________
lstm_11 (LSTM) (None, 200, 100) 60400 embedding_11[0][0]
____________________________________________________________________________________________________
time_distributed_11 (TimeDistrib (None, 200, 1) 101 lstm_11[0][0]
____________________________________________________________________________________________________
flatten_10 (Flatten) (None, 200) 0 time_distributed_11[0][0]
____________________________________________________________________________________________________
activation_10 (Activation) (None, 200) 0 flatten_10[0][0]
____________________________________________________________________________________________________
repeat_vector_9 (RepeatVector) (None, 100, 200) 0 activation_10[0][0]
____________________________________________________________________________________________________
permute_8 (Permute) (None, 200, 100) 0 repeat_vector_9[0][0]
____________________________________________________________________________________________________
merge_7 (Merge) (None, 200, 100) 0 lstm_11[0][0]
permute_8[0][0]
____________________________________________________________________________________________________
lambda_5 (Lambda) (None, 100) 0 merge_7[0][0]
____________________________________________________________________________________________________
dense_15 (Dense) (None, 1) 101 lambda_5[0][0]
====================================================================================================
Total params: 110,602
Trainable params: 110,602
Non-trainable params: 0
____________________________________________________________________________________________________
callbacks = [EarlyStopping(monitor='val_acc', patience=3)]
batch_size = 128
n_epochs = 100
model.fit(X_train, y_train, batch_size=batch_size, epochs=n_epochs, validation_split=0.2, callbacks=callbacks)
Train on 20000 samples, validate on 5000 samples
Epoch 1/100
20000/20000 [==============================] - 53s - loss: 7.9465 - acc: 0.5016 - val_loss: 8.0700 - val_acc: 0.4938
Epoch 2/100
20000/20000 [==============================] - 53s - loss: 7.9465 - acc: 0.5016 - val_loss: 8.0700 - val_acc: 0.4938
Epoch 3/100
20000/20000 [==============================] - 53s - loss: 7.9465 - acc: 0.5016 - val_loss: 8.0700 - val_acc: 0.4938
Epoch 4/100
12800/20000 [==================>...........] - ETA: 18s - loss: 7.9251 - acc: 0.5029
---------------------------------------------------------------------------
print('\nAccuracy on test set: {}'.format(model.evaluate(X_test, y_test)[1])
# Accuracy on test set: 0.81326