Mercurial > repos > bgruening > sklearn_discriminant_classifier
comparison discriminant.xml @ 13:f46da2feb233 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5d71c93a3dd804b1469852240a86021ab9130364
author | bgruening |
---|---|
date | Mon, 09 Jul 2018 14:33:25 -0400 |
parents | e0067d9baffc |
children | fb232caca397 |
comparison
equal
deleted
inserted
replaced
12:cb14b6827f70 | 13:f46da2feb233 |
---|---|
20 import numpy as np | 20 import numpy as np |
21 import sklearn.discriminant_analysis | 21 import sklearn.discriminant_analysis |
22 import pandas | 22 import pandas |
23 import pickle | 23 import pickle |
24 | 24 |
25 @COLUMNS_FUNCTION@ | |
26 @GET_X_y_FUNCTION@ | |
27 | |
25 input_json_path = sys.argv[1] | 28 input_json_path = sys.argv[1] |
26 params = json.load(open(input_json_path, "r")) | 29 params = json.load(open(input_json_path, "r")) |
27 | 30 |
28 | 31 |
29 #if $selected_tasks.selected_task == "load": | 32 #if $selected_tasks.selected_task == "load": |
30 | 33 |
31 classifier_object = pickle.load(open("$infile_model", 'r')) | 34 classifier_object = pickle.load(open("$infile_model", 'r')) |
32 | 35 |
33 data = pandas.read_csv("$selected_tasks.infile_data", sep='\t', header=0, index_col=None, parse_dates=True, encoding=None, tupleize_cols=False ) | 36 header = 'infer' if params["selected_tasks"]["header"] else None |
37 data = pandas.read_csv("$selected_tasks.infile_data", sep='\t', header=header, index_col=None, parse_dates=True, encoding=None, tupleize_cols=False) | |
34 prediction = classifier_object.predict(data) | 38 prediction = classifier_object.predict(data) |
35 prediction_df = pandas.DataFrame(prediction) | 39 prediction_df = pandas.DataFrame(prediction) |
36 res = pandas.concat([data, prediction_df], axis=1) | 40 res = pandas.concat([data, prediction_df], axis=1) |
37 res.to_csv(path_or_buf = "$outfile_predict", sep="\t", index=False) | 41 res.to_csv(path_or_buf = "$outfile_predict", sep="\t", index=False) |
38 | 42 |
39 #else: | 43 #else: |
40 | 44 |
41 data_train = pandas.read_csv("$selected_tasks.infile_train", sep='\t', header=0, index_col=None, parse_dates=True, encoding=None, tupleize_cols=False ) | 45 X, y = get_X_y(params, "$selected_tasks.selected_algorithms.input_options.infile1" ,"$selected_tasks.selected_algorithms.input_options.infile2") |
42 | |
43 data = data_train.ix[:,0:len(data_train.columns)-1] | |
44 labels = np.array(data_train[data_train.columns[len(data_train.columns)-1]]) | |
45 | 46 |
46 options = params["selected_tasks"]["selected_algorithms"]["options"] | 47 options = params["selected_tasks"]["selected_algorithms"]["options"] |
47 selected_algorithm = params["selected_tasks"]["selected_algorithms"]["selected_algorithm"] | 48 selected_algorithm = params["selected_tasks"]["selected_algorithms"]["selected_algorithm"] |
48 | 49 |
49 my_class = getattr(sklearn.discriminant_analysis, selected_algorithm) | 50 my_class = getattr(sklearn.discriminant_analysis, selected_algorithm) |
50 classifier_object = my_class(**options) | 51 classifier_object = my_class(**options) |
51 classifier_object.fit(data,labels) | 52 classifier_object.fit(X, y) |
52 pickle.dump(classifier_object,open("$outfile_fit", 'w+'), pickle.HIGHEST_PROTOCOL) | 53 pickle.dump(classifier_object,open("$outfile_fit", 'w+'), pickle.HIGHEST_PROTOCOL) |
53 | 54 |
54 #end if | 55 #end if |
55 ]]> | 56 ]]> |
56 </configfile> | 57 </configfile> |
57 </configfiles> | 58 </configfiles> |
58 <inputs> | 59 <inputs> |
59 <expand macro="train_loadConditional" model="zip"> | 60 <expand macro="sl_Conditional" model="zip"> |
60 <param name="selected_algorithm" type="select" label="Classifier type"> | 61 <param name="selected_algorithm" type="select" label="Classifier type"> |
61 <option value="LinearDiscriminantAnalysis" selected="true">Linear Discriminant Classifier</option> | 62 <option value="LinearDiscriminantAnalysis" selected="true">Linear Discriminant Classifier</option> |
62 <option value="QuadraticDiscriminantAnalysis">Quadratic Discriminant Classifier</option> | 63 <option value="QuadraticDiscriminantAnalysis">Quadratic Discriminant Classifier</option> |
63 </param> | 64 </param> |
64 <when value="LinearDiscriminantAnalysis"> | 65 <when value="LinearDiscriminantAnalysis"> |
66 <expand macro="sl_mixed_input"/> | |
65 <section name="options" title="Advanced Options" expanded="False"> | 67 <section name="options" title="Advanced Options" expanded="False"> |
66 <param argument="solver" type="select" optional="true" label="Solver" help=""> | 68 <param argument="solver" type="select" optional="true" label="Solver" help=""> |
67 <option value="svd" selected="true">Singular Value Decomposition</option> | 69 <option value="svd" selected="true">Singular Value Decomposition</option> |
68 <option value="lsqr">Least Squares Solution</option> | 70 <option value="lsqr">Least Squares Solution</option> |
69 <option value="eigen">Eigenvalue Decomposition</option> | 71 <option value="eigen">Eigenvalue Decomposition</option> |
76 <param argument="store_covariance" type="boolean" optional="true" truevalue="booltrue" falsevalue="boolflase" checked="false" | 78 <param argument="store_covariance" type="boolean" optional="true" truevalue="booltrue" falsevalue="boolflase" checked="false" |
77 label="Store covariance" help="Compute class covariance matrix."/> | 79 label="Store covariance" help="Compute class covariance matrix."/> |
78 </section> | 80 </section> |
79 </when> | 81 </when> |
80 <when value="QuadraticDiscriminantAnalysis"> | 82 <when value="QuadraticDiscriminantAnalysis"> |
83 <expand macro="sl_mixed_input"/> | |
81 <section name="options" title="Advanced Options" expanded="False"> | 84 <section name="options" title="Advanced Options" expanded="False"> |
82 <!--expand macro="priors"/--> | 85 <!--expand macro="priors"/--> |
83 <param argument="reg_param" type="float" optional="true" value="0.0" label="Regularization coefficient" help="Covariance estimate regularizer."/> | 86 <param argument="reg_param" type="float" optional="true" value="0.0" label="Regularization coefficient" help="Covariance estimate regularizer."/> |
84 <expand macro="tol" default_value="0.00001" help_text="Rank estimation threshold used in SVD solver."/> | 87 <expand macro="tol" default_value="0.00001" help_text="Rank estimation threshold used in SVD solver."/> |
85 <param argument="store_covariances" type="boolean" optional="true" truevalue="booltrue" falsevalue="boolflase" checked="false" | 88 <param argument="store_covariances" type="boolean" optional="true" truevalue="booltrue" falsevalue="boolflase" checked="false" |
89 </expand> | 92 </expand> |
90 </inputs> | 93 </inputs> |
91 <expand macro="output"/> | 94 <expand macro="output"/> |
92 <tests> | 95 <tests> |
93 <test> | 96 <test> |
94 <param name="infile_train" value="train.tabular" ftype="tabular"/> | 97 <param name="infile1" value="train.tabular" ftype="tabular"/> |
98 <param name="infile2" value="train.tabular" ftype="tabular"/> | |
99 <param name="header1" value="True"/> | |
100 <param name="header2" value="True"/> | |
101 <param name="col1" value="1,2,3,4"/> | |
102 <param name="col2" value="5"/> | |
95 <param name="selected_task" value="train"/> | 103 <param name="selected_task" value="train"/> |
96 <param name="selected_algorithm" value="LinearDiscriminantAnalysis"/> | 104 <param name="selected_algorithm" value="LinearDiscriminantAnalysis"/> |
97 <param name="solver" value="svd" /> | 105 <param name="solver" value="svd" /> |
98 <param name="store_covariances" value="True"/> | 106 <param name="store_covariances" value="True"/> |
99 <output name="outfile_fit" file="lda_model01" compare="sim_size" delta="500"/> | 107 <output name="outfile_fit" file="lda_model01" compare="sim_size" delta="500"/> |
100 </test> | 108 </test> |
101 <test> | 109 <test> |
102 <param name="infile_train" value="train.tabular" ftype="tabular"/> | 110 <param name="infile1" value="train.tabular" ftype="tabular"/> |
111 <param name="infile2" value="train.tabular" ftype="tabular"/> | |
112 <param name="header1" value="True"/> | |
113 <param name="header2" value="True"/> | |
114 <param name="col1" value="1,2,3,4"/> | |
115 <param name="col2" value="5"/> | |
103 <param name="selected_task" value="train"/> | 116 <param name="selected_task" value="train"/> |
104 <param name="selected_algorithm" value="LinearDiscriminantAnalysis"/> | 117 <param name="selected_algorithm" value="LinearDiscriminantAnalysis"/> |
105 <param name="solver" value="lsqr"/> | 118 <param name="solver" value="lsqr"/> |
106 <output name="outfile_fit" file="lda_model02" compare="sim_size" delta="500"/> | 119 <output name="outfile_fit" file="lda_model02" compare="sim_size" delta="500"/> |
107 </test> | 120 </test> |
108 <test> | 121 <test> |
109 <param name="infile_train" value="train.tabular" ftype="tabular"/> | 122 <param name="infile1" value="train.tabular" ftype="tabular"/> |
123 <param name="infile2" value="train.tabular" ftype="tabular"/> | |
124 <param name="header1" value="True"/> | |
125 <param name="header2" value="True"/> | |
126 <param name="col1" value="1,2,3,4"/> | |
127 <param name="col2" value="5"/> | |
110 <param name="selected_task" value="train"/> | 128 <param name="selected_task" value="train"/> |
111 <param name="selected_algorithm" value="QuadraticAnalysis"/> | 129 <param name="selected_algorithm" value="QuadraticDiscriminantAnalysis"/> |
112 <output name="outfile_fit" file="qda_model01" compare="sim_size" delta="500"/> | 130 <output name="outfile_fit" file="qda_model01" compare="sim_size" delta="500"/> |
113 </test> | 131 </test> |
114 <test> | 132 <test> |
115 <param name="infile_model" value="lda_model01" ftype="zip"/> | 133 <param name="infile_model" value="lda_model01" ftype="zip"/> |
116 <param name="infile_data" value="test.tabular" ftype="tabular"/> | 134 <param name="infile_data" value="test.tabular" ftype="tabular"/> |
135 <param name="header" value="True"/> | |
117 <param name="selected_task" value="load"/> | 136 <param name="selected_task" value="load"/> |
118 <output name="outfile_predict" file="lda_prediction_result01.tabular"/> | 137 <output name="outfile_predict" file="lda_prediction_result01.tabular"/> |
119 </test> | 138 </test> |
120 <test> | 139 <test> |
121 <param name="infile_model" value="lda_model02" ftype="zip"/> | 140 <param name="infile_model" value="lda_model02" ftype="zip"/> |
122 <param name="infile_data" value="test.tabular" ftype="tabular"/> | 141 <param name="infile_data" value="test.tabular" ftype="tabular"/> |
142 <param name="header" value="True"/> | |
123 <param name="selected_task" value="load"/> | 143 <param name="selected_task" value="load"/> |
124 <output name="outfile_predict" file="lda_prediction_result02.tabular"/> | 144 <output name="outfile_predict" file="lda_prediction_result02.tabular"/> |
125 </test> | 145 </test> |
126 <test> | 146 <test> |
127 <param name="infile_model" value="qda_model01" ftype="zip"/> | 147 <param name="infile_model" value="qda_model01" ftype="zip"/> |
128 <param name="infile_data" value="test.tabular" ftype="tabular"/> | 148 <param name="infile_data" value="test.tabular" ftype="tabular"/> |
149 <param name="header" value="True"/> | |
129 <param name="selected_task" value="load"/> | 150 <param name="selected_task" value="load"/> |
130 <output name="outfile_predict" file="qda_prediction_result01.tabular"/> | 151 <output name="outfile_predict" file="qda_prediction_result01.tabular"/> |
131 </test> | 152 </test> |
132 </tests> | 153 </tests> |
133 <help><![CDATA[ | 154 <help><![CDATA[ |