PK!keras_nalu/__init__.pyPK!Fwxxkeras_nalu/nalu.py"""Keras NALU module""" from keras import backend as K from keras import constraints from keras import initializers from keras import regularizers from keras.engine import InputSpec from keras.layers import Layer from keras.utils.generic_utils import get_custom_objects class NALU(Layer): """Keras NALU layer""" def __init__( self, units, G_constraint=None, G_initializer='glorot_uniform', G_regularizer=None, M_hat_constraint=None, M_hat_initializer='glorot_uniform', M_hat_regularizer=None, W_hat_constraint=None, W_hat_initializer='glorot_uniform', W_hat_regularizer=None, cell=None, e=1e-28, **kwargs, ): assert cell in ['a', 'm', None] super(NALU, self).__init__(**kwargs) self.cell = cell self.G = None self.G_constraint = constraints.get(G_constraint) self.G_initializer = initializers.get(G_initializer) self.G_regularizer = regularizers.get(G_regularizer) self.M_hat = None self.M_hat_constraint = constraints.get(M_hat_constraint) self.M_hat_initializer = initializers.get(M_hat_initializer) self.M_hat_regularizer = regularizers.get(M_hat_regularizer) self.W_hat = None self.W_hat_constraint = constraints.get(W_hat_constraint) self.W_hat_initializer = initializers.get(W_hat_initializer) self.W_hat_regularizer = regularizers.get(W_hat_regularizer) self.e = e self.supports_masking = True self.units = units def build(self, input_shape): input_dim = input_shape[-1] if self.cell is None: self.G = self.add_weight( constraint=self.G_constraint, initializer=self.G_initializer, name='G', regularizer=self.G_regularizer, shape=(input_dim, self.units), ) self.M_hat = self.add_weight( constraint=self.M_hat_constraint, initializer=self.M_hat_initializer, name='M_hat', regularizer=self.M_hat_regularizer, shape=(input_dim, self.units), ) self.W_hat = self.add_weight( constraint=self.W_hat_constraint, initializer=self.W_hat_initializer, name='W_hat', regularizer=self.W_hat_regularizer, shape=(input_dim, self.units), ) self.built = True self.input_spec = InputSpec(axes={-1: input_dim}, min_ndim=2) def call(self, inputs, **kwargs): W = K.tanh(self.W_hat) * K.sigmoid(self.M_hat) a = K.dot(inputs, W) m = K.exp(K.dot(K.log(K.abs(inputs) + self.e), W)) if self.cell == 'a': y = a elif self.cell == 'm': y = m else: g = K.sigmoid(K.dot(inputs, self.G)) y = (g * a) + ((1 - g) * m) return y def compute_output_shape(self, input_shape): output_shape = list(input_shape) output_shape[-1] = self.units output_shape = tuple(output_shape) return output_shape def get_config(self): base_config = super(NALU, self).get_config() config = { 'G_constraint': constraints.serialize(self.G_constraint), 'G_initializer': initializers.serialize(self.G_initializer), 'G_regularizer': regularizers.serialize(self.G_regularizer), 'M_hat_constraint': constraints.serialize(self.M_hat_constraint), 'M_hat_initializer': initializers.serialize(self.M_hat_initializer), 'M_hat_regularizer': regularizers.serialize(self.M_hat_regularizer), 'W_hat_constraint': constraints.serialize(self.W_hat_constraint), 'W_hat_initializer': initializers.serialize(self.W_hat_initializer), 'W_hat_regularizer': regularizers.serialize(self.W_hat_regularizer), 'cell': self.cell, 'e': self.e, 'units': self.units, } return {**base_config, **config} get_custom_objects().update({'NALU': NALU}) PK!!keras_nalu/pretrained/__init__.pyPK!W?Ikeras_nalu/pretrained/model.py"""Keras NALU pretrained model""" from os import path from keras.models import load_model def get_model(): """Get the NALU pretrained model""" return load_model( path.join(path.dirname(__file__), 'model.h5') ) PK!?v v keras_nalu/pretrained/train.py"""Pretrain Keras NALU model on counting task""" from os import path from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TerminateOnNaN from keras.layers import Input from keras.models import Model from keras.optimizers import RMSprop import numpy as np from keras_nalu.nalu import NALU def generate_dataset(batch_size, number_width, multiplier): """Generate dataset for a task""" X = np.zeros((batch_size, 2 * number_width)) Y = np.zeros((batch_size)) for i in range(batch_size): a = multiplier * np.random.rand(number_width) b = multiplier * np.random.rand(number_width) X[i] = np.concatenate([a, b]) Y[i] = np.sum(a * b) return X, Y def train(): """Train Keras NALU model on counting task""" model_dir = path.dirname(__file__) number_width = 16 X_train, Y_train = generate_dataset( batch_size=2**18, multiplier=1, number_width=number_width, ) X_validation, Y_validation = generate_dataset( batch_size=2**9, multiplier=9999, number_width=number_width, ) X_test, Y_test = generate_dataset( batch_size=2**9, multiplier=9999, number_width=number_width, ) inputs = Input(shape=(2 * number_width,)) hidden = NALU(units=number_width, cell='m')(inputs) outputs = NALU(units=1, cell='a')(hidden) callbacks = [ TerminateOnNaN(), ReduceLROnPlateau( factor=0.1, min_lr=1e-16, patience=50, verbose=1, ), EarlyStopping( patience=200, restore_best_weights=True, verbose=1, ), ModelCheckpoint( filepath=path.join(model_dir, 'model.checkpoint.h5'), period=10, save_best_only=True, verbose=1, ), ] model = Model(inputs=inputs, outputs=outputs) model.summary() model.compile(loss='mae', optimizer=RMSprop(lr=0.01)) model.fit( batch_size=256, callbacks=callbacks, epochs=1000, validation_data=(X_validation, Y_validation), verbose=2, x=X_train, y=Y_train, ) model.evaluate( batch_size=256, verbose=1, x=X_test, y=Y_test, ) model.save(path.join(model_dir, 'model.h5')) if __name__ == '__main__': train() PK!HڽTU keras_nalu-1.2.0.dist-info/WHEEL A н#Z;/"d&F[xzw@Zpy3Fv]\fi4WZ^EgM_-]#0(q7PK!HH,+O #keras_nalu-1.2.0.dist-info/METADATAVn8}W"[M[UARAȥؠX،4R"hя!e'J۠M 3gf(w|$3^vXsٰ;8$)hv <D%VX;jN@ {*j^˥sNGResXkKK۵2ALZXTƠHXq!P|vAUzFl%(%*K/%V, N1hm &N[FWK8uPEOlr$۔W$ ^%igd 稕Tv =Bx UhV;u#/Z׌lF*Ot yЉNTIϫZ83OT EGf#>#g)uh*_[JR,vE~e!~W.+܊UL9b4~HttLb;p\[ǥ L ث9HݲQUJ@_GnuJ(ԧW+o%y8?=Hcu}VYtC $I'Hժ0Odť(:=,&a-= @[Z#eg]媩qL"75mp$LtYgsI; c7)t5K11[F-Q1QgL5nY Rxͺlwܡ۽lvQ Tf]2ݎ^HAmmd!\Ā>7Q9-eg %Yҝ3ߥ,Sw mv=}3t<5Dq3 Iq8w>ܴށd쐞4R5%nNy 7|"^ۏlsC4.)حHW5 -E H5gFX/EZPTl,ɆC.=R"EvkEmHS)̽yuȣZ_Cx%L7m)f (CAE"mpE5\HU{FJ}e,PK!HyrVbar!keras_nalu-1.2.0.dist-info/RECORD9s@>R@X( ( 7r3ɌXypsJ$U#s:>%&%yiuS(