comparison recommend.py @ 0:8918de535391 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/rna_commander/tools/rna_tools/rna_commender commit 2fc7f3c08f30e2d81dc4ad19759dfe7ba9b0a3a1
author rnateam
date Tue, 31 May 2016 05:41:03 -0400
parents
children a609d6dc8047
comparison
equal deleted inserted replaced
-1:000000000000 0:8918de535391
1 """Recommend RNAs."""
2 from __future__ import print_function
3
4 import cPickle
5 import sys
6 from itertools import izip
7
8 from utils import get_serendipity_val
9
10 __author__ = "Gianluca Corrado"
11 __copyright__ = "Copyright 2016, Gianluca Corrado"
12 __license__ = "MIT"
13 __maintainer__ = "Gianluca Corrado"
14 __email__ = "gianluca.corrado@unitn.it"
15 __status__ = "Production"
16
17
18 class Predictor():
19 """Predict interactions."""
20
21 def __init__(self, predict_dataset, trained_model, serendipity_dic=None,
22 output=None):
23 """
24 Constructor.
25
26 Parameters
27 ------
28 predict_dataset : data.PredictDataset
29 Dataset containing the examples to predict.
30
31 trained_model : str
32 File name of the trained model.
33
34 serendipity_dic : dict (default : None)
35 Dictionary with serendipy values.
36
37 output : str (default : None)
38 Output file. If None then STDOUT.
39 """
40 self.predict_dataset = predict_dataset
41 f = open(trained_model)
42 self.model = cPickle.load(f)
43 f.close()
44 try:
45 f = open(serendipity_dic)
46 self.serendipity_dic = cPickle.load(f)
47 f.close()
48 except:
49 self.serendipity_dic = None
50 self.output = output
51
52 def predict(self):
53 """Predict interaction values."""
54 # predict the y_hat
55 (p, p_names, r, r_names) = self.predict_dataset
56 y_hat = self.model.predict(p, r)
57 # sort the interactions according to y_hat
58 ordering = sorted(range(len(y_hat)),
59 key=lambda x: y_hat[x], reverse=True)
60 p_names = p_names[ordering]
61 r_names = r_names[ordering]
62 y_hat = y_hat[ordering]
63
64 # output to STDOUT
65 if self.output is None:
66 print("RBP\ttarget\ty_hat\tserendipity")
67 if self.serendipity_dic is None:
68 for (p_, r_, s_) in izip(p_names, r_names, y_hat):
69 print("%s\t%s\t%.3f\t---" % (p_, r_, s_))
70 sys.stdout.flush()
71 else:
72 for (p_, r_, s_) in izip(p_names, r_names, y_hat):
73 print("%s\t%s\t%.3f\t%.2f" %
74 (p_, r_, s_,
75 get_serendipity_val(self.serendipity_dic, r_)))
76 sys.stdout.flush()
77 # output to file
78 else:
79 nf = open(self.output, "w")
80 nf.write("RBP\ttarget\ty_hat\tserendipity\n")
81 if self.serendipity_dic is None:
82 for (p_, r_, s_) in izip(p_names, r_names, y_hat):
83 nf.write("%s\t%s\t%.3f\t---\n" % (p_, r_, s_))
84 else:
85 for (p_, r_, s_) in izip(p_names, r_names, y_hat):
86 nf.write("%s\t%s\t%.3f\t%.2f\n" %
87 (p_, r_, s_,
88 get_serendipity_val(self.serendipity_dic, r_)))
89 nf.close()