Mercurial > repos > rnateam > rnacommender
comparison recommend.py @ 0:8918de535391 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/rna_commander/tools/rna_tools/rna_commender commit 2fc7f3c08f30e2d81dc4ad19759dfe7ba9b0a3a1
author | rnateam |
---|---|
date | Tue, 31 May 2016 05:41:03 -0400 |
parents | |
children | a609d6dc8047 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:8918de535391 |
---|---|
1 """Recommend RNAs.""" | |
2 from __future__ import print_function | |
3 | |
4 import cPickle | |
5 import sys | |
6 from itertools import izip | |
7 | |
8 from utils import get_serendipity_val | |
9 | |
10 __author__ = "Gianluca Corrado" | |
11 __copyright__ = "Copyright 2016, Gianluca Corrado" | |
12 __license__ = "MIT" | |
13 __maintainer__ = "Gianluca Corrado" | |
14 __email__ = "gianluca.corrado@unitn.it" | |
15 __status__ = "Production" | |
16 | |
17 | |
18 class Predictor(): | |
19 """Predict interactions.""" | |
20 | |
21 def __init__(self, predict_dataset, trained_model, serendipity_dic=None, | |
22 output=None): | |
23 """ | |
24 Constructor. | |
25 | |
26 Parameters | |
27 ------ | |
28 predict_dataset : data.PredictDataset | |
29 Dataset containing the examples to predict. | |
30 | |
31 trained_model : str | |
32 File name of the trained model. | |
33 | |
34 serendipity_dic : dict (default : None) | |
35 Dictionary with serendipy values. | |
36 | |
37 output : str (default : None) | |
38 Output file. If None then STDOUT. | |
39 """ | |
40 self.predict_dataset = predict_dataset | |
41 f = open(trained_model) | |
42 self.model = cPickle.load(f) | |
43 f.close() | |
44 try: | |
45 f = open(serendipity_dic) | |
46 self.serendipity_dic = cPickle.load(f) | |
47 f.close() | |
48 except: | |
49 self.serendipity_dic = None | |
50 self.output = output | |
51 | |
52 def predict(self): | |
53 """Predict interaction values.""" | |
54 # predict the y_hat | |
55 (p, p_names, r, r_names) = self.predict_dataset | |
56 y_hat = self.model.predict(p, r) | |
57 # sort the interactions according to y_hat | |
58 ordering = sorted(range(len(y_hat)), | |
59 key=lambda x: y_hat[x], reverse=True) | |
60 p_names = p_names[ordering] | |
61 r_names = r_names[ordering] | |
62 y_hat = y_hat[ordering] | |
63 | |
64 # output to STDOUT | |
65 if self.output is None: | |
66 print("RBP\ttarget\ty_hat\tserendipity") | |
67 if self.serendipity_dic is None: | |
68 for (p_, r_, s_) in izip(p_names, r_names, y_hat): | |
69 print("%s\t%s\t%.3f\t---" % (p_, r_, s_)) | |
70 sys.stdout.flush() | |
71 else: | |
72 for (p_, r_, s_) in izip(p_names, r_names, y_hat): | |
73 print("%s\t%s\t%.3f\t%.2f" % | |
74 (p_, r_, s_, | |
75 get_serendipity_val(self.serendipity_dic, r_))) | |
76 sys.stdout.flush() | |
77 # output to file | |
78 else: | |
79 nf = open(self.output, "w") | |
80 nf.write("RBP\ttarget\ty_hat\tserendipity\n") | |
81 if self.serendipity_dic is None: | |
82 for (p_, r_, s_) in izip(p_names, r_names, y_hat): | |
83 nf.write("%s\t%s\t%.3f\t---\n" % (p_, r_, s_)) | |
84 else: | |
85 for (p_, r_, s_) in izip(p_names, r_names, y_hat): | |
86 nf.write("%s\t%s\t%.3f\t%.2f\n" % | |
87 (p_, r_, s_, | |
88 get_serendipity_val(self.serendipity_dic, r_))) | |
89 nf.close() |