Mercurial > repos > bgruening > sklearn_discriminant_classifier
diff model_prediction.py @ 27:8e49f26b14d3 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ba6a47bdf76bbf4cb276206ac1a8cbf61332fd16"
author | bgruening |
---|---|
date | Fri, 13 Sep 2019 12:10:41 -0400 |
parents | 9bb505eafac9 |
children | 64b771b1471a |
line wrap: on
line diff
--- a/model_prediction.py Fri Aug 09 07:06:17 2019 -0400 +++ b/model_prediction.py Fri Sep 13 12:10:41 2019 -0400 @@ -2,11 +2,13 @@ import json import numpy as np import pandas as pd +import tabix import warnings from scipy.io import mmread from sklearn.pipeline import Pipeline +from galaxy_ml.externals.selene_sdk.sequences import Genome from galaxy_ml.utils import (load_model, read_columns, get_module, try_get_attr) @@ -138,53 +140,108 @@ pred_data_generator.fit() - preds = estimator.model_.predict_generator( - pred_data_generator.flow(batch_size=32), - workers=N_JOBS, - use_multiprocessing=True) + variants = pred_data_generator.variants + # TODO : remove the following block after galaxy-ml v0.7.13 + blacklist_tabix = getattr(pred_data_generator.reference_genome_, + '_blacklist_tabix', None) + clean_variants = [] + if blacklist_tabix: + start_radius = pred_data_generator.start_radius_ + end_radius = pred_data_generator.end_radius_ + + for chrom, pos, name, ref, alt, strand in variants: + center = pos + len(ref) // 2 + start = center - start_radius + end = center + end_radius - if preds.min() < 0. or preds.max() > 1.: - warnings.warn('Network returning invalid probability values. ' - 'The last layer might not normalize predictions ' - 'into probabilities ' - '(like softmax or sigmoid would).') + if isinstance(pred_data_generator.reference_genome_, Genome): + if "chr" not in chrom: + chrom = "chr" + chrom + if "MT" in chrom: + chrom = chrom[:-1] + try: + rows = blacklist_tabix.query(chrom, start, end) + found = 0 + for row in rows: + found = 1 + break + if found: + continue + except tabix.TabixError: + pass - if params['method'] == 'predict_proba' and preds.shape[1] == 1: - # first column is probability of class 0 and second is of class 1 - preds = np.hstack([1 - preds, preds]) + clean_variants.append((chrom, pos, name, ref, alt, strand)) + else: + clean_variants = variants + + setattr(pred_data_generator, 'variants', clean_variants) + + variants = np.array(clean_variants) + # predict 1600 sample at once then write to file + gen_flow = pred_data_generator.flow(batch_size=1600) + + file_writer = open(outfile_predict, 'w') + header_row = '\t'.join(['chrom', 'pos', 'name', 'ref', + 'alt', 'strand']) + file_writer.write(header_row) + header_done = False - elif params['method'] == 'predict': - if preds.shape[-1] > 1: - # if the last activation is `softmax`, the sum of all - # probibilities will 1, the classification is considered as - # multi-class problem, otherwise, we take it as multi-label. - act = getattr(estimator.model_.layers[-1], 'activation', None) - if act and act.__name__ == 'softmax': - classes = preds.argmax(axis=-1) + steps_done = 0 + + # TODO: multiple threading + try: + while steps_done < len(gen_flow): + index_array = next(gen_flow.index_generator) + batch_X = gen_flow._get_batches_of_transformed_samples( + index_array) + + if params['method'] == 'predict': + batch_preds = estimator.predict( + batch_X, + # The presence of `pred_data_generator` below is to + # override model carrying data_generator if there + # is any. + data_generator=pred_data_generator) else: - preds = (preds > 0.5).astype('int32') - else: - classes = (preds > 0.5).astype('int32') + batch_preds = estimator.predict_proba( + batch_X, + # The presence of `pred_data_generator` below is to + # override model carrying data_generator if there + # is any. + data_generator=pred_data_generator) + + if batch_preds.ndim == 1: + batch_preds = batch_preds[:, np.newaxis] + + batch_meta = variants[index_array] + batch_out = np.column_stack([batch_meta, batch_preds]) - preds = estimator.classes_[classes] + if not header_done: + heads = np.arange(batch_preds.shape[-1]).astype(str) + heads_str = '\t'.join(heads) + file_writer.write("\t%s\n" % heads_str) + header_done = True + + for row in batch_out: + row_str = '\t'.join(row) + file_writer.write("%s\n" % row_str) + + steps_done += 1 + + finally: + file_writer.close() + # TODO: make api `pred_data_generator.close()` + pred_data_generator.close() + return 0 # end input # output - if input_type == 'variant_effect': # TODO: save in batchs - rval = pd.DataFrame(preds) - meta = pd.DataFrame( - pred_data_generator.variants, - columns=['chrom', 'pos', 'name', 'ref', 'alt', 'strand']) - - rval = pd.concat([meta, rval], axis=1) - - elif len(preds.shape) == 1: + if len(preds.shape) == 1: rval = pd.DataFrame(preds, columns=['Predicted']) else: rval = pd.DataFrame(preds) - rval.to_csv(outfile_predict, sep='\t', - header=True, index=False) + rval.to_csv(outfile_predict, sep='\t', header=True, index=False) if __name__ == '__main__':