comparison COBRAxy/marea_cluster.py @ 147:3fca9b568faf draft

Uploaded
author bimib
date Wed, 06 Nov 2024 13:57:24 +0000
parents 41f35c2f0c7b
children 7f3552eaf774
comparison
equal deleted inserted replaced
146:88cf4543e210 147:3fca9b568faf
18 import scipy.cluster.hierarchy as shc 18 import scipy.cluster.hierarchy as shc
19 import matplotlib.cm as cm 19 import matplotlib.cm as cm
20 from typing import Optional, Dict, List 20 from typing import Optional, Dict, List
21 21
22 ################################# process args ############################### 22 ################################# process args ###############################
23 def process_args(args :List[str]) -> argparse.Namespace: 23 def process_args(args :List[str] = None) -> argparse.Namespace:
24 """ 24 """
25 Processes command-line arguments. 25 Processes command-line arguments.
26 26
27 Args: 27 Args:
28 args (list): List of command-line arguments. 28 args (list): List of command-line arguments.
84 84
85 parser.add_argument('-bc', '--best_cluster', 85 parser.add_argument('-bc', '--best_cluster',
86 type = str, 86 type = str,
87 help = 'output of best cluster tsv') 87 help = 'output of best cluster tsv')
88 88
89 89 parser.add_argument(
90 90 '-idop', '--output_path',
91 args = parser.parse_args() 91 type = str,
92 default='result',
93 help = 'output path for maps')
94
95 args = parser.parse_args(args)
92 return args 96 return args
93 97
94 ########################### warning ########################################### 98 ########################### warning ###########################################
95 def warning(s :str) -> None: 99 def warning(s :str) -> None:
96 """ 100 """
215 best_cluster (str): The file path to save the output of the best cluster. 219 best_cluster (str): The file path to save the output of the best cluster.
216 220
217 Returns: 221 Returns:
218 None 222 None
219 """ 223 """
220 if not os.path.exists('clustering'): 224 if not os.path.exists(args.output_path):
221 os.makedirs('clustering') 225 os.makedirs(args.output_path)
222 226
223 227
224 if elbow == 'true': 228 if elbow == 'true':
225 elbow = True 229 elbow = True
226 else: 230 else:
257 for i in range(len(all_labels)): 261 for i in range(len(all_labels)):
258 prefix = '' 262 prefix = ''
259 if (i + k_min == best): 263 if (i + k_min == best):
260 prefix = '_BEST' 264 prefix = '_BEST'
261 265
262 write_to_csv(dataset, all_labels[i], 'clustering/kmeans_with_' + str(i + k_min) + prefix + '_clusters.tsv') 266 write_to_csv(dataset, all_labels[i], f'{args.output_path}/kmeans_with_' + str(i + k_min) + prefix + '_clusters.tsv')
263 267
264 268
265 if (prefix == '_BEST'): 269 if (prefix == '_BEST'):
266 labels = all_labels[i] 270 labels = all_labels[i]
267 predict = [x+1 for x in labels] 271 predict = [x+1 for x in labels]
270 274
271 275
272 276
273 277
274 if silhouette: 278 if silhouette:
275 silhouette_draw(dataset, all_labels[i], i + k_min, 'clustering/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png') 279 silhouette_draw(dataset, all_labels[i], i + k_min, f'{args.output_path}/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png')
276 280
277 281
278 if elbow: 282 if elbow:
279 elbow_plot(distortions, k_min,k_max) 283 elbow_plot(distortions, k_min,k_max)
280 284
301 x = list(range(k_min, k_max + 1)) 305 x = list(range(k_min, k_max + 1))
302 x.insert(0, 1) 306 x.insert(0, 1)
303 plt.plot(x, distortions, marker = 'o') 307 plt.plot(x, distortions, marker = 'o')
304 plt.xlabel('Number of clusters (k)') 308 plt.xlabel('Number of clusters (k)')
305 plt.ylabel('Distortion') 309 plt.ylabel('Distortion')
306 s = 'clustering/elbow_plot.png' 310 s = f'{args.output_path}/elbow_plot.png'
307 fig = plt.gcf() 311 fig = plt.gcf()
308 fig.set_size_inches(18.5, 10.5, forward = True) 312 fig.set_size_inches(18.5, 10.5, forward = True)
309 fig.savefig(s, dpi=100) 313 fig.savefig(s, dpi=100)
310 314
311 315
404 best_cluster (str): The file path to save the output of the best cluster. 408 best_cluster (str): The file path to save the output of the best cluster.
405 409
406 Returns: 410 Returns:
407 None 411 None
408 """ 412 """
409 if not os.path.exists('clustering'): 413 if not os.path.exists(args.output_path):
410 os.makedirs('clustering') 414 os.makedirs(args.output_path)
411 415
412 if eps is not None: 416 if eps is not None:
413 clusterer = DBSCAN(eps = eps, min_samples = min_samples) 417 clusterer = DBSCAN(eps = eps, min_samples = min_samples)
414 else: 418 else:
415 clusterer = DBSCAN() 419 clusterer = DBSCAN()
443 silhouette (str): Whether to generate silhouette plots ('true' or 'false'). 447 silhouette (str): Whether to generate silhouette plots ('true' or 'false').
444 448
445 Returns: 449 Returns:
446 None 450 None
447 """ 451 """
448 if not os.path.exists('clustering'): 452 if not os.path.exists(args.output_path):
449 os.makedirs('clustering') 453 os.makedirs(args.output_path)
450 454
451 plt.figure(figsize=(10, 7)) 455 plt.figure(figsize=(10, 7))
452 plt.title("Customer Dendograms") 456 plt.title("Customer Dendograms")
453 shc.dendrogram(shc.linkage(dataset, method='ward'), labels=dataset.index.values.tolist()) 457 shc.dendrogram(shc.linkage(dataset, method='ward'), labels=dataset.index.values.tolist())
454 fig = plt.gcf() 458 fig = plt.gcf()
455 fig.savefig('clustering/dendogram.png', dpi=200) 459 fig.savefig(f'{args.output_path}/dendogram.png', dpi=200)
456 460
457 range_n_clusters = [i for i in range(k_min, k_max+1)] 461 range_n_clusters = [i for i in range(k_min, k_max+1)]
458 462
459 scores = [] 463 scores = []
460 labels = [] 464 labels = []
464 for n_clusters in range_n_clusters: 468 for n_clusters in range_n_clusters:
465 cluster = AgglomerativeClustering(n_clusters=n_clusters, affinity='euclidean', linkage='ward') 469 cluster = AgglomerativeClustering(n_clusters=n_clusters, affinity='euclidean', linkage='ward')
466 cluster.fit_predict(dataset) 470 cluster.fit_predict(dataset)
467 cluster_labels = cluster.labels_ 471 cluster_labels = cluster.labels_
468 labels.append(cluster_labels) 472 labels.append(cluster_labels)
469 write_to_csv(dataset, cluster_labels, 'clustering/hierarchical_with_' + str(n_clusters) + '_clusters.tsv') 473 write_to_csv(dataset, cluster_labels, f'{args.output_path}/hierarchical_with_' + str(n_clusters) + '_clusters.tsv')
470 474
471 best = max_index(scores) + k_min 475 best = max_index(scores) + k_min
472 476
473 for i in range(len(labels)): 477 for i in range(len(labels)):
474 prefix = '' 478 prefix = ''
475 if (i + k_min == best): 479 if (i + k_min == best):
476 prefix = '_BEST' 480 prefix = '_BEST'
477 if silhouette == 'true': 481 if silhouette == 'true':
478 silhouette_draw(dataset, labels[i], i + k_min, 'clustering/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png') 482 silhouette_draw(dataset, labels[i], i + k_min, f'{args.output_path}/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png')
479 483
480 for i in range(len(labels)): 484 for i in range(len(labels)):
481 if (i + k_min == best): 485 if (i + k_min == best):
482 labels = labels[i] 486 labels = labels[i]
483 predict = [x+1 for x in labels] 487 predict = [x+1 for x in labels]
484 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str) 488 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
485 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class']) 489 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class'])
486 490
487 491
488 ############################# main ########################################### 492 ############################# main ###########################################
489 def main() -> None: 493 def main(args_in:List[str] = None) -> None:
490 """ 494 """
491 Initializes everything and sets the program in motion based on the fronted input arguments. 495 Initializes everything and sets the program in motion based on the fronted input arguments.
492 496
493 Returns: 497 Returns:
494 None 498 None
495 """ 499 """
496 if not os.path.exists('clustering'): 500 global args
497 os.makedirs('clustering') 501 args = process_args(args_in)
498 502
499 args = process_args(sys.argv) 503 if not os.path.exists(args.output_path):
504 os.makedirs(args.output_path)
500 505
501 #Data read 506 #Data read
502 507
503 X = read_dataset(args.input) 508 X = read_dataset(args.input)
504 X = pd.DataFrame.to_dict(X, orient='list') 509 X = pd.DataFrame.to_dict(X, orient='list')