diff recommend.py @ 0:8918de535391 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/rna_commander/tools/rna_tools/rna_commender commit 2fc7f3c08f30e2d81dc4ad19759dfe7ba9b0a3a1
author rnateam
date Tue, 31 May 2016 05:41:03 -0400
parents
children a609d6dc8047
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/recommend.py	Tue May 31 05:41:03 2016 -0400
@@ -0,0 +1,89 @@
+"""Recommend RNAs."""
+from __future__ import print_function
+
+import cPickle
+import sys
+from itertools import izip
+
+from utils import get_serendipity_val
+
+__author__ = "Gianluca Corrado"
+__copyright__ = "Copyright 2016, Gianluca Corrado"
+__license__ = "MIT"
+__maintainer__ = "Gianluca Corrado"
+__email__ = "gianluca.corrado@unitn.it"
+__status__ = "Production"
+
+
+class Predictor():
+    """Predict interactions."""
+
+    def __init__(self, predict_dataset, trained_model, serendipity_dic=None,
+                 output=None):
+        """
+        Constructor.
+
+        Parameters
+        ------
+        predict_dataset : data.PredictDataset
+            Dataset containing the examples to predict.
+
+        trained_model : str
+            File name of the trained model.
+
+        serendipity_dic : dict (default : None)
+            Dictionary with serendipy values.
+
+        output : str (default : None)
+            Output file. If None then STDOUT.
+        """
+        self.predict_dataset = predict_dataset
+        f = open(trained_model)
+        self.model = cPickle.load(f)
+        f.close()
+        try:
+            f = open(serendipity_dic)
+            self.serendipity_dic = cPickle.load(f)
+            f.close()
+        except:
+            self.serendipity_dic = None
+        self.output = output
+
+    def predict(self):
+        """Predict interaction values."""
+        # predict the y_hat
+        (p, p_names, r, r_names) = self.predict_dataset
+        y_hat = self.model.predict(p, r)
+        # sort the interactions according to y_hat
+        ordering = sorted(range(len(y_hat)),
+                          key=lambda x: y_hat[x], reverse=True)
+        p_names = p_names[ordering]
+        r_names = r_names[ordering]
+        y_hat = y_hat[ordering]
+
+        # output to STDOUT
+        if self.output is None:
+            print("RBP\ttarget\ty_hat\tserendipity")
+            if self.serendipity_dic is None:
+                for (p_, r_, s_) in izip(p_names, r_names, y_hat):
+                    print("%s\t%s\t%.3f\t---" % (p_, r_, s_))
+                    sys.stdout.flush()
+            else:
+                for (p_, r_, s_) in izip(p_names, r_names, y_hat):
+                    print("%s\t%s\t%.3f\t%.2f" %
+                          (p_, r_, s_,
+                           get_serendipity_val(self.serendipity_dic, r_)))
+                    sys.stdout.flush()
+        # output to file
+        else:
+            nf = open(self.output, "w")
+            nf.write("RBP\ttarget\ty_hat\tserendipity\n")
+            if self.serendipity_dic is None:
+                for (p_, r_, s_) in izip(p_names, r_names, y_hat):
+                    nf.write("%s\t%s\t%.3f\t---\n" % (p_, r_, s_))
+            else:
+                for (p_, r_, s_) in izip(p_names, r_names, y_hat):
+                    nf.write("%s\t%s\t%.3f\t%.2f\n" %
+                             (p_, r_, s_,
+                              get_serendipity_val(self.serendipity_dic, r_)))
+            nf.close()