Mercurial > repos > bgruening > keras_batch_models
comparison to_categorical.py @ 9:0a3f113397b2 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
| author | bgruening | 
|---|---|
| date | Tue, 13 Apr 2021 17:29:01 +0000 | 
| parents | |
| children | 4a5266c96889 | 
   comparison
  equal
  deleted
  inserted
  replaced
| 8:508ce0649bec | 9:0a3f113397b2 | 
|---|---|
| 1 import argparse | |
| 2 import json | |
| 3 import warnings | |
| 4 | |
| 5 import numpy as np | |
| 6 import pandas as pd | |
| 7 from keras.utils import to_categorical | |
| 8 | |
| 9 | |
| 10 def main(inputs, infile, outfile, num_classes=None): | |
| 11 """ | |
| 12 Parameter | |
| 13 --------- | |
| 14 input : str | |
| 15 File path to galaxy tool parameter | |
| 16 | |
| 17 infile : str | |
| 18 File paths of input vector | |
| 19 | |
| 20 outfile : str | |
| 21 File path to output matrix | |
| 22 | |
| 23 num_classes : str | |
| 24 Total number of classes. If None, this would be inferred as the (largest number in y) + 1 | |
| 25 | |
| 26 """ | |
| 27 warnings.simplefilter("ignore") | |
| 28 | |
| 29 with open(inputs, "r") as param_handler: | |
| 30 params = json.load(param_handler) | |
| 31 | |
| 32 input_header = params["header0"] | |
| 33 header = "infer" if input_header else None | |
| 34 | |
| 35 input_vector = pd.read_csv(infile, sep="\t", header=header) | |
| 36 | |
| 37 output_matrix = to_categorical(input_vector, num_classes=num_classes) | |
| 38 | |
| 39 np.savetxt(outfile, output_matrix, fmt="%d", delimiter="\t") | |
| 40 | |
| 41 | |
| 42 if __name__ == "__main__": | |
| 43 aparser = argparse.ArgumentParser() | |
| 44 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | |
| 45 aparser.add_argument("-y", "--infile", dest="infile") | |
| 46 aparser.add_argument("-n", "--num_classes", dest="num_classes", type=int, default=None) | |
| 47 aparser.add_argument("-o", "--outfile", dest="outfile") | |
| 48 args = aparser.parse_args() | |
| 49 | |
| 50 main(args.inputs, args.infile, args.outfile, args.num_classes) | 
