Mercurial > repos > bimib > marea
comparison Marea/marea_cluster.py @ 33:abf0bfe01c78 draft
Uploaded
| author | bimib |
|---|---|
| date | Wed, 16 Oct 2019 16:25:56 -0400 |
| parents | 944e15aa970a |
| children | 1a97d1537623 |
comparison
equal
deleted
inserted
replaced
| 32:b795e3e163e0 | 33:abf0bfe01c78 |
|---|---|
| 70 type = str, | 70 type = str, |
| 71 required = True, | 71 required = True, |
| 72 help = 'your tool directory') | 72 help = 'your tool directory') |
| 73 | 73 |
| 74 parser.add_argument('-ms', '--min_samples', | 74 parser.add_argument('-ms', '--min_samples', |
| 75 type = int, | 75 type = float, |
| 76 help = 'min samples for dbscan (optional)') | 76 help = 'min samples for dbscan (optional)') |
| 77 | 77 |
| 78 parser.add_argument('-ep', '--eps', | 78 parser.add_argument('-ep', '--eps', |
| 79 type = int, | 79 type = float, |
| 80 help = 'eps for dbscan (optional)') | 80 help = 'eps for dbscan (optional)') |
| 81 | 81 |
| 82 parser.add_argument('-bc', '--best_cluster', | 82 parser.add_argument('-bc', '--best_cluster', |
| 83 type = str, | 83 type = str, |
| 84 help = 'output of best cluster tsv') | 84 help = 'output of best cluster tsv') |
| 308 | 308 |
| 309 plt.savefig(path, bbox_inches='tight') | 309 plt.savefig(path, bbox_inches='tight') |
| 310 | 310 |
| 311 ######################## dbscan ############################################## | 311 ######################## dbscan ############################################## |
| 312 | 312 |
| 313 def dbscan(dataset, eps, min_samples): | 313 def dbscan(dataset, eps, min_samples, best_cluster): |
| 314 if not os.path.exists('clustering'): | 314 if not os.path.exists('clustering'): |
| 315 os.makedirs('clustering') | 315 os.makedirs('clustering') |
| 316 | 316 |
| 317 if eps is not None: | 317 if eps is not None: |
| 318 clusterer = DBSCAN(eps = eps, min_samples = min_samples) | 318 clusterer = DBSCAN(eps = eps, min_samples = min_samples) |
| 329 n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) | 329 n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) |
| 330 | 330 |
| 331 | 331 |
| 332 ##TODO: PLOT SU DBSCAN (no centers) e HIERARCHICAL | 332 ##TODO: PLOT SU DBSCAN (no centers) e HIERARCHICAL |
| 333 | 333 |
| 334 | 334 labels = labels |
| 335 write_to_csv(dataset, labels, 'clustering/dbscan_results.tsv') | 335 predict = [x+1 for x in labels] |
| 336 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str) | |
| 337 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class']) | |
| 338 | |
| 336 | 339 |
| 337 ########################## hierachical ####################################### | 340 ########################## hierachical ####################################### |
| 338 | 341 |
| 339 def hierachical_agglomerative(dataset, k_min, k_max): | 342 def hierachical_agglomerative(dataset, k_min, k_max, best_cluster): |
| 340 | 343 |
| 341 if not os.path.exists('clustering'): | 344 if not os.path.exists('clustering'): |
| 342 os.makedirs('clustering') | 345 os.makedirs('clustering') |
| 343 | 346 |
| 344 plt.figure(figsize=(10, 7)) | 347 plt.figure(figsize=(10, 7)) |
| 347 fig = plt.gcf() | 350 fig = plt.gcf() |
| 348 fig.savefig('clustering/dendogram.png', dpi=200) | 351 fig.savefig('clustering/dendogram.png', dpi=200) |
| 349 | 352 |
| 350 range_n_clusters = [i for i in range(k_min, k_max+1)] | 353 range_n_clusters = [i for i in range(k_min, k_max+1)] |
| 351 | 354 |
| 352 for n_clusters in range_n_clusters: | 355 scores = [] |
| 353 | 356 labels = [] |
| 357 for n_clusters in range_n_clusters: | |
| 354 cluster = AgglomerativeClustering(n_clusters=n_clusters, affinity='euclidean', linkage='ward') | 358 cluster = AgglomerativeClustering(n_clusters=n_clusters, affinity='euclidean', linkage='ward') |
| 355 cluster.fit_predict(dataset) | 359 cluster.fit_predict(dataset) |
| 356 cluster_labels = cluster.labels_ | 360 cluster_labels = cluster.labels_ |
| 357 | 361 labels.append(cluster_labels) |
| 358 silhouette_avg = silhouette_score(dataset, cluster_labels) | 362 silhouette_avg = silhouette_score(dataset, cluster_labels) |
| 359 write_to_csv(dataset, cluster_labels, 'clustering/hierarchical_with_' + str(n_clusters) + '_clusters.tsv') | 363 write_to_csv(dataset, cluster_labels, 'clustering/hierarchical_with_' + str(n_clusters) + '_clusters.tsv') |
| 364 scores.append(silhouette_avg) | |
| 360 #warning("For n_clusters =", n_clusters, | 365 #warning("For n_clusters =", n_clusters, |
| 361 #"The average silhouette_score is :", silhouette_avg) | 366 #"The average silhouette_score is :", silhouette_avg) |
| 367 | |
| 368 best = max_index(scores) + k_min | |
| 369 | |
| 370 for i in range(len(labels)): | |
| 371 if (i + k_min == best): | |
| 372 labels = labels[i] | |
| 373 predict = [x+1 for x in labels] | |
| 374 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str) | |
| 375 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class']) | |
| 376 | |
| 362 | 377 |
| 363 | 378 |
| 364 | 379 |
| 365 | 380 |
| 366 | 381 |
| 388 | 403 |
| 389 if args.cluster_type == 'kmeans': | 404 if args.cluster_type == 'kmeans': |
| 390 kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.davies, args.best_cluster) | 405 kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.davies, args.best_cluster) |
| 391 | 406 |
| 392 if args.cluster_type == 'dbscan': | 407 if args.cluster_type == 'dbscan': |
| 393 dbscan(X, args.eps, args.min_samples) | 408 dbscan(X, args.eps, args.min_samples, args.best_cluster) |
| 394 | 409 |
| 395 if args.cluster_type == 'hierarchy': | 410 if args.cluster_type == 'hierarchy': |
| 396 hierachical_agglomerative(X, args.k_min, args.k_max) | 411 hierachical_agglomerative(X, args.k_min, args.k_max, args.best_cluster) |
| 397 | 412 |
| 398 ############################################################################## | 413 ############################################################################## |
| 399 | 414 |
| 400 if __name__ == "__main__": | 415 if __name__ == "__main__": |
| 401 main() | 416 main() |
