PK!keras_nalu/__init__.pyPK! MMkeras_nalu/nalu.py"""Keras NALU module""" from keras import backend as K from keras import constraints from keras import initializers from keras import regularizers from keras.engine import InputSpec from keras.layers import Layer from keras.utils.generic_utils import get_custom_objects class NALU(Layer): """Keras NALU layer""" def __init__( self, units, G_constraint=None, G_initializer='glorot_uniform', G_regularizer=None, M_hat_constraint=None, M_hat_initializer='glorot_uniform', M_hat_regularizer=None, W_hat_constraint=None, W_hat_initializer='glorot_uniform', W_hat_regularizer=None, e=1e-7, **kwargs, ): super(NALU, self).__init__(**kwargs) self.G = None self.G_constraint = constraints.get(G_constraint) self.G_initializer = initializers.get(G_initializer) self.G_regularizer = regularizers.get(G_regularizer) self.M_hat = None self.M_hat_constraint = constraints.get(M_hat_constraint) self.M_hat_initializer = initializers.get(M_hat_initializer) self.M_hat_regularizer = regularizers.get(M_hat_regularizer) self.W_hat = None self.W_hat_constraint = constraints.get(W_hat_constraint) self.W_hat_initializer = initializers.get(W_hat_initializer) self.W_hat_regularizer = regularizers.get(W_hat_regularizer) self.e = e self.supports_masking = True self.units = units def build(self, input_shape): input_dim = input_shape[-1] self.G = self.add_weight( constraint=self.G_constraint, initializer=self.G_initializer, name='G', regularizer=self.G_regularizer, shape=(input_dim, self.units), ) self.M_hat = self.add_weight( constraint=self.M_hat_constraint, initializer=self.M_hat_initializer, name='M_hat', regularizer=self.M_hat_regularizer, shape=(input_dim, self.units), ) self.W_hat = self.add_weight( constraint=self.W_hat_constraint, initializer=self.W_hat_initializer, name='W_hat', regularizer=self.W_hat_regularizer, shape=(input_dim, self.units), ) self.built = True self.input_spec = InputSpec(axes={-1: input_dim}, min_ndim=2) def call(self, inputs, **kwargs): W = K.tanh(self.W_hat) * K.sigmoid(self.M_hat) a = K.dot(inputs, W) m = K.exp(K.dot(K.log(K.abs(inputs) + self.e), W)) g = K.sigmoid(K.dot(inputs, self.G)) y = (g * a) + ((1 - g) * m) return y def compute_output_shape(self, input_shape): output_shape = list(input_shape) output_shape[-1] = self.units output_shape = tuple(output_shape) return output_shape def get_config(self): base_config = super(NALU, self).get_config() config = { 'G_constraint': constraints.serialize(self.G_constraint), 'G_initializer': initializers.serialize(self.G_initializer), 'G_regularizer': regularizers.serialize(self.G_regularizer), 'M_hat_constraint': constraints.serialize(self.M_hat_constraint), 'M_hat_initializer': initializers.serialize(self.M_hat_initializer), 'M_hat_regularizer': regularizers.serialize(self.M_hat_regularizer), 'W_hat_constraint': constraints.serialize(self.W_hat_constraint), 'W_hat_initializer': initializers.serialize(self.W_hat_initializer), 'W_hat_regularizer': regularizers.serialize(self.W_hat_regularizer), 'e': self.e, 'units': self.units, } return {**base_config, **config} get_custom_objects().update({'NALU': NALU}) PK!HڽTU keras_nalu-1.0.3.dist-info/WHEEL A н#Z;/"d&F[xzw@Zpy3Fv]\fi4WZ^EgM_-]#0(q7PK!Hg҅E#keras_nalu-1.0.3.dist-info/METADATAU]o:}$Ү7 ([Wb[P?B[f Nll43v4ETP!ٙ93>9{ ^‹[N& 6''~+ܴ%ߣs=rֵ]N_QPCG**с%/u S#*,޸|6߶YY 8踕,q?>:'mBoVۜBHGϵj!UNO>SJMt {l9y7ȁɍ;0yN_}c2iʊMEWZq'o{ų3<8mSCi/b|e):ӋUNOh'wzyLtQ+=s6<$哜@keAWE#K yi G"z ٓ[ý1#:`Ä_SwkSNc4\kt"D`kdٻvnw4X~3|?Ҿy bsخ 6ϳ,# mpVx˖@10EVgOЄ)ʲ0 4 8pP N{4sLqT(\;x;AfNGl#}B(~> -t'mu|;u OƂбR>ㆣMovrg"znMVjceƨp~PK!H4n!keras_nalu-1.0.3.dist-info/RECORD}̹r@oYP@-RXp4; %1gR8y #ʹƘulĘ`((+[@A_ψp8۔N VITFͪxy:Q#QWP.*=e,T2 AA}8fȱN-x._:\9!:mK7Mrwbb93]:Wv3y){ѠU8&r @(CPK!keras_nalu/__init__.pyPK! MM4keras_nalu/nalu.pyPK!HڽTU keras_nalu-1.0.3.dist-info/WHEELPK!Hg҅E#Ckeras_nalu-1.0.3.dist-info/METADATAPK!H4n!keras_nalu-1.0.3.dist-info/RECORDPKr