Mercurial > repos > bgruening > sklearn_ensemble
comparison ensemble.xml @ 2:6e6726be0728 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tools/sklearn commit 641ac64ded23fbb6fe85d5f13926da12dcce4e76
author | bgruening |
---|---|
date | Tue, 13 Mar 2018 04:56:08 -0400 |
parents | 569eefee7ed8 |
children | 0431274c367d |
comparison
equal
deleted
inserted
replaced
1:883f2973d37d | 2:6e6726be0728 |
---|---|
23 from scipy.io import mmread | 23 from scipy.io import mmread |
24 | 24 |
25 input_json_path = sys.argv[1] | 25 input_json_path = sys.argv[1] |
26 params = json.load(open(input_json_path, "r")) | 26 params = json.load(open(input_json_path, "r")) |
27 | 27 |
28 @COLUMNS_FUNCTION@ | |
29 | |
28 #if $selected_tasks.selected_task == "train": | 30 #if $selected_tasks.selected_task == "train": |
29 | 31 |
30 algorithm = params["selected_tasks"]["selected_algorithms"]["selected_algorithm"] | 32 algorithm = params["selected_tasks"]["selected_algorithms"]["selected_algorithm"] |
31 options = params["selected_tasks"]["selected_algorithms"]["options"] | 33 options = params["selected_tasks"]["selected_algorithms"]["options"] |
32 input_type = params["selected_tasks"]["selected_algorithms"]["input_options"]["selected_input"] | 34 input_type = params["selected_tasks"]["selected_algorithms"]["input_options"]["selected_input"] |
33 if input_type=="tabular": | 35 if input_type=="tabular": |
34 col1 = params["selected_tasks"]["selected_algorithms"]["input_options"]["col1"] | 36 X = read_columns( |
35 col1 = list(map(lambda x: x - 1, col1)) | 37 "$selected_tasks.selected_algorithms.input_options.infile1", |
36 f1 = pandas.read_csv("$selected_tasks.selected_algorithms.input_options.infile1", sep='\t', header=None, index_col=None, parse_dates=True, encoding=None, tupleize_cols=False ) | 38 "$selected_tasks.selected_algorithms.input_options.col1", |
37 X = f1.iloc[:,col1].values | 39 sep='\t', |
40 header=None, | |
41 parse_dates=True | |
42 ) | |
38 else: | 43 else: |
39 X = mmread(open("$selected_tasks.selected_algorithms.input_options.infile1", 'r')) | 44 X = mmread(open("$selected_tasks.selected_algorithms.input_options.infile1", 'r')) |
40 | 45 |
41 col2 = params["selected_tasks"]["selected_algorithms"]["input_options"]["col2"] | 46 y = read_columns( |
42 col2 = list(map(lambda x: x - 1, col2)) | 47 "$selected_tasks.selected_algorithms.input_options.infile2", |
43 f2 = pandas.read_csv("$selected_tasks.selected_algorithms.input_options.infile2", sep='\t', header=None, index_col=None, parse_dates=True, encoding=None, tupleize_cols=False ) | 48 "$selected_tasks.selected_algorithms.input_options.col2", |
44 y = f2.iloc[:,col2].values | 49 sep='\t', |
50 header=None, | |
51 parse_dates=True | |
52 ) | |
45 | 53 |
46 my_class = getattr(sklearn.ensemble, algorithm) | 54 my_class = getattr(sklearn.ensemble, algorithm) |
47 estimator = my_class(**options) | 55 estimator = my_class(**options) |
48 estimator.fit(X,y) | 56 estimator.fit(X,y) |
49 pickle.dump(estimator,open("$outfile_fit", 'w+'), pickle.HIGHEST_PROTOCOL) | 57 pickle.dump(estimator,open("$outfile_fit", 'w+'), pickle.HIGHEST_PROTOCOL) |
50 | 58 |
51 #else: | 59 #else: |
52 classifier_object = pickle.load(open("$selected_tasks.infile_model", 'r')) | 60 classifier_object = pickle.load(open("$selected_tasks.infile_model", 'r')) |
53 data = pandas.read_csv("$selected_tasks.infile_data", sep='\t', header=0, index_col=None, parse_dates=True, encoding=None, tupleize_cols=False ) | 61 data = pandas.read_csv("$selected_tasks.infile_data", sep='\t', header=0, index_col=None, parse_dates=True, encoding=None, tupleize_cols=False) |
54 prediction = classifier_object.predict(data) | 62 prediction = classifier_object.predict(data) |
55 prediction_df = pandas.DataFrame(prediction) | 63 prediction_df = pandas.DataFrame(prediction) |
56 res = pandas.concat([data, prediction_df], axis=1) | 64 res = pandas.concat([data, prediction_df], axis=1) |
57 res.to_csv(path_or_buf = "$outfile_predict", sep="\t", index=False) | 65 res.to_csv(path_or_buf = "$outfile_predict", sep="\t", index=False) |
58 #end if | 66 #end if |