Mercurial > repos > bimib > marea
comparison Marea/marea_cluster.py @ 33:abf0bfe01c78 draft
Uploaded
author | bimib |
---|---|
date | Wed, 16 Oct 2019 16:25:56 -0400 |
parents | 944e15aa970a |
children | 1a97d1537623 |
comparison
equal
deleted
inserted
replaced
32:b795e3e163e0 | 33:abf0bfe01c78 |
---|---|
70 type = str, | 70 type = str, |
71 required = True, | 71 required = True, |
72 help = 'your tool directory') | 72 help = 'your tool directory') |
73 | 73 |
74 parser.add_argument('-ms', '--min_samples', | 74 parser.add_argument('-ms', '--min_samples', |
75 type = int, | 75 type = float, |
76 help = 'min samples for dbscan (optional)') | 76 help = 'min samples for dbscan (optional)') |
77 | 77 |
78 parser.add_argument('-ep', '--eps', | 78 parser.add_argument('-ep', '--eps', |
79 type = int, | 79 type = float, |
80 help = 'eps for dbscan (optional)') | 80 help = 'eps for dbscan (optional)') |
81 | 81 |
82 parser.add_argument('-bc', '--best_cluster', | 82 parser.add_argument('-bc', '--best_cluster', |
83 type = str, | 83 type = str, |
84 help = 'output of best cluster tsv') | 84 help = 'output of best cluster tsv') |
308 | 308 |
309 plt.savefig(path, bbox_inches='tight') | 309 plt.savefig(path, bbox_inches='tight') |
310 | 310 |
311 ######################## dbscan ############################################## | 311 ######################## dbscan ############################################## |
312 | 312 |
313 def dbscan(dataset, eps, min_samples): | 313 def dbscan(dataset, eps, min_samples, best_cluster): |
314 if not os.path.exists('clustering'): | 314 if not os.path.exists('clustering'): |
315 os.makedirs('clustering') | 315 os.makedirs('clustering') |
316 | 316 |
317 if eps is not None: | 317 if eps is not None: |
318 clusterer = DBSCAN(eps = eps, min_samples = min_samples) | 318 clusterer = DBSCAN(eps = eps, min_samples = min_samples) |
329 n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) | 329 n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) |
330 | 330 |
331 | 331 |
332 ##TODO: PLOT SU DBSCAN (no centers) e HIERARCHICAL | 332 ##TODO: PLOT SU DBSCAN (no centers) e HIERARCHICAL |
333 | 333 |
334 | 334 labels = labels |
335 write_to_csv(dataset, labels, 'clustering/dbscan_results.tsv') | 335 predict = [x+1 for x in labels] |
336 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str) | |
337 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class']) | |
338 | |
336 | 339 |
337 ########################## hierachical ####################################### | 340 ########################## hierachical ####################################### |
338 | 341 |
339 def hierachical_agglomerative(dataset, k_min, k_max): | 342 def hierachical_agglomerative(dataset, k_min, k_max, best_cluster): |
340 | 343 |
341 if not os.path.exists('clustering'): | 344 if not os.path.exists('clustering'): |
342 os.makedirs('clustering') | 345 os.makedirs('clustering') |
343 | 346 |
344 plt.figure(figsize=(10, 7)) | 347 plt.figure(figsize=(10, 7)) |
347 fig = plt.gcf() | 350 fig = plt.gcf() |
348 fig.savefig('clustering/dendogram.png', dpi=200) | 351 fig.savefig('clustering/dendogram.png', dpi=200) |
349 | 352 |
350 range_n_clusters = [i for i in range(k_min, k_max+1)] | 353 range_n_clusters = [i for i in range(k_min, k_max+1)] |
351 | 354 |
352 for n_clusters in range_n_clusters: | 355 scores = [] |
353 | 356 labels = [] |
357 for n_clusters in range_n_clusters: | |
354 cluster = AgglomerativeClustering(n_clusters=n_clusters, affinity='euclidean', linkage='ward') | 358 cluster = AgglomerativeClustering(n_clusters=n_clusters, affinity='euclidean', linkage='ward') |
355 cluster.fit_predict(dataset) | 359 cluster.fit_predict(dataset) |
356 cluster_labels = cluster.labels_ | 360 cluster_labels = cluster.labels_ |
357 | 361 labels.append(cluster_labels) |
358 silhouette_avg = silhouette_score(dataset, cluster_labels) | 362 silhouette_avg = silhouette_score(dataset, cluster_labels) |
359 write_to_csv(dataset, cluster_labels, 'clustering/hierarchical_with_' + str(n_clusters) + '_clusters.tsv') | 363 write_to_csv(dataset, cluster_labels, 'clustering/hierarchical_with_' + str(n_clusters) + '_clusters.tsv') |
364 scores.append(silhouette_avg) | |
360 #warning("For n_clusters =", n_clusters, | 365 #warning("For n_clusters =", n_clusters, |
361 #"The average silhouette_score is :", silhouette_avg) | 366 #"The average silhouette_score is :", silhouette_avg) |
367 | |
368 best = max_index(scores) + k_min | |
369 | |
370 for i in range(len(labels)): | |
371 if (i + k_min == best): | |
372 labels = labels[i] | |
373 predict = [x+1 for x in labels] | |
374 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str) | |
375 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class']) | |
376 | |
362 | 377 |
363 | 378 |
364 | 379 |
365 | 380 |
366 | 381 |
388 | 403 |
389 if args.cluster_type == 'kmeans': | 404 if args.cluster_type == 'kmeans': |
390 kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.davies, args.best_cluster) | 405 kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.davies, args.best_cluster) |
391 | 406 |
392 if args.cluster_type == 'dbscan': | 407 if args.cluster_type == 'dbscan': |
393 dbscan(X, args.eps, args.min_samples) | 408 dbscan(X, args.eps, args.min_samples, args.best_cluster) |
394 | 409 |
395 if args.cluster_type == 'hierarchy': | 410 if args.cluster_type == 'hierarchy': |
396 hierachical_agglomerative(X, args.k_min, args.k_max) | 411 hierachical_agglomerative(X, args.k_min, args.k_max, args.best_cluster) |
397 | 412 |
398 ############################################################################## | 413 ############################################################################## |
399 | 414 |
400 if __name__ == "__main__": | 415 if __name__ == "__main__": |
401 main() | 416 main() |