changeset 0:8918de535391 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/rna_commander/tools/rna_tools/rna_commender commit 2fc7f3c08f30e2d81dc4ad19759dfe7ba9b0a3a1
author rnateam
date Tue, 31 May 2016 05:41:03 -0400
parents
children 21130153e729
files data.py fasta_utils/__init__.py init.sh main.py model.py pfam_utils/__init__.py rbpfeatures.py recommend.py rnacommender.xml test-data/sample.fa utils/__init__.py
diffstat 11 files changed, 881 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/data.py	Tue May 31 05:41:03 2016 -0400
@@ -0,0 +1,90 @@
+"""Dataset handler."""
+
+import numpy as np
+
+import pandas as pd
+
+from theano import config
+
+__author__ = "Gianluca Corrado"
+__copyright__ = "Copyright 2016, Gianluca Corrado"
+__license__ = "MIT"
+__maintainer__ = "Gianluca Corrado"
+__email__ = "gianluca.corrado@unitn.it"
+__status__ = "Production"
+
+
+class Dataset(object):
+    """General dataset."""
+
+    def __init__(self, fp, fr, standardize_proteins=False,
+                 standardize_rnas=False):
+        """
+        Constructor.
+
+        Parameters
+        ----------
+        fp : str
+            Protein features
+
+        fr : str
+            The name of the HDF5 file containing features for the RNAs.
+        """
+        self.Fp = fp.astype(config.floatX)
+
+        store = pd.io.pytables.HDFStore(fr)
+        self.Fr = store.features.astype(config.floatX)
+        store.close()
+
+    def load(self):
+        """Load dataset in memory."""
+        raise NotImplementedError()
+
+
+class PredictDataset(Dataset):
+    """Test dataset."""
+
+    def __init__(self, fp, fr):
+        """
+        Constructor.
+
+        Parameters
+        ----------
+        fp : str
+            The name of the HDF5 file containing features for the proteins.
+
+        fr : str
+            The name of the HDF5 file containing features for the RNAs.
+        """
+        super(PredictDataset, self).__init__(fp, fr)
+
+    def load(self):
+        """
+        Load dataset in memory.
+
+        Return
+        ------
+        Examples to predict. For each example:
+            - p contains the protein features,
+            - r contains the RNA features,
+            - p_names contains the name of the protein,
+            - r_names contains the name of the RNA.
+
+        """
+        protein_input_dim = self.Fp.shape[0]
+        rna_input_dim = self.Fr.shape[0]
+        num_examples = self.Fp.shape[1] * self.Fr.shape[1]
+        p = np.zeros((num_examples, protein_input_dim)).astype(config.floatX)
+        p_names = []
+        r = np.zeros((num_examples, rna_input_dim)).astype(config.floatX)
+        r_names = []
+        index = 0
+        for protein in self.Fp.columns:
+            for rna in self.Fr.columns:
+                p[index] = self.Fp[protein]
+                p_names.append(protein)
+                r[index] = self.Fr[rna]
+                r_names.append(rna)
+                index += 1
+
+        return (p, np.array(p_names), r, np.array(r_names))
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/fasta_utils/__init__.py	Tue May 31 05:41:03 2016 -0400
@@ -0,0 +1,58 @@
+"""Util functions for FASTA format."""
+
+__author__ = "Gianluca Corrado"
+__copyright__ = "Copyright 2016, Gianluca Corrado"
+__license__ = "MIT"
+__maintainer__ = "Gianluca Corrado"
+__email__ = "gianluca.corrado@unitn.it"
+__status__ = "Production"
+
+
+def import_fasta(fasta_file):
+    """Import a fasta file as a dictionary."""
+    dic = {}
+    f = open(fasta_file)
+    fasta = f.read().strip()
+    f.close()
+    for a in fasta.split('>'):
+        k = a.split('\n')[0]
+        v = ''.join(a.split('\n')[1:])
+        if k != '':
+            dic[k] = v
+    return dic
+
+
+def export_fasta(dic):
+    """Export a dictionary."""
+    fasta = ""
+    for (k, v) in dic.iteritems():
+        fasta += ">%s\n%s\n" % (k, v)
+    return fasta
+
+
+def seq_names(fasta_file):
+    """Get sequence names from fasta file."""
+    names = []
+    f = open(fasta_file)
+    fasta = f.read()
+    f.close()
+    for a in fasta.split('>'):
+        names.append(a.split('\n')[0])
+    return [a for a in names if a != '']
+
+
+def stockholm2fasta(stockholm):
+    """Convert alignment in stockholm format to fasta format."""
+    fasta = ""
+    for line in stockholm.split("\n"):
+        # comment line
+        if line[0] == "#":
+            continue
+        # termination line
+        elif line == "//":
+            return fasta
+        # alignment line
+        else:
+            name, align = line.split()
+            seq = align.replace(".", "")
+            fasta += ">%s\n%s\n" % (name, seq)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/init.sh	Tue May 31 05:41:03 2016 -0400
@@ -0,0 +1,7 @@
+#!/usr/bin/bash
+
+if [ ! -d "AURA_Human_data" ]; then
+  wget http://www.googledrive.com/host/0B9v5_ppcfmgWNTJzVjlkc0pCMVU
+  tar -xf 0B9v5_ppcfmgWNTJzVjlkc0pCMVU
+  rm 0B9v5_ppcfmgWNTJzVjlkc0pCMVU
+fi
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/main.py	Tue May 31 05:41:03 2016 -0400
@@ -0,0 +1,49 @@
+#!/usr/bin/env python
+"""Recommendation."""
+
+import argparse
+import sys
+from rbpfeatures import RBPVectorizer
+from data import PredictDataset
+from recommend import Predictor
+
+from theano import config
+
+__author__ = "Gianluca Corrado"
+__copyright__ = "Copyright 2016, Gianluca Corrado"
+__license__ = "MIT"
+__maintainer__ = "Gianluca Corrado"
+__email__ = "gianluca.corrado@unitn.it"
+__status__ = "Production"
+
+config.floatX = 'float32'
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(
+        description=__doc__,
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('fasta', metavar='fasta', type=str,
+                        help="""Fasta file containing the RBP \
+                        sequences.""")
+
+    args = parser.parse_args()
+
+    v = RBPVectorizer(fasta=args.fasta)
+    rbp_fea = v.vectorize()
+
+    if rbp_fea is not None:
+        # Define and load dataset
+        D = PredictDataset(
+            fp=rbp_fea, fr="AURA_Human_data/RNA_features/HT_utrs.h5")
+        dataset = D.load()
+
+        model = "AURA_Human_data/model/trained_model.pkl"
+
+        # Define the Trainer and train the model
+        P = Predictor(predict_dataset=dataset,
+                      trained_model=model,
+                      serendipity_dic=model + '_',
+                      output="output.txt")
+        P.predict()
+    else:
+        sys.exit("""The queried protein has no domain similarity with the proteins in the training dataset. It cannot be predicted.""")
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/model.py	Tue May 31 05:41:03 2016 -0400
@@ -0,0 +1,141 @@
+"""Recommender model."""
+from __future__ import print_function
+
+import sys
+
+import numpy as np
+
+from theano import config, function, shared
+import theano.tensor as T
+
+__author__ = "Gianluca Corrado"
+__copyright__ = "Copyright 2016, Gianluca Corrado"
+__license__ = "MIT"
+__maintainer__ = "Gianluca Corrado"
+__email__ = "gianluca.corrado@unitn.it"
+__status__ = "Production"
+
+
+class Model():
+    """Factorization model."""
+
+    def __init__(self, sp, sr, kp, kr, irange=0.01, learning_rate=0.01,
+                 lambda_reg=0.01, verbose=True, seed=1234):
+        """
+        Constructor.
+
+        Parameters
+        ----------
+        sp : int
+            Number of protein features.
+
+        sr : int
+            Number of RNA features.
+
+        kp : int
+            Size of the protein latent space.
+
+        kr : int
+            Size of the RNA latent space.
+
+        irange : float (default : 0.01)
+            Initialization range for the model weights.
+
+        learning_rate : float (default : 0.01)
+            Learning rate for the weights update.
+
+        lambda_reg : (default : 0.01)
+            Lambda parameter for the regularization.
+
+        verbose : bool (default : True)
+            Print information at STDOUT.
+
+        seed : int (default : 1234)
+            Seed for random number generator.
+        """
+        if verbose:
+            print("Compiling model...", end=' ')
+            sys.stdout.flush()
+
+        self.learning_rate = learning_rate
+        self.lambda_reg = lambda_reg
+        np.random.seed(seed)
+        # explictit features for proteins
+        fp = T.matrix("Fp", dtype=config.floatX)
+        # explictit features for RNAs
+        fr = T.matrix("Fr", dtype=config.floatX)
+        # Correct label
+        y = T.vector("y")
+
+        # projection matrix for proteins
+        self.Ap = shared(((.5 - np.random.rand(kp, sp)) *
+                          irange).astype(config.floatX), name="Ap")
+        self.bp = shared(((.5 - np.random.rand(kp)) *
+                          irange).astype(config.floatX), name="bp")
+        # projection matrix for RNAs
+        self.Ar = shared(((.5 - np.random.rand(kr, sr)) *
+                          irange).astype(config.floatX), name="Ar")
+        self.br = shared(((.5 - np.random.rand(kr)) *
+                          irange).astype(config.floatX), name="br")
+        # generalization matrix
+        self.B = shared(((.5 - np.random.rand(kp, kr)) *
+                         irange).astype(config.floatX), name="B")
+
+        # Latent space for proteins
+        p = T.nnet.sigmoid(T.dot(fp, self.Ap.T) + self.bp)
+        # Latent space for RNAs
+        r = T.nnet.sigmoid(T.dot(fr, self.Ar.T) + self.br)
+        # Predicted output
+        y_hat = T.nnet.sigmoid(T.sum(T.dot(p, self.B) * r, axis=1))
+
+        def _regularization():
+            """Normalized Frobenius norm."""
+            norm_proteins = self.Ap.norm(2) + self.bp.norm(2)
+            norm_rnas = self.Ar.norm(2) + self.br.norm(2)
+            norm_b = self.B.norm(2)
+
+            num_proteins = self.Ap.flatten().shape[0] + self.bp.shape[0]
+            num_rnas = self.Ar.flatten().shape[0] + self.br.shape[0]
+            num_b = self.B.flatten().shape[0]
+
+            return (norm_proteins / num_proteins + norm_rnas / num_rnas +
+                    norm_b / num_b) / 3
+
+        # mean squared error
+        cost_ = (T.sqr(y - y_hat)).mean()
+        reg = lambda_reg * _regularization()
+        cost = cost_ + reg
+
+        # compute sgd updates
+        g_Ap, g_bp, g_Ar, g_br, g_B = T.grad(
+            cost, [self.Ap, self.bp, self.Ar, self.br, self.B])
+        updates = ((self.Ap, self.Ap - learning_rate * g_Ap),
+                   (self.bp, self.bp - learning_rate * g_bp),
+                   (self.Ar, self.Ar - learning_rate * g_Ar),
+                   (self.br, self.br - learning_rate * g_br),
+                   (self.B, self.B - learning_rate * g_B))
+
+        # training step
+        self.train = function(
+            inputs=[fp, fr, y],
+            outputs=[y_hat, cost],
+            updates=updates)
+        # test
+        self.test = function(
+            inputs=[fp, fr, y],
+            outputs=[y_hat, cost])
+
+        # predict
+        self.predict = function(
+            inputs=[fp, fr],
+            outputs=y_hat)
+
+        if verbose:
+            print("Done.")
+            sys.stdout.flush()
+
+    def get_params(self):
+        """Return the parameters of the model."""
+        return {'Ap': self.Ap.get_value(), 'bp': self.bp.get_value(),
+                'Ar': self.Ar.get_value(), 'br': self.br.get_value(),
+                'B': self.B.get_value()}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pfam_utils/__init__.py	Tue May 31 05:41:03 2016 -0400
@@ -0,0 +1,139 @@
+"""Utils for PFAM."""
+
+import xml.etree.ElementTree as ET
+from math import ceil
+from time import sleep
+from xml.etree.ElementTree import ParseError
+
+import requests
+
+import pandas as pd
+
+import fasta_utils
+
+__author__ = "Gianluca Corrado"
+__copyright__ = "Copyright 2016, Gianluca Corrado"
+__license__ = "MIT"
+__maintainer__ = "Gianluca Corrado"
+__email__ = "gianluca.corrado@unitn.it"
+__status__ = "Production"
+
+
+def search_header():
+    """Return the header of a Pfam scan search."""
+    return "<seq id>        <alignment start>       <alignment end> \
+    <envelope start>        <envelope end>  <hmm acc>       <hmm name>\
+          <type>  <hmm start>     <hmm end>       <hmm length>    <bit score>\
+               <E-value>       <significance>  <clan>\n"
+
+
+def sequence_search(seq_id, seq):
+    """
+    Search a sequence against PFAM.
+
+    Input
+    -----
+    seq_id : str
+        Name of the protein sequence.
+    seq : str
+        Protein sequence.
+
+    Output
+    ------
+    ret : str
+        Formatted string containing the results of the Pfam scan for the
+        given sequence
+    """
+    def add_spaces(text, mul=8):
+        """Add spaces to a string."""
+        l = len(text)
+        next_mul = int(ceil(l / mul) + 1) * mul
+        offset = next_mul - l
+        if offset == 0:
+            offset = 8
+        return text + " " * offset
+
+    url = "http://pfam.xfam.org/search/sequence"
+    params = {'seq': seq,
+              'evalue': '1.0',
+              'output': 'xml'}
+    req = requests.get(url, params=params)
+    xml = req.text
+    try:
+        root = ET.fromstring(xml)
+    # sometimes Pfam returns the HTML code
+    except ParseError:
+        print "resending: %s" % seq_id
+        return "%s" % sequence_search(seq_id, seq)
+
+    result_url = root[0][1].text
+    # wait for Pfam to compute the results
+    sleep(4)
+    while True:
+        req2 = requests.get(result_url)
+        if req2.status_code == 200:
+            break
+        else:
+            sleep(1)
+    result_xml = req2.text
+    root = ET.fromstring(result_xml)
+    try:
+        matches = root[0][0][0][0][:]
+    # Sometimes raised when the sequence has no matches
+    except IndexError:
+        return ""
+    ret = ""
+    for match in matches:
+        for location in match:
+            ret += add_spaces(seq_id)
+            ret += add_spaces(location.attrib['ali_start'])
+            ret += add_spaces(location.attrib['ali_end'])
+            ret += add_spaces(location.attrib['start'])
+            ret += add_spaces(location.attrib['end'])
+            ret += add_spaces(match.attrib['accession'])
+            ret += add_spaces(match.attrib['id'])
+            ret += add_spaces(match.attrib['class'])
+            ret += add_spaces(location.attrib['hmm_start'])
+            ret += add_spaces(location.attrib['hmm_end'])
+            ret += add_spaces("None")
+            ret += add_spaces(location.attrib['bitscore'])
+            ret += add_spaces(location.attrib['evalue'])
+            ret += add_spaces(location.attrib['significant'])
+            ret += "None\n"
+    return ret
+
+
+def read_pfam_output(pfam_out_file):
+    """Read the output of PFAM scan."""
+    cols = ["seq_id", "alignment_start", "alignment_end", "envelope_start",
+            "envelope_end", "hmm_acc", "hmm_name", "type", "hmm_start",
+            "hmm_end", "hmm_length", "bit_score", "E-value", "significance",
+            "clan"]
+    try:
+        data = pd.read_table(pfam_out_file,
+                             sep="\s*", skip_blank_lines=True, skiprows=1,
+                             names=cols, engine='python')
+    except:
+        return None
+    return data
+
+
+def download_seed_seqs(acc):
+    """
+    Download seed sequences from PFAM.
+
+    Input
+    -----
+    acc : str
+        Accession number of a Pfam domain
+
+    Output
+    ------
+    fasta : str
+        Seed sequences in fasta format
+    """
+    url = "http://pfam.xfam.org/family/%s/alignment/seed" % acc
+    req = requests.get(url)
+    stockholm = req.text
+    fasta = fasta_utils.stockholm2fasta(stockholm)
+    return fasta
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/rbpfeatures.py	Tue May 31 05:41:03 2016 -0400
@@ -0,0 +1,217 @@
+"""Compute the RBP features."""
+
+import re
+import subprocess as sp
+import uuid
+from os import mkdir
+from os import listdir
+from os.path import isfile, join
+from os import devnull
+from shutil import rmtree
+
+import numpy as np
+
+import pandas as pd
+
+import fasta_utils
+import pfam_utils
+
+__author__ = "Gianluca Corrado"
+__copyright__ = "Copyright 2016, Gianluca Corrado"
+__license__ = "MIT"
+__maintainer__ = "Gianluca Corrado"
+__email__ = "gianluca.corrado@unitn.it"
+__status__ = "Production"
+
+
+class RBPVectorizer():
+    """Compute the RBP features."""
+
+    def __init__(self, fasta):
+        """
+        Constructor.
+
+        Parameters
+        ----------
+        fasta : str
+            Fasta file containing the RBP sequences to predict.
+        """
+        self.fasta = fasta
+
+        self._mod_fold = "AURA_Human_data/RBP_features/mod"
+        self._reference_fisher_scores = \
+            "AURA_Human_data/RBP_features/fisher_scores_ref"
+        self._train_rbps_file = \
+            "AURA_Human_data/RBP_features/rbps_in_train.txt"
+
+        self._temp_fold = "temp_" + str(uuid.uuid4())
+        self.pfam_scan = "%s/pfam_scan.txt" % self._temp_fold
+        self._dom_fold = "%s/domains" % self._temp_fold
+        self._seeds_fold = "%s/seeds" % self._temp_fold
+        self._fisher_fold = "%s/fisher_scores" % self._temp_fold
+
+    def _pfam_scan(self):
+        """Scan the sequences against the Pfam database."""
+        nf = open(self.pfam_scan, "w")
+        nf.write(pfam_utils.search_header())
+
+        fasta = fasta_utils.import_fasta(self.fasta)
+
+        for rbp in sorted(fasta.keys()):
+            seq = fasta[rbp]
+            text = pfam_utils.sequence_search(rbp, seq)
+            nf.write(text)
+
+        nf.close()
+
+    def _overlapping_domains(self):
+        """Compute the set of domains contributing to the similarity."""
+        reference_domains = set([dom.replace(".mod", "") for dom in
+                                 listdir(self._mod_fold) if
+                                 isfile(join(self._mod_fold, dom))])
+
+        data = pfam_utils.read_pfam_output(self.pfam_scan)
+        if data is None:
+            return []
+
+        prot_domains = set([a.split('.')[0] for a in data["hmm_acc"]])
+        dom_list = sorted(list(reference_domains & prot_domains))
+
+        return dom_list
+
+    def _prepare_domains(self, dom_list):
+        """Select domain subsequences from the entire protein sequences."""
+        def prepare_domains(fasta_dic, dom_list, pfam_scan, out_folder):
+            out_file_dic = {}
+            for acc in dom_list:
+                out_file_dic[acc] = open("%s/%s.fa" % (out_folder, acc), "w")
+
+            f = open(pfam_scan)
+            f.readline()
+            for line in f:
+                split = line.split()
+                rbp = split[0]
+                start = int(split[3])
+                stop = int(split[4])
+                acc = split[5].split('.')[0]
+                if acc in out_file_dic.keys():
+                    out_file_dic[acc].write(
+                        ">%s:%i-%i\n%s\n" % (rbp, start, stop,
+                                             fasta_dic[rbp][start:stop]))
+            f.close()
+
+            for acc in dom_list:
+                out_file_dic[acc].close()
+
+        mkdir(self._dom_fold)
+        fasta = fasta_utils.import_fasta(self.fasta)
+        prepare_domains(fasta, dom_list, self.pfam_scan,
+                        self._dom_fold)
+
+    def _compute_fisher_scores(self, dom_list):
+        """Wrapper for SAM 3.5 get_fisher_scores."""
+        def get_fisher_scores(dom_list, mod_fold, dom_fold, fisher_fold):
+            for acc in dom_list:
+                _FNULL = open(devnull, 'w')
+                cmd = "get_fisher_scores run -i %s/%s.mod -db %s/%s.fa" % (
+                    mod_fold, acc, dom_fold, acc)
+                fisher = sp.check_output(
+                    cmd, shell=True, stderr=_FNULL)
+                nf = open("%s/%s.txt" % (fisher_fold, acc), "w")
+                nf.write(fisher)
+                nf.close()
+
+        mkdir(self._fisher_fold)
+        get_fisher_scores(dom_list, self._mod_fold, self._dom_fold,
+                          self._fisher_fold)
+
+    def _ekm(self, dom_list):
+        """Compute the empirical kernel map from the Fisher scores."""
+        def process_seg(e):
+            """Process segment of a SAM 3.5 get_fisher_scores output file."""
+            seg = e.split()
+            c = seg[0].split(':')[0]
+            m = map(float, seg[3:])
+            return c, m
+
+        def read_sam_file(samfile):
+            """Read a SAM 3.5 get_fisher_scores output file."""
+            f = open(samfile)
+            data = f.read()
+            f.close()
+
+            columns = []
+            m = []
+            split = re.split(">A ", data)[1:]
+            for e in split:
+                c, m_ = process_seg(e)
+                columns.append(c)
+                m.append(m_)
+
+            m = np.matrix(m)
+            df = pd.DataFrame(data=m.T, columns=columns)
+            return df
+
+        def dom_features(fisher_fold, dom_list, names=None):
+            """Compute the features with respect to a domain type."""
+            dfs = []
+            for acc in dom_list:
+                df = read_sam_file("%s/%s.txt" % (fisher_fold, acc))
+                df = df.groupby(df.columns, axis=1).mean()
+                dfs.append(df)
+
+            con = pd.concat(dfs, ignore_index=True)
+
+            if names is not None:
+                add = sorted(list(set(names) - set(con.columns)))
+                fil = sorted(list(set(names) - set(add)))
+                con = con[fil]
+                for c in add:
+                    con[c] = np.zeros(len(con.index), dtype='float64')
+                con = con[names]
+
+            con = con.fillna(0.0)
+            return con
+
+        f = open(self._train_rbps_file)
+        train_rbps = f.read().strip().split('\n')
+        f.close()
+        ref = dom_features(self._reference_fisher_scores, dom_list,
+                           names=train_rbps)
+        ekm_ref = ref.T.dot(ref)
+        ekm_ref.index = ekm_ref.columns
+
+        sel = dom_features(self._fisher_fold, dom_list)
+
+        ekm_sel = sel.T.dot(sel)
+        ekm_sel.index = ekm_sel.columns
+
+        ekm = ref.T.dot(sel)
+
+        for rs in ekm.columns:
+            for rr in ekm.index:
+                if ekm_ref[rr][rr] > 0 and ekm_sel[rs][rs] > 0:
+                    ekm[rs][rr] /= np.sqrt(ekm_ref[rr][rr] * ekm_sel[rs][rs])
+        return ekm
+
+    def vectorize(self):
+        """Produce the RBP features."""
+        # create a temporary folder
+        mkdir(self._temp_fold)
+        # scan the RBP sequences against Pfam
+        self._pfam_scan()
+        # determine the accession numbers of the pfam domains needed for
+        # computing the features
+        dom_list = self._overlapping_domains()
+        if len(dom_list) == 0:
+            rmtree(self._temp_fold)
+            return None
+        # prepare fasta file with the sequence of the domains
+        self._prepare_domains(dom_list)
+        # compute fisher scores using SAM 3.5
+        self._compute_fisher_scores(dom_list)
+        # compute the empirical kernel map
+        ekm = self._ekm(dom_list)
+        # remove the temporary folder
+        rmtree(self._temp_fold)
+        return ekm
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/recommend.py	Tue May 31 05:41:03 2016 -0400
@@ -0,0 +1,89 @@
+"""Recommend RNAs."""
+from __future__ import print_function
+
+import cPickle
+import sys
+from itertools import izip
+
+from utils import get_serendipity_val
+
+__author__ = "Gianluca Corrado"
+__copyright__ = "Copyright 2016, Gianluca Corrado"
+__license__ = "MIT"
+__maintainer__ = "Gianluca Corrado"
+__email__ = "gianluca.corrado@unitn.it"
+__status__ = "Production"
+
+
+class Predictor():
+    """Predict interactions."""
+
+    def __init__(self, predict_dataset, trained_model, serendipity_dic=None,
+                 output=None):
+        """
+        Constructor.
+
+        Parameters
+        ------
+        predict_dataset : data.PredictDataset
+            Dataset containing the examples to predict.
+
+        trained_model : str
+            File name of the trained model.
+
+        serendipity_dic : dict (default : None)
+            Dictionary with serendipy values.
+
+        output : str (default : None)
+            Output file. If None then STDOUT.
+        """
+        self.predict_dataset = predict_dataset
+        f = open(trained_model)
+        self.model = cPickle.load(f)
+        f.close()
+        try:
+            f = open(serendipity_dic)
+            self.serendipity_dic = cPickle.load(f)
+            f.close()
+        except:
+            self.serendipity_dic = None
+        self.output = output
+
+    def predict(self):
+        """Predict interaction values."""
+        # predict the y_hat
+        (p, p_names, r, r_names) = self.predict_dataset
+        y_hat = self.model.predict(p, r)
+        # sort the interactions according to y_hat
+        ordering = sorted(range(len(y_hat)),
+                          key=lambda x: y_hat[x], reverse=True)
+        p_names = p_names[ordering]
+        r_names = r_names[ordering]
+        y_hat = y_hat[ordering]
+
+        # output to STDOUT
+        if self.output is None:
+            print("RBP\ttarget\ty_hat\tserendipity")
+            if self.serendipity_dic is None:
+                for (p_, r_, s_) in izip(p_names, r_names, y_hat):
+                    print("%s\t%s\t%.3f\t---" % (p_, r_, s_))
+                    sys.stdout.flush()
+            else:
+                for (p_, r_, s_) in izip(p_names, r_names, y_hat):
+                    print("%s\t%s\t%.3f\t%.2f" %
+                          (p_, r_, s_,
+                           get_serendipity_val(self.serendipity_dic, r_)))
+                    sys.stdout.flush()
+        # output to file
+        else:
+            nf = open(self.output, "w")
+            nf.write("RBP\ttarget\ty_hat\tserendipity\n")
+            if self.serendipity_dic is None:
+                for (p_, r_, s_) in izip(p_names, r_names, y_hat):
+                    nf.write("%s\t%s\t%.3f\t---\n" % (p_, r_, s_))
+            else:
+                for (p_, r_, s_) in izip(p_names, r_names, y_hat):
+                    nf.write("%s\t%s\t%.3f\t%.2f\n" %
+                             (p_, r_, s_,
+                              get_serendipity_val(self.serendipity_dic, r_)))
+            nf.close()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/rnacommender.xml	Tue May 31 05:41:03 2016 -0400
@@ -0,0 +1,47 @@
+<tool id="rbc_rnacommender" name="RNAcommender" version="0.1.0">
+    <description>files into a collection</description>
+    <requirements>
+        <requirement type="package" version="3.5">sam</requirement>
+        <requirement type="package" version="1.8.1">numpy</requirement>
+        <requirement type="package" version="0.17.1">pandas</requirement>
+        <requirement type="package" version="3.2.2">pytables</requirement>
+        <requirement type="package" version="0.7">theano</requirement>
+        <requirement type="package" version="2.10.0">requests</requirement>
+    </requirements>
+    <command detect_errors="aggressive">
+    <![CDATA[
+        sh $__tool_directory__/init.sh &&
+        python $__tool_directory__/main.py "$infile"
+    ]]></command>
+    <inputs>
+        <param name="infile" type="data" format="fasta" label="Fasta file to split"/>
+    </inputs>
+    <outputs>
+        <data format="tabular" from_work_dir="output.txt" name="outfile" />
+    </outputs>
+    <tests>
+        <test>
+            <param name="infile" value="sample.fa" />
+            <output name="outfile">
+                <assert_contents>
+                    <has_text_matching expression="RBP\ttarget\ty_hat"/>
+                </assert_contents>
+            </output>
+        </test>
+    </tests>
+    <help><![CDATA[
+        RNAcommender 0.1.0 (GALAXY version).
+        In order to get the prediction for one (or more) RBPs, insert one fasta file with the protein sequences. The output will contain a ranked list of targets for ALL the proteins in the input file.
+        RNAcommender full package is available at https://github.com/gianlucacorrado/RNAcommender.
+    ]]></help>
+    <citations>
+        <citation type="bibtex">
+            @ARTICLE{corrado2016rnacommender,
+                Author = {Gianluca Corrado, Toma Tebaldi, Fabrizio Costa, Paolo Frasconi and Andrea Passerini},
+                keywords = {machine learning, bioinformatics, post-trainscriptional regulation, gene expression},
+                title = {{RNAcommender: genome-wide recommendation of RNA-protein interactions.}},
+                url = {https://github.com/gianlucacorrado/RNAcommender}
+            }
+        </citation>
+    </citations>
+</tool>
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/sample.fa	Tue May 31 05:41:03 2016 -0400
@@ -0,0 +1,2 @@
+>AGO1
+MEAGPSGAAAGAYLPPLQQVFQAPRRPGIGTVGKPIKLLANYFEVDIPKIDVYHYEVDIKPDKCPRRVNREVVEYMVQHFKPQIFGDRKPVYDGKKNIYTVTALPIGNERVDFEVTIPGEGKDRIFKVSIKWLAIVSWRMLHEALVSGQIPVPLESVQALDVAMRHLASMRYTPVGRSFFSPPEGYYHPLGGGREVWFGFHQSVRPAMWKMMLNIDVSATAFYKAQPVIEFMCEVLDIRNIDEQPKPLTDSQRVRFTKEIKGLKVEVTHCGQMKRKYRVCNVTRRPASHQTFPLQLESGQTVECTVAQYFKQKYNLQLKYPHLPCLQVGQEQKHTYLPLEVCNIVAGQRCIKKLTDNQTSTMIKATARSAPDRQEEISRLMKNASYNLDPYIQEFGIKVKDDMTEVTGRVLPAPILQYGGRNRAIATPNQGVWDMRGKQFYNGIEIKVWAIACFAPQKQCREEVLKNFTDQLRKISKDAGMPIQGQPCFCKYAQGADSVEPMFRHLKNTYSGLQLIIVILPGKTPVYAEVKRVGDTLLGMATQCVQVKNVVKTSPQTLSNLCLKINVKLGGINNILVPHQRSAVFQQPVIFLGADVTHPPAGDGKKPSITAVVGSMDAHPSRYCATVRVQRPRQEIIEDLSYMVRELLIQFYKSTRFKPTRIIFYRDGVPEGQLPQILHYELLAIRDACIKLEKDYQPGITYIVVQKRHHTRLFCADKNERIGKSGNIPAGTTVDTNITHPFEFDFYLCSHAGIQGTSRPSHYYVLWDDNRFTADELQILTYQLCHTYVRCTRSVSIPAPAYYARLVAFRARYHLVDKEHDSGEGSHISGQSNGRDPQALAKAVQVHQDTLRTMYFA
\ No newline at end of file
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/utils/__init__.py	Tue May 31 05:41:03 2016 -0400
@@ -0,0 +1,42 @@
+"""Util functions."""
+
+import pandas as pd
+import cPickle
+
+__author__ = "Gianluca Corrado"
+__copyright__ = "Copyright 2016, Gianluca Corrado"
+__license__ = "MIT"
+__maintainer__ = "Gianluca Corrado"
+__email__ = "gianluca.corrado@unitn.it"
+__status__ = "Production"
+
+
+def feature_size(store_name):
+    """Number of features."""
+    store = pd.io.pytables.HDFStore(store_name)
+    a = store.features
+    store.close()
+    return a.shape[0]
+
+
+def save_serendipity_dic(y, filename):
+    """Save the dictionary with the serendipity values."""
+    store = pd.io.pytables.HDFStore(y)
+    mat = store.matrix
+    store.close()
+    n = len(mat.columns)
+    ser = 1 - mat.sum(axis=1) / n
+
+    f = open(filename, "w")
+    cPickle.dump(ser.to_dict(), f, protocol=2)
+    f.close()
+
+
+def get_serendipity_val(dic, key):
+    """Return the serendipity of a RNA."""
+    # The key was in the training set
+    try:
+        return dic[key]
+    # The key wasn't in the training set, then the serendipity is 1
+    except KeyError:
+        return 1.