Mercurial > repos > bgruening > sklearn_discriminant_classifier
annotate simple_model_fit.py @ 29:95d0f81e46e8 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
author | bgruening |
---|---|
date | Fri, 01 Nov 2019 17:14:31 -0400 |
parents | |
children | 18b39ada6f35 |
rev | line source |
---|---|
29
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
1 import argparse |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
2 import json |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
3 import pandas as pd |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
4 import pickle |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
5 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
6 from galaxy_ml.utils import load_model, read_columns |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
7 from sklearn.pipeline import Pipeline |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
8 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
9 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
10 def _get_X_y(params, infile1, infile2): |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
11 """ read from inputs and output X and y |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
12 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
13 Parameters |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
14 ---------- |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
15 params : dict |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
16 Tool inputs parameter |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
17 infile1 : str |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
18 File path to dataset containing features |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
19 infile2 : str |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
20 File path to dataset containing target values |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
21 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
22 """ |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
23 # store read dataframe object |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
24 loaded_df = {} |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
25 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
26 input_type = params['input_options']['selected_input'] |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
27 # tabular input |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
28 if input_type == 'tabular': |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
29 header = 'infer' if params['input_options']['header1'] else None |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
30 column_option = (params['input_options']['column_selector_options_1'] |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
31 ['selected_column_selector_option']) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
32 if column_option in ['by_index_number', 'all_but_by_index_number', |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
33 'by_header_name', 'all_but_by_header_name']: |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
34 c = params['input_options']['column_selector_options_1']['col1'] |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
35 else: |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
36 c = None |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
37 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
38 df_key = infile1 + repr(header) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
39 df = pd.read_csv(infile1, sep='\t', header=header, |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
40 parse_dates=True) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
41 loaded_df[df_key] = df |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
42 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
43 X = read_columns(df, c=c, c_option=column_option).astype(float) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
44 # sparse input |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
45 elif input_type == 'sparse': |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
46 X = mmread(open(infile1, 'r')) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
47 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
48 # Get target y |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
49 header = 'infer' if params['input_options']['header2'] else None |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
50 column_option = (params['input_options']['column_selector_options_2'] |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
51 ['selected_column_selector_option2']) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
52 if column_option in ['by_index_number', 'all_but_by_index_number', |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
53 'by_header_name', 'all_but_by_header_name']: |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
54 c = params['input_options']['column_selector_options_2']['col2'] |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
55 else: |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
56 c = None |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
57 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
58 df_key = infile2 + repr(header) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
59 if df_key in loaded_df: |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
60 infile2 = loaded_df[df_key] |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
61 else: |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
62 infile2 = pd.read_csv(infile2, sep='\t', |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
63 header=header, parse_dates=True) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
64 loaded_df[df_key] = infile2 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
65 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
66 y = read_columns( |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
67 infile2, |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
68 c=c, |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
69 c_option=column_option, |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
70 sep='\t', |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
71 header=header, |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
72 parse_dates=True) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
73 if len(y.shape) == 2 and y.shape[1] == 1: |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
74 y = y.ravel() |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
75 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
76 return X, y |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
77 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
78 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
79 def main(inputs, infile_estimator, infile1, infile2, out_object, |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
80 out_weights=None): |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
81 """ main |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
82 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
83 Parameters |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
84 ---------- |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
85 inputs : str |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
86 File path to galaxy tool parameter |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
87 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
88 infile_estimator : str |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
89 File paths of input estimator |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
90 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
91 infile1 : str |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
92 File path to dataset containing features |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
93 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
94 infile2 : str |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
95 File path to dataset containing target labels |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
96 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
97 out_object : str |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
98 File path for output of fitted model or skeleton |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
99 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
100 out_weights : str |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
101 File path for output of weights |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
102 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
103 """ |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
104 with open(inputs, 'r') as param_handler: |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
105 params = json.load(param_handler) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
106 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
107 # load model |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
108 with open(infile_estimator, 'rb') as est_handler: |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
109 estimator = load_model(est_handler) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
110 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
111 X_train, y_train = _get_X_y(params, infile1, infile2) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
112 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
113 estimator.fit(X_train, y_train) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
114 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
115 main_est = estimator |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
116 if isinstance(main_est, Pipeline): |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
117 main_est = main_est.steps[-1][-1] |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
118 if hasattr(main_est, 'model_') \ |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
119 and hasattr(main_est, 'save_weights'): |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
120 if out_weights: |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
121 main_est.save_weights(out_weights) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
122 del main_est.model_ |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
123 del main_est.fit_params |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
124 del main_est.model_class_ |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
125 del main_est.validation_data |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
126 if getattr(main_est, 'data_generator_', None): |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
127 del main_est.data_generator_ |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
128 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
129 with open(out_object, 'wb') as output_handler: |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
130 pickle.dump(estimator, output_handler, |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
131 pickle.HIGHEST_PROTOCOL) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
132 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
133 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
134 if __name__ == '__main__': |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
135 aparser = argparse.ArgumentParser() |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
136 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
137 aparser.add_argument("-X", "--infile_estimator", dest="infile_estimator") |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
138 aparser.add_argument("-y", "--infile1", dest="infile1") |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
139 aparser.add_argument("-g", "--infile2", dest="infile2") |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
140 aparser.add_argument("-o", "--out_object", dest="out_object") |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
141 aparser.add_argument("-t", "--out_weights", dest="out_weights") |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
142 args = aparser.parse_args() |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
143 |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
144 main(args.inputs, args.infile_estimator, args.infile1, |
95d0f81e46e8
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
bgruening
parents:
diff
changeset
|
145 args.infile2, args.out_object, args.out_weights) |