annotate train_transformer.py @ 6:e94dc7945639 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
author bgruening
date Sun, 16 Oct 2022 11:52:10 +0000
parents
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
6
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
1 import tensorflow as tf
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
2 import transformer_network
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
3 import utils
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
4 from tensorflow.keras.layers import (Dense, Dropout, GlobalAveragePooling1D,
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
5 Input)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
6 from tensorflow.keras.models import Model
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
7
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
8
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
9 def create_model(vocab_size, config):
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
10 embed_dim = config["embedding_dim"]
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
11 ff_dim = config["feed_forward_dim"]
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
12 max_len = config["maximum_path_length"]
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
13 dropout = config["dropout"]
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
14
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
15 inputs = Input(shape=(max_len,))
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
16 embedding_layer = transformer_network.TokenAndPositionEmbedding(max_len, vocab_size, embed_dim)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
17 x = embedding_layer(inputs)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
18 transformer_block = transformer_network.TransformerBlock(embed_dim, config["n_heads"], ff_dim)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
19 x, weights = transformer_block(x)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
20 x = GlobalAveragePooling1D()(x)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
21 x = Dropout(dropout)(x)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
22 x = Dense(ff_dim, activation="relu")(x)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
23 x = Dropout(dropout)(x)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
24 outputs = Dense(vocab_size, activation="sigmoid")(x)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
25 return Model(inputs=inputs, outputs=[outputs, weights])
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
26
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
27
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
28 def create_enc_transformer(train_data, train_labels, test_data, test_labels, f_dict, r_dict, c_wts, c_tools, pub_conn, tr_t_freq, config):
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
29 print("Train transformer...")
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
30 vocab_size = len(f_dict) + 1
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
31
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
32 enc_optimizer = tf.keras.optimizers.Adam(learning_rate=config["learning_rate"])
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
33
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
34 model = create_model(vocab_size, config)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
35
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
36 u_tr_y_labels, u_tr_y_labels_dict = utils.get_u_tr_labels(train_labels)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
37 u_te_y_labels, u_te_y_labels_dict = utils.get_u_tr_labels(test_labels)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
38
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
39 trained_on_labels = [int(item) for item in list(u_tr_y_labels_dict.keys())]
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
40
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
41 epo_tr_batch_loss = list()
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
42 epo_tr_batch_acc = list()
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
43 all_sel_tool_ids = list()
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
44
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
45 te_lowest_t_ids = utils.get_low_freq_te_samples(test_data, test_labels, tr_t_freq)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
46 tr_log_step = config["tr_logging_step"]
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
47 te_log_step = config["te_logging_step"]
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
48 n_train_steps = config["n_train_iter"]
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
49 te_batch_size = config["te_batch_size"]
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
50 tr_batch_size = config["tr_batch_size"]
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
51 sel_tools = list()
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
52 for batch in range(n_train_steps):
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
53 x_train, y_train, sel_tools = utils.sample_balanced_tr_y(train_data, train_labels, u_tr_y_labels_dict, tr_batch_size, tr_t_freq, sel_tools)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
54 all_sel_tool_ids.extend(sel_tools)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
55 with tf.GradientTape() as model_tape:
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
56 prediction, att_weights = model(x_train, training=True)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
57 tr_loss, tr_cat_loss = utils.compute_loss(y_train, prediction)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
58 tr_acc = tf.reduce_mean(utils.compute_acc(y_train, prediction))
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
59 trainable_vars = model.trainable_variables
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
60 model_gradients = model_tape.gradient(tr_loss, trainable_vars)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
61 enc_optimizer.apply_gradients(zip(model_gradients, trainable_vars))
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
62 epo_tr_batch_loss.append(tr_loss.numpy())
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
63 epo_tr_batch_acc.append(tr_acc.numpy())
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
64 if (batch + 1) % tr_log_step == 0:
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
65 print("Total train data size: ", train_data.shape, train_labels.shape)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
66 print("Batch train data size: ", x_train.shape, y_train.shape)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
67 print("At Step {}/{} training loss:".format(str(batch + 1), str(n_train_steps)))
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
68 print(tr_loss.numpy())
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
69 if (batch + 1) % te_log_step == 0:
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
70 print("Predicting on test data...")
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
71 utils.validate_model(test_data, test_labels, te_batch_size, model, f_dict, r_dict, u_te_y_labels_dict, trained_on_labels, te_lowest_t_ids)
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
72 print("Saving model after training for {} steps".format(n_train_steps))
e94dc7945639 planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
bgruening
parents:
diff changeset
73 utils.save_model_file(model, r_dict, c_wts, c_tools, pub_conn, config["trained_model_path"])