diff Marea/marea_cluster.py @ 28:e6831924df01 draft

small fixes (elbow plot and output managment)
author bimib
date Mon, 14 Oct 2019 05:01:08 -0400
parents 9992eba50cfb
children 944e15aa970a
line wrap: on
line diff
--- a/Marea/marea_cluster.py	Mon Oct 07 13:48:01 2019 -0400
+++ b/Marea/marea_cluster.py	Mon Oct 14 05:01:08 2019 -0400
@@ -78,6 +78,11 @@
     parser.add_argument('-ep', '--eps',
                         type = int,
                         help = 'eps for dbscan (optional)')
+                        
+    parser.add_argument('-bc', '--best_cluster',
+                        type = str,
+                        help = 'output of best cluster tsv')
+    				
     
     
     args = parser.parse_args()
@@ -131,20 +136,7 @@
     dest = name
     classe.to_csv(dest, sep = '\t', index = False,
                       header = ['Patient_ID', 'Class'])
-    
-
-      #list_labels = labels
-    #list_values = dataset
-
-    #list_values = list_values.tolist()
-    #d = {'Label' : list_labels, 'Value' : list_values}
-    
-    #df = pd.DataFrame(d, columns=['Value','Label'])
-
-    #dest = name + '.tsv'
-    #df.to_csv(dest, sep = '\t', index = False,
-     #                 header = ['Value', 'Label'])
-    
+   
 ########################### trova il massimo in lista ########################
 def max_index (lista):
     best = -1
@@ -158,7 +150,7 @@
     
 ################################ kmeans #####################################
     
-def kmeans (k_min, k_max, dataset, elbow, silhouette, davies):
+def kmeans (k_min, k_max, dataset, elbow, silhouette, davies, best_cluster):
     if not os.path.exists('clustering'):
         os.makedirs('clustering')
     
@@ -189,7 +181,10 @@
         cluster_labels = clusterer.fit_predict(dataset)
         
         all_labels.append(cluster_labels)
-        silhouette_avg = silhouette_score(dataset, cluster_labels)
+        if n_clusters == 1:
+        	silhouette_avg = 0
+        else:
+            silhouette_avg = silhouette_score(dataset, cluster_labels)
         scores.append(silhouette_avg)
         distortions.append(clusterer.fit(dataset).inertia_)
         
@@ -201,6 +196,14 @@
             prefix = '_BEST'
             
         write_to_csv(dataset, all_labels[i], 'clustering/kmeans_with_' + str(i + k_min) + prefix + '_clusters.tsv')
+        
+        
+        if (prefix == '_BEST'):
+            labels = all_labels[i]
+            predict = [x+1 for x in labels]
+            classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
+            classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class'])
+            
             
         if davies:
             with np.errstate(divide='ignore', invalid='ignore'):
@@ -235,6 +238,9 @@
     
 ############################## silhouette plot ###############################
 def silihouette_draw(dataset, labels, n_clusters, path):
+    if n_clusters == 1:
+        return None
+        
     silhouette_avg = silhouette_score(dataset, labels)
     warning("For n_clusters = " + str(n_clusters) +
           " The average silhouette_score is: " + str(silhouette_avg))
@@ -375,7 +381,7 @@
     
     
     if args.cluster_type == 'kmeans':
-        kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.davies)
+        kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.davies, args.best_cluster)
     
     if args.cluster_type == 'dbscan':
         dbscan(X, args.eps, args.min_samples)
@@ -386,4 +392,4 @@
 ##############################################################################
 
 if __name__ == "__main__":
-    main()
\ No newline at end of file
+    main()