Mercurial > repos > bimib > marea
comparison Marea/marea_cluster.py @ 28:e6831924df01 draft
small fixes (elbow plot and output managment)
author | bimib |
---|---|
date | Mon, 14 Oct 2019 05:01:08 -0400 |
parents | 9992eba50cfb |
children | 944e15aa970a |
comparison
equal
deleted
inserted
replaced
27:8c480c977a12 | 28:e6831924df01 |
---|---|
76 help = 'min samples for dbscan (optional)') | 76 help = 'min samples for dbscan (optional)') |
77 | 77 |
78 parser.add_argument('-ep', '--eps', | 78 parser.add_argument('-ep', '--eps', |
79 type = int, | 79 type = int, |
80 help = 'eps for dbscan (optional)') | 80 help = 'eps for dbscan (optional)') |
81 | |
82 parser.add_argument('-bc', '--best_cluster', | |
83 type = str, | |
84 help = 'output of best cluster tsv') | |
85 | |
81 | 86 |
82 | 87 |
83 args = parser.parse_args() | 88 args = parser.parse_args() |
84 return args | 89 return args |
85 | 90 |
129 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str) | 134 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str) |
130 | 135 |
131 dest = name | 136 dest = name |
132 classe.to_csv(dest, sep = '\t', index = False, | 137 classe.to_csv(dest, sep = '\t', index = False, |
133 header = ['Patient_ID', 'Class']) | 138 header = ['Patient_ID', 'Class']) |
134 | 139 |
135 | |
136 #list_labels = labels | |
137 #list_values = dataset | |
138 | |
139 #list_values = list_values.tolist() | |
140 #d = {'Label' : list_labels, 'Value' : list_values} | |
141 | |
142 #df = pd.DataFrame(d, columns=['Value','Label']) | |
143 | |
144 #dest = name + '.tsv' | |
145 #df.to_csv(dest, sep = '\t', index = False, | |
146 # header = ['Value', 'Label']) | |
147 | |
148 ########################### trova il massimo in lista ######################## | 140 ########################### trova il massimo in lista ######################## |
149 def max_index (lista): | 141 def max_index (lista): |
150 best = -1 | 142 best = -1 |
151 best_index = 0 | 143 best_index = 0 |
152 for i in range(len(lista)): | 144 for i in range(len(lista)): |
156 | 148 |
157 return best_index | 149 return best_index |
158 | 150 |
159 ################################ kmeans ##################################### | 151 ################################ kmeans ##################################### |
160 | 152 |
161 def kmeans (k_min, k_max, dataset, elbow, silhouette, davies): | 153 def kmeans (k_min, k_max, dataset, elbow, silhouette, davies, best_cluster): |
162 if not os.path.exists('clustering'): | 154 if not os.path.exists('clustering'): |
163 os.makedirs('clustering') | 155 os.makedirs('clustering') |
164 | 156 |
165 | 157 |
166 if elbow == 'true': | 158 if elbow == 'true': |
187 for n_clusters in range_n_clusters: | 179 for n_clusters in range_n_clusters: |
188 clusterer = KMeans(n_clusters=n_clusters, random_state=10) | 180 clusterer = KMeans(n_clusters=n_clusters, random_state=10) |
189 cluster_labels = clusterer.fit_predict(dataset) | 181 cluster_labels = clusterer.fit_predict(dataset) |
190 | 182 |
191 all_labels.append(cluster_labels) | 183 all_labels.append(cluster_labels) |
192 silhouette_avg = silhouette_score(dataset, cluster_labels) | 184 if n_clusters == 1: |
185 silhouette_avg = 0 | |
186 else: | |
187 silhouette_avg = silhouette_score(dataset, cluster_labels) | |
193 scores.append(silhouette_avg) | 188 scores.append(silhouette_avg) |
194 distortions.append(clusterer.fit(dataset).inertia_) | 189 distortions.append(clusterer.fit(dataset).inertia_) |
195 | 190 |
196 best = max_index(scores) + k_min | 191 best = max_index(scores) + k_min |
197 | 192 |
199 prefix = '' | 194 prefix = '' |
200 if (i + k_min == best): | 195 if (i + k_min == best): |
201 prefix = '_BEST' | 196 prefix = '_BEST' |
202 | 197 |
203 write_to_csv(dataset, all_labels[i], 'clustering/kmeans_with_' + str(i + k_min) + prefix + '_clusters.tsv') | 198 write_to_csv(dataset, all_labels[i], 'clustering/kmeans_with_' + str(i + k_min) + prefix + '_clusters.tsv') |
199 | |
200 | |
201 if (prefix == '_BEST'): | |
202 labels = all_labels[i] | |
203 predict = [x+1 for x in labels] | |
204 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str) | |
205 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class']) | |
206 | |
204 | 207 |
205 if davies: | 208 if davies: |
206 with np.errstate(divide='ignore', invalid='ignore'): | 209 with np.errstate(divide='ignore', invalid='ignore'): |
207 davies_bouldin = davies_bouldin_score(dataset, all_labels[i]) | 210 davies_bouldin = davies_bouldin_score(dataset, all_labels[i]) |
208 warning("\nFor n_clusters = " + str(i + k_min) + | 211 warning("\nFor n_clusters = " + str(i + k_min) + |
233 fig.savefig(s, dpi=100) | 236 fig.savefig(s, dpi=100) |
234 | 237 |
235 | 238 |
236 ############################## silhouette plot ############################### | 239 ############################## silhouette plot ############################### |
237 def silihouette_draw(dataset, labels, n_clusters, path): | 240 def silihouette_draw(dataset, labels, n_clusters, path): |
241 if n_clusters == 1: | |
242 return None | |
243 | |
238 silhouette_avg = silhouette_score(dataset, labels) | 244 silhouette_avg = silhouette_score(dataset, labels) |
239 warning("For n_clusters = " + str(n_clusters) + | 245 warning("For n_clusters = " + str(n_clusters) + |
240 " The average silhouette_score is: " + str(silhouette_avg)) | 246 " The average silhouette_score is: " + str(silhouette_avg)) |
241 | 247 |
242 plt.close('all') | 248 plt.close('all') |
373 if tmp == None: | 379 if tmp == None: |
374 X = X.drop(columns=[i]) | 380 X = X.drop(columns=[i]) |
375 | 381 |
376 | 382 |
377 if args.cluster_type == 'kmeans': | 383 if args.cluster_type == 'kmeans': |
378 kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.davies) | 384 kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.davies, args.best_cluster) |
379 | 385 |
380 if args.cluster_type == 'dbscan': | 386 if args.cluster_type == 'dbscan': |
381 dbscan(X, args.eps, args.min_samples) | 387 dbscan(X, args.eps, args.min_samples) |
382 | 388 |
383 if args.cluster_type == 'hierarchy': | 389 if args.cluster_type == 'hierarchy': |