Mercurial > repos > bimib > marea
comparison Marea/marea_cluster.py @ 13:e96f3b85e5a0 draft
Uploaded
author | bimib |
---|---|
date | Wed, 13 Feb 2019 05:42:20 -0500 |
parents | 3d77287caf22 |
children | 1a0c8c2780f2 |
comparison
equal
deleted
inserted
replaced
12:3d77287caf22 | 13:e96f3b85e5a0 |
---|---|
1 | |
2 from __future__ import division | 1 from __future__ import division |
3 import os | 2 import os |
4 import sys | 3 import sys |
5 import pandas as pd | 4 import pandas as pd |
6 import collections | 5 import collections |
7 import pickle as pk | 6 import pickle as pk |
8 import argparse | 7 import argparse |
9 from sklearn.cluster import KMeans | 8 from sklearn.cluster import KMeans |
9 import matplotlib | |
10 matplotlib.use('GTKAgg') | |
10 import matplotlib.pyplot as plt | 11 import matplotlib.pyplot as plt |
11 | 12 |
12 ########################## argparse ########################################### | 13 ########################## argparse ########################################### |
13 | 14 |
14 def process_args(args): | 15 def process_args(args): |
538 ################################# clustering ################################## | 539 ################################# clustering ################################## |
539 | 540 |
540 def f_cluster(resolve_rules): | 541 def f_cluster(resolve_rules): |
541 os.makedirs('cluster_out') | 542 os.makedirs('cluster_out') |
542 args = process_args(sys.argv) | 543 args = process_args(sys.argv) |
544 k_min = args.k_min | |
545 k_max = args.k_max | |
546 if k_min > k_max: | |
547 warning('k range boundaries inverted.\n') | |
548 tmp = k_min | |
549 k_min = k_max | |
550 k_max = tmp | |
551 else: | |
552 warning('k range correct.\n') | |
543 cluster_data = pd.DataFrame.from_dict(resolve_rules, orient = 'index') | 553 cluster_data = pd.DataFrame.from_dict(resolve_rules, orient = 'index') |
544 for i in cluster_data.columns: | 554 for i in cluster_data.columns: |
545 tmp = cluster_data[i][0] | 555 tmp = cluster_data[i][0] |
546 if tmp == None: | 556 if tmp == None: |
547 cluster_data = cluster_data.drop(columns=[i]) | 557 cluster_data = cluster_data.drop(columns=[i]) |
548 distorsion = [] | 558 distorsion = [] |
549 for i in range(args.k_min, args.k_max+1): | 559 for i in range(k_min, k_max+1): |
550 tmp_kmeans = KMeans(n_clusters = i, | 560 tmp_kmeans = KMeans(n_clusters = i, |
551 n_init = 100, | 561 n_init = 100, |
552 max_iter = 300, | 562 max_iter = 300, |
553 random_state = 0).fit(cluster_data) | 563 random_state = 0).fit(cluster_data) |
554 distorsion.append(tmp_kmeans.inertia_) | 564 distorsion.append(tmp_kmeans.inertia_) |
557 classe = (pd.DataFrame(list(zip(cluster_data.index, predict)))).astype(str) | 567 classe = (pd.DataFrame(list(zip(cluster_data.index, predict)))).astype(str) |
558 dest = 'cluster_out/K=' + str(i) + '_' + args.name+'.tsv' | 568 dest = 'cluster_out/K=' + str(i) + '_' + args.name+'.tsv' |
559 classe.to_csv(dest, sep = '\t', index = False, | 569 classe.to_csv(dest, sep = '\t', index = False, |
560 header = ['Patient_ID', 'Class']) | 570 header = ['Patient_ID', 'Class']) |
561 plt.figure(0) | 571 plt.figure(0) |
562 plt.plot(range(args.k_min, args.k_max+1), distorsion, marker = 'o') | 572 plt.plot(range(k_min, k_max+1), distorsion, marker = 'o') |
563 plt.xlabel('Number of cluster') | 573 plt.xlabel('Number of cluster') |
564 plt.ylabel('Distorsion') | 574 plt.ylabel('Distorsion') |
565 plt.savefig(args.elbow, dpi = 240, format = 'pdf') | 575 plt.savefig(args.elbow, dpi = 240, format = 'pdf') |
566 if args.cond_hier == 'yes': | 576 if args.cond_hier == 'yes': |
567 import scipy.cluster.hierarchy as hier | 577 import scipy.cluster.hierarchy as hier |
574 | 584 |
575 ################################# main ######################################## | 585 ################################# main ######################################## |
576 | 586 |
577 def main(): | 587 def main(): |
578 args = process_args(sys.argv) | 588 args = process_args(sys.argv) |
579 if args.k_min > args.k_max: | |
580 sys.exit('Execution aborted: max cluster > min cluster') | |
581 if args.rules_selector == 'HMRcore': | 589 if args.rules_selector == 'HMRcore': |
582 recon = pk.load(open(args.tool_dir + '/local/HMRcore_rules.p', 'rb')) | 590 recon = pk.load(open(args.tool_dir + '/local/HMRcore_rules.p', 'rb')) |
583 elif args.rules_selector == 'Recon': | 591 elif args.rules_selector == 'Recon': |
584 recon = pk.load(open(args.tool_dir + '/local/Recon_rules.p', 'rb')) | 592 recon = pk.load(open(args.tool_dir + '/local/Recon_rules.p', 'rb')) |
585 elif args.rules_selector == 'Custom': | 593 elif args.rules_selector == 'Custom': |