Mercurial > repos > bgruening > create_tool_recommendation_model
diff transformer_network.py @ 6:e94dc7945639 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
author | bgruening |
---|---|
date | Sun, 16 Oct 2022 11:52:10 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/transformer_network.py Sun Oct 16 11:52:10 2022 +0000 @@ -0,0 +1,39 @@ +import tensorflow as tf +from tensorflow.keras.layers import (Dense, Dropout, Embedding, Layer, + LayerNormalization, MultiHeadAttention) +from tensorflow.keras.models import Sequential + + +class TransformerBlock(Layer): + def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1): + super(TransformerBlock, self).__init__() + self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=rate) + self.ffn = Sequential( + [Dense(ff_dim, activation="relu"), Dense(embed_dim)] + ) + self.layernorm1 = LayerNormalization(epsilon=1e-6) + self.layernorm2 = LayerNormalization(epsilon=1e-6) + self.dropout1 = Dropout(rate) + self.dropout2 = Dropout(rate) + + def call(self, inputs, training): + attn_output, attention_scores = self.att(inputs, inputs, inputs, return_attention_scores=True, training=training) + attn_output = self.dropout1(attn_output, training=training) + out1 = self.layernorm1(inputs + attn_output) + ffn_output = self.ffn(out1) + ffn_output = self.dropout2(ffn_output, training=training) + return self.layernorm2(out1 + ffn_output), attention_scores + + +class TokenAndPositionEmbedding(Layer): + def __init__(self, maxlen, vocab_size, embed_dim): + super(TokenAndPositionEmbedding, self).__init__() + self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim, mask_zero=True) + self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim, mask_zero=True) + + def call(self, x): + maxlen = tf.shape(x)[-1] + positions = tf.range(start=0, limit=maxlen, delta=1) + positions = self.pos_emb(positions) + x = self.token_emb(x) + return x + positions