view recommend.py @ 5:b3462a72ff76 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/rna_commander/tools/rna_tools/rna_commender commit 7ef62aa3d86abd4b911e35447646712a4628e7fe
author rnateam
date Fri, 29 Jul 2016 03:27:18 -0400
parents a609d6dc8047
children
line wrap: on
line source

"""Recommend RNAs."""
from __future__ import print_function

import cPickle
import sys
from itertools import izip

from utils import get_serendipity_val

__author__ = "Gianluca Corrado"
__copyright__ = "Copyright 2016, Gianluca Corrado"
__license__ = "MIT"
__maintainer__ = "Gianluca Corrado"
__email__ = "gianluca.corrado@unitn.it"
__status__ = "Production"


class Predictor():
    """Predict interactions."""

    def __init__(self, predict_dataset, trained_model, serendipity_dic=None,
                 output=None):
        """
        Constructor.

        Parameters
        ------
        predict_dataset : data.PredictDataset
            Dataset containing the examples to predict.

        trained_model : str
            File name of the trained model.

        serendipity_dic : dict (default : None)
            Dictionary with serendipy values.

        output : str (default : None)
            Output file. If None then STDOUT.
        """
        self.predict_dataset = predict_dataset
        f = open(trained_model)
        self.model = cPickle.load(f)
        f.close()
        try:
            f = open(serendipity_dic)
            self.serendipity_dic = cPickle.load(f)
            f.close()
        except:
            self.serendipity_dic = None
        self.output = output

    def predict(self):
        """Predict interaction values."""
        # predict the y_hat
        (p, p_names, r, r_names) = self.predict_dataset
        assert p.dtype == 'float32'
        assert r.dtype == 'float32'
        y_hat = self.model.predict(p, r)
        # sort the interactions according to y_hat
        ordering = sorted(range(len(y_hat)),
                          key=lambda x: y_hat[x], reverse=True)
        p_names = p_names[ordering]
        r_names = r_names[ordering]
        y_hat = y_hat[ordering]

        # output to STDOUT
        if self.output is None:
            print("RBP\ttarget\ty_hat\tserendipity")
            if self.serendipity_dic is None:
                for (p_, r_, s_) in izip(p_names, r_names, y_hat):
                    print("%s\t%s\t%.3f\t---" % (p_, r_, s_))
                    sys.stdout.flush()
            else:
                for (p_, r_, s_) in izip(p_names, r_names, y_hat):
                    print("%s\t%s\t%.3f\t%.2f" %
                          (p_, r_, s_,
                           get_serendipity_val(self.serendipity_dic, r_)))
                    sys.stdout.flush()
        # output to file
        else:
            nf = open(self.output, "w")
            nf.write("RBP\ttarget\ty_hat\tserendipity\n")
            if self.serendipity_dic is None:
                for (p_, r_, s_) in izip(p_names, r_names, y_hat):
                    nf.write("%s\t%s\t%.3f\t---\n" % (p_, r_, s_))
            else:
                for (p_, r_, s_) in izip(p_names, r_names, y_hat):
                    nf.write("%s\t%s\t%.3f\t%.2f\n" %
                             (p_, r_, s_,
                              get_serendipity_val(self.serendipity_dic, r_)))
            nf.close()