Mercurial > repos > bgruening > run_jupyter_job
annotate test-data/tf-script.py @ 0:f4619200cb0a draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
author | bgruening |
---|---|
date | Sat, 11 Dec 2021 17:56:38 +0000 |
parents | |
children |
rev | line source |
---|---|
0
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
1 import numpy as np |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
2 import tensorflow as tf |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
3 |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
4 (mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data() |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
5 mnist_images, mnist_labels = mnist_images[:128], mnist_labels[:128] |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
6 dataset = tf.data.Dataset.from_tensor_slices((tf.cast(mnist_images[..., tf.newaxis] / 255, tf.float32), tf.cast(mnist_labels, tf.int64))) |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
7 dataset = dataset.shuffle(1000).batch(32) |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
8 |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
9 tot_loss = [] |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
10 epochs = 1 |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
11 |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
12 mnist_model = tf.keras.Sequential([ |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
13 tf.keras.layers.Conv2D(16, [3, 3], activation='relu'), |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
14 tf.keras.layers.Conv2D(16, [3, 3], activation='relu'), |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
15 tf.keras.layers.GlobalAveragePooling2D(), |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
16 tf.keras.layers.Dense(10) |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
17 ]) |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
18 |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
19 optimizer = tf.keras.optimizers.Adam() |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
20 loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
21 |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
22 for epoch in range(epochs): |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
23 loss_history = [] |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
24 for (batch, (images, labels)) in enumerate(dataset): |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
25 with tf.GradientTape() as tape: |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
26 logits = mnist_model(images, training=True) |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
27 loss_value = loss_object(labels, logits) |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
28 loss_history.append(loss_value.numpy().mean()) |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
29 grads = tape.gradient(loss_value, mnist_model.trainable_variables) |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
30 optimizer.apply_gradients(zip(grads, mnist_model.trainable_variables)) |
f4619200cb0a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/jupyter_job commit f945b1bff5008ba01da31c7de64e5326579394d6"
bgruening
parents:
diff
changeset
|
31 tot_loss.append(np.mean(loss_history)) |