PK!keras_nalu/__init__.pyPK!{nyykeras_nalu/nalu.py"""Keras NALU module""" from keras import backend as K from keras import constraints from keras import initializers from keras import regularizers from keras.engine import InputSpec from keras.layers import Layer from keras.utils.generic_utils import get_custom_objects class NALU(Layer): """Keras NALU layer""" def __init__( self, units, G_constraint=None, G_initializer='glorot_uniform', G_regularizer=None, M_hat_constraint=None, M_hat_initializer='glorot_uniform', M_hat_regularizer=None, W_hat_constraint=None, W_hat_initializer='glorot_uniform', W_hat_regularizer=None, cell=None, e=1e-7, **kwargs, ): assert cell in ['a', 'm', None] super(NALU, self).__init__(**kwargs) self.cell = cell self.G = None self.G_constraint = constraints.get(G_constraint) self.G_initializer = initializers.get(G_initializer) self.G_regularizer = regularizers.get(G_regularizer) self.M_hat = None self.M_hat_constraint = constraints.get(M_hat_constraint) self.M_hat_initializer = initializers.get(M_hat_initializer) self.M_hat_regularizer = regularizers.get(M_hat_regularizer) self.W_hat = None self.W_hat_constraint = constraints.get(W_hat_constraint) self.W_hat_initializer = initializers.get(W_hat_initializer) self.W_hat_regularizer = regularizers.get(W_hat_regularizer) self.e = e self.supports_masking = True self.units = units def build(self, input_shape): input_dim = input_shape[-1] if (self.cell is None): self.G = self.add_weight( constraint=self.G_constraint, initializer=self.G_initializer, name='G', regularizer=self.G_regularizer, shape=(input_dim, self.units), ) self.M_hat = self.add_weight( constraint=self.M_hat_constraint, initializer=self.M_hat_initializer, name='M_hat', regularizer=self.M_hat_regularizer, shape=(input_dim, self.units), ) self.W_hat = self.add_weight( constraint=self.W_hat_constraint, initializer=self.W_hat_initializer, name='W_hat', regularizer=self.W_hat_regularizer, shape=(input_dim, self.units), ) self.built = True self.input_spec = InputSpec(axes={-1: input_dim}, min_ndim=2) def call(self, inputs, **kwargs): W = K.tanh(self.W_hat) * K.sigmoid(self.M_hat) a = K.dot(inputs, W) m = K.exp(K.dot(K.log(K.abs(inputs) + self.e), W)) if self.cell == 'a': y = a elif self.cell == 'm': y = m else: g = K.sigmoid(K.dot(inputs, self.G)) y = (g * a) + ((1 - g) * m) return y def compute_output_shape(self, input_shape): output_shape = list(input_shape) output_shape[-1] = self.units output_shape = tuple(output_shape) return output_shape def get_config(self): base_config = super(NALU, self).get_config() config = { 'G_constraint': constraints.serialize(self.G_constraint), 'G_initializer': initializers.serialize(self.G_initializer), 'G_regularizer': regularizers.serialize(self.G_regularizer), 'M_hat_constraint': constraints.serialize(self.M_hat_constraint), 'M_hat_initializer': initializers.serialize(self.M_hat_initializer), 'M_hat_regularizer': regularizers.serialize(self.M_hat_regularizer), 'W_hat_constraint': constraints.serialize(self.W_hat_constraint), 'W_hat_initializer': initializers.serialize(self.W_hat_initializer), 'W_hat_regularizer': regularizers.serialize(self.W_hat_regularizer), 'cell': self.cell, 'e': self.e, 'units': self.units, } return {**base_config, **config} get_custom_objects().update({'NALU': NALU}) PK!HڽTU keras_nalu-1.1.0.dist-info/WHEEL A н#Z;/"d&F[xzw@Zpy3Fv]\fi4WZ^EgM_-]#0(q7PK!H.iM #keras_nalu-1.1.0.dist-info/METADATAVmo6_qC>DlN0vARA˰lF:Kl)#);wD?swϽQ>E O4Vf{ɔg ˖i4ཇ@-HAީǚrNYUM:-AKK7ke A*efpM#,\*c1\]O(GcO+H6!(%:{^NYJnX $=1hm &N[F׵hJ8MRE=?-^HXח^7z,Č|܍OfpZYA݌~Ghs#jro]JP놱~Dg.Dj1c~t(L.hK%\tdIA֡I|cFbf+1Q^ ps+V2ee:}9yL&cBہ:.eb`Zht^ ,A@ [Uw$ qVnUN}OC߁ ==?_kWE;2H|T] ]ANV\moob Foi?19-0Rv7Q1-e{GA lA;}GY4 g&,n3 E7HLpτC94"mpZ{rCzS+M4*G ]xQoh]R[nZ<0O oA՚n?j3 +W#X ʵJвmKƩHi|x@ZSrGq ud[c&证4^t ӽV[!)%%D/n /O'hH 蚆˼C_JkHqC.KPK!HYn!keras_nalu-1.1.0.dist-info/RECORD}̻v0o CDeɡm,{:xrK-PW!`撲~URmCɾr00sVxܦ@X돧BT>k)ˡ\&)/ ($O'"/_d`i.l{LO=я "!|VUXx֧^(ДWo'e$c{uYYU 3qq2 \kcx.*9i@Zi;SdPK!keras_nalu/__init__.pyPK!{nyy4keras_nalu/nalu.pyPK!HڽTU keras_nalu-1.1.0.dist-info/WHEELPK!H.iM #okeras_nalu-1.1.0.dist-info/METADATAPK!HYn!keras_nalu-1.1.0.dist-info/RECORDPKr