Mercurial > repos > bimib > marea
diff Marea/marea_cluster.py @ 28:e6831924df01 draft
small fixes (elbow plot and output managment)
author | bimib |
---|---|
date | Mon, 14 Oct 2019 05:01:08 -0400 |
parents | 9992eba50cfb |
children | 944e15aa970a |
line wrap: on
line diff
--- a/Marea/marea_cluster.py Mon Oct 07 13:48:01 2019 -0400 +++ b/Marea/marea_cluster.py Mon Oct 14 05:01:08 2019 -0400 @@ -78,6 +78,11 @@ parser.add_argument('-ep', '--eps', type = int, help = 'eps for dbscan (optional)') + + parser.add_argument('-bc', '--best_cluster', + type = str, + help = 'output of best cluster tsv') + args = parser.parse_args() @@ -131,20 +136,7 @@ dest = name classe.to_csv(dest, sep = '\t', index = False, header = ['Patient_ID', 'Class']) - - - #list_labels = labels - #list_values = dataset - - #list_values = list_values.tolist() - #d = {'Label' : list_labels, 'Value' : list_values} - - #df = pd.DataFrame(d, columns=['Value','Label']) - - #dest = name + '.tsv' - #df.to_csv(dest, sep = '\t', index = False, - # header = ['Value', 'Label']) - + ########################### trova il massimo in lista ######################## def max_index (lista): best = -1 @@ -158,7 +150,7 @@ ################################ kmeans ##################################### -def kmeans (k_min, k_max, dataset, elbow, silhouette, davies): +def kmeans (k_min, k_max, dataset, elbow, silhouette, davies, best_cluster): if not os.path.exists('clustering'): os.makedirs('clustering') @@ -189,7 +181,10 @@ cluster_labels = clusterer.fit_predict(dataset) all_labels.append(cluster_labels) - silhouette_avg = silhouette_score(dataset, cluster_labels) + if n_clusters == 1: + silhouette_avg = 0 + else: + silhouette_avg = silhouette_score(dataset, cluster_labels) scores.append(silhouette_avg) distortions.append(clusterer.fit(dataset).inertia_) @@ -201,6 +196,14 @@ prefix = '_BEST' write_to_csv(dataset, all_labels[i], 'clustering/kmeans_with_' + str(i + k_min) + prefix + '_clusters.tsv') + + + if (prefix == '_BEST'): + labels = all_labels[i] + predict = [x+1 for x in labels] + classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str) + classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class']) + if davies: with np.errstate(divide='ignore', invalid='ignore'): @@ -235,6 +238,9 @@ ############################## silhouette plot ############################### def silihouette_draw(dataset, labels, n_clusters, path): + if n_clusters == 1: + return None + silhouette_avg = silhouette_score(dataset, labels) warning("For n_clusters = " + str(n_clusters) + " The average silhouette_score is: " + str(silhouette_avg)) @@ -375,7 +381,7 @@ if args.cluster_type == 'kmeans': - kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.davies) + kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.davies, args.best_cluster) if args.cluster_type == 'dbscan': dbscan(X, args.eps, args.min_samples) @@ -386,4 +392,4 @@ ############################################################################## if __name__ == "__main__": - main() \ No newline at end of file + main()