comparison COBRAxy/src/marea_cluster.py @ 539:2fb97466e404 draft

Uploaded
author francesco_lapi
date Sat, 25 Oct 2025 14:55:13 +0000
parents
children fcdbc81feb45
comparison
equal deleted inserted replaced
538:fd53d42348bd 539:2fb97466e404
1 # -*- coding: utf-8 -*-
2 """
3 Created on Mon Jun 3 19:51:00 2019
4 @author: Narger
5 """
6
7 import sys
8 import argparse
9 import os
10 import numpy as np
11 import pandas as pd
12 from sklearn.datasets import make_blobs
13 from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
14 from sklearn.metrics import silhouette_samples, silhouette_score, cluster
15 import matplotlib
16 matplotlib.use('agg')
17 import matplotlib.pyplot as plt
18 import scipy.cluster.hierarchy as shc
19 import matplotlib.cm as cm
20 from typing import Optional, Dict, List
21
22 ################################# process args ###############################
23 def process_args(args_in :List[str] = None) -> argparse.Namespace:
24 """
25 Processes command-line arguments.
26
27 Args:
28 args (list): List of command-line arguments.
29
30 Returns:
31 Namespace: An object containing parsed arguments.
32 """
33 parser = argparse.ArgumentParser(usage = '%(prog)s [options]',
34 description = 'process some value\'s' +
35 ' genes to create class.')
36
37 parser.add_argument('-ol', '--out_log',
38 help = "Output log")
39
40 parser.add_argument('-in', '--input',
41 type = str,
42 help = 'input dataset')
43
44 parser.add_argument('-cy', '--cluster_type',
45 type = str,
46 choices = ['kmeans', 'dbscan', 'hierarchy'],
47 default = 'kmeans',
48 help = 'choose clustering algorythm')
49
50 parser.add_argument('-sc', '--scaling',
51 type = str,
52 choices = ['true', 'false'],
53 default = 'true',
54 help = 'choose if you want to scaling the data')
55
56 parser.add_argument('-k1', '--k_min',
57 type = int,
58 default = 2,
59 help = 'choose minimun cluster number to be generated')
60
61 parser.add_argument('-k2', '--k_max',
62 type = int,
63 default = 7,
64 help = 'choose maximum cluster number to be generated')
65
66 parser.add_argument('-el', '--elbow',
67 type = str,
68 default = 'false',
69 choices = ['true', 'false'],
70 help = 'choose if you want to generate an elbow plot for kmeans')
71
72 parser.add_argument('-si', '--silhouette',
73 type = str,
74 default = 'false',
75 choices = ['true', 'false'],
76 help = 'choose if you want silhouette plots')
77
78 parser.add_argument('-td', '--tool_dir',
79 type = str,
80 required = True,
81 help = 'your tool directory')
82
83 parser.add_argument('-ms', '--min_samples',
84 type = int,
85 help = 'min samples for dbscan (optional)')
86
87 parser.add_argument('-ep', '--eps',
88 type = float,
89 help = 'eps for dbscan (optional)')
90
91 parser.add_argument('-bc', '--best_cluster',
92 type = str,
93 help = 'output of best cluster tsv')
94
95 parser.add_argument(
96 '-idop', '--output_path',
97 type = str,
98 default='clustering/',
99 help = 'output path for maps')
100
101 args_in = parser.parse_args(args_in)
102 return args_in
103
104 ########################### warning ###########################################
105 def warning(s :str) -> None:
106 """
107 Log a warning message to an output log file and print it to the console.
108
109 Args:
110 s (str): The warning message to be logged and printed.
111
112 Returns:
113 None
114 """
115
116 with open(args.out_log, 'a') as log:
117 log.write(s + "\n\n")
118 print(s)
119
120 ########################## read dataset ######################################
121 def read_dataset(dataset :str) -> pd.DataFrame:
122 """
123 Read dataset from a CSV file and return it as a Pandas DataFrame.
124
125 Args:
126 dataset (str): the path to the dataset to convert into a DataFrame
127
128 Returns:
129 pandas.DataFrame: The dataset loaded as a Pandas DataFrame.
130
131 Raises:
132 pandas.errors.EmptyDataError: If the dataset file is empty.
133 sys.exit: If the dataset file has the wrong format (e.g., fewer than 2 columns)
134 """
135 try:
136 dataset = pd.read_csv(dataset, sep = '\t', header = 0)
137 except pd.errors.EmptyDataError:
138 sys.exit('Execution aborted: wrong format of dataset\n')
139 if len(dataset.columns) < 2:
140 sys.exit('Execution aborted: wrong format of dataset\n')
141 return dataset
142
143 ############################ rewrite_input ###################################
144 def rewrite_input(dataset :Dict) -> Dict[str, List[Optional[float]]]:
145 """
146 Rewrite the dataset as a dictionary of lists instead of as a dictionary of dictionaries.
147
148 Args:
149 dataset (pandas.DataFrame): The dataset to be rewritten.
150
151 Returns:
152 dict: The rewritten dataset as a dictionary of lists.
153 """
154 #Riscrivo il dataset come dizionario di liste,
155 #non come dizionario di dizionari
156 #dataset.pop('Reactions', None)
157
158 for key, val in dataset.items():
159 l = []
160 for i in val:
161 if i == 'None':
162 l.append(None)
163 else:
164 l.append(float(i))
165
166 dataset[key] = l
167
168 return dataset
169
170 ############################## write to csv ##################################
171 def write_to_csv (dataset :pd.DataFrame, labels :List[str], name :str) -> None:
172 """
173 Write dataset and predicted labels to a CSV file.
174
175 Args:
176 dataset (pandas.DataFrame): The dataset to be written.
177 labels (list): The predicted labels for each data point.
178 name (str): The name of the output CSV file.
179
180 Returns:
181 None
182 """
183 #labels = predict
184 predict = [x+1 for x in labels]
185
186 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
187
188 dest = name
189 classe.to_csv(dest, sep = '\t', index = False,
190 header = ['Patient_ID', 'Class'])
191
192 ########################### trova il massimo in lista ########################
193 def max_index (lista :List[int]) -> int:
194 """
195 Find the index of the maximum value in a list.
196
197 Args:
198 lista (list): The list in which we search for the index of the maximum value.
199
200 Returns:
201 int: The index of the maximum value in the list.
202 """
203 best = -1
204 best_index = 0
205 for i in range(len(lista)):
206 if lista[i] > best:
207 best = lista [i]
208 best_index = i
209
210 return best_index
211
212 ################################ kmeans #####################################
213 def kmeans (k_min: int, k_max: int, dataset: pd.DataFrame, elbow: str, silhouette: str, best_cluster: str) -> None:
214 """
215 Perform k-means clustering on the given dataset, which is an algorithm used to partition a dataset into groups (clusters) based on their characteristics.
216 The goal is to divide the data into homogeneous groups, where the elements within each group are similar to each other and different from the elements in other groups.
217
218 Args:
219 k_min (int): The minimum number of clusters to consider.
220 k_max (int): The maximum number of clusters to consider.
221 dataset (pandas.DataFrame): The dataset to perform clustering on.
222 elbow (str): Whether to generate an elbow plot for kmeans ('True' or 'False').
223 silhouette (str): Whether to generate silhouette plots ('True' or 'False').
224 best_cluster (str): The file path to save the output of the best cluster.
225
226 Returns:
227 None
228 """
229 if not os.path.exists(args.output_path):
230 os.makedirs(args.output_path)
231
232
233 if elbow == 'true':
234 elbow = True
235 else:
236 elbow = False
237
238 if silhouette == 'true':
239 silhouette = True
240 else:
241 silhouette = False
242
243 range_n_clusters = [i for i in range(k_min, k_max+1)]
244 distortions = []
245 scores = []
246 all_labels = []
247
248 clusterer = KMeans(n_clusters=1, random_state=10)
249 distortions.append(clusterer.fit(dataset).inertia_)
250
251
252 for n_clusters in range_n_clusters:
253 clusterer = KMeans(n_clusters=n_clusters, random_state=10)
254 cluster_labels = clusterer.fit_predict(dataset)
255
256 all_labels.append(cluster_labels)
257 if n_clusters == 1:
258 silhouette_avg = 0
259 else:
260 silhouette_avg = silhouette_score(dataset, cluster_labels)
261 scores.append(silhouette_avg)
262 distortions.append(clusterer.fit(dataset).inertia_)
263
264 best = max_index(scores) + k_min
265
266 for i in range(len(all_labels)):
267 prefix = ''
268 if (i + k_min == best):
269 prefix = '_BEST'
270
271 write_to_csv(dataset, all_labels[i], f'{args.output_path}/kmeans_with_' + str(i + k_min) + prefix + '_clusters.tsv')
272
273
274 if (prefix == '_BEST'):
275 labels = all_labels[i]
276 predict = [x+1 for x in labels]
277 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
278 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class'])
279
280
281
282
283 if silhouette:
284 silhouette_draw(dataset, all_labels[i], i + k_min, f'{args.output_path}/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png')
285
286
287 if elbow:
288 elbow_plot(distortions, k_min,k_max)
289
290
291
292
293
294 ############################## elbow_plot ####################################
295 def elbow_plot (distortions: List[float], k_min: int, k_max: int) -> None:
296 """
297 Generate an elbow plot to visualize the distortion for different numbers of clusters.
298 The elbow plot is a graphical tool used in clustering analysis to help identifying the appropriate number of clusters by looking for the point where the rate of decrease
299 in distortion sharply decreases, indicating the optimal balance between model complexity and clustering quality.
300
301 Args:
302 distortions (list): List of distortion values for different numbers of clusters.
303 k_min (int): The minimum number of clusters considered.
304 k_max (int): The maximum number of clusters considered.
305
306 Returns:
307 None
308 """
309 plt.figure(0)
310 x = list(range(k_min, k_max + 1))
311 x.insert(0, 1)
312 plt.plot(x, distortions, marker = 'o')
313 plt.xlabel('Number of clusters (k)')
314 plt.ylabel('Distortion')
315 s = f'{args.output_path}/elbow_plot.png'
316 fig = plt.gcf()
317 fig.set_size_inches(18.5, 10.5, forward = True)
318 fig.savefig(s, dpi=100)
319
320
321 ############################## silhouette plot ###############################
322 def silhouette_draw(dataset: pd.DataFrame, labels: List[str], n_clusters: int, path:str) -> None:
323 """
324 Generate a silhouette plot for the clustering results.
325 The silhouette coefficient is a measure used to evaluate the quality of clusters obtained from a clustering algorithmand it quantifies how similar an object is to its own cluster compared to other clusters.
326 The silhouette coefficient ranges from -1 to 1, where:
327 - A value close to +1 indicates that the object is well matched to its own cluster and poorly matched to neighboring clusters. This implies that the object is in a dense, well-separated cluster.
328 - A value close to 0 indicates that the object is close to the decision boundary between two neighboring clusters.
329 - A value close to -1 indicates that the object may have been assigned to the wrong cluster.
330
331 Args:
332 dataset (pandas.DataFrame): The dataset used for clustering.
333 labels (list): The cluster labels assigned to each data point.
334 n_clusters (int): The number of clusters.
335 path (str): The path to save the silhouette plot image.
336
337 Returns:
338 None
339 """
340 if n_clusters == 1:
341 return None
342
343 silhouette_avg = silhouette_score(dataset, labels)
344 warning("For n_clusters = " + str(n_clusters) +
345 " The average silhouette_score is: " + str(silhouette_avg))
346
347 plt.close('all')
348 # Create a subplot with 1 row and 2 columns
349 fig, (ax1) = plt.subplots(1, 1)
350
351 fig.set_size_inches(18, 7)
352
353 # The 1st subplot is the silhouette plot
354 # The silhouette coefficient can range from -1, 1 but in this example all
355 # lie within [-0.1, 1]
356 ax1.set_xlim([-1, 1])
357 # The (n_clusters+1)*10 is for inserting blank space between silhouette
358 # plots of individual clusters, to demarcate them clearly.
359 ax1.set_ylim([0, len(dataset) + (n_clusters + 1) * 10])
360
361 # Compute the silhouette scores for each sample
362 sample_silhouette_values = silhouette_samples(dataset, labels)
363
364 y_lower = 10
365 for i in range(n_clusters):
366 # Aggregate the silhouette scores for samples belonging to
367 # cluster i, and sort them
368 ith_cluster_silhouette_values = \
369 sample_silhouette_values[labels == i]
370
371 ith_cluster_silhouette_values.sort()
372
373 size_cluster_i = ith_cluster_silhouette_values.shape[0]
374 y_upper = y_lower + size_cluster_i
375
376 color = cm.nipy_spectral(float(i) / n_clusters)
377 ax1.fill_betweenx(np.arange(y_lower, y_upper),
378 0, ith_cluster_silhouette_values,
379 facecolor=color, edgecolor=color, alpha=0.7)
380
381 # Label the silhouette plots with their cluster numbers at the middle
382 ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
383
384 # Compute the new y_lower for next plot
385 y_lower = y_upper + 10 # 10 for the 0 samples
386
387 ax1.set_title("The silhouette plot for the various clusters.")
388 ax1.set_xlabel("The silhouette coefficient values")
389 ax1.set_ylabel("Cluster label")
390
391 # The vertical line for average silhouette score of all the values
392 ax1.axvline(x=silhouette_avg, color="red", linestyle="--")
393
394 ax1.set_yticks([]) # Clear the yaxis labels / ticks
395 ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])
396
397
398 plt.suptitle(("Silhouette analysis for clustering on sample data "
399 "with n_clusters = " + str(n_clusters) + "\nAverage silhouette_score = " + str(silhouette_avg)), fontsize=12, fontweight='bold')
400
401
402 plt.savefig(path, bbox_inches='tight')
403
404 ######################## dbscan ##############################################
405 def dbscan(dataset: pd.DataFrame, eps: float, min_samples: float, best_cluster: str) -> None:
406 """
407 Perform DBSCAN clustering on the given dataset, which is a clustering algorithm that groups together closely packed points based on the notion of density.
408
409 Args:
410 dataset (pandas.DataFrame): The dataset to be clustered.
411 eps (float): The maximum distance between two samples for one to be considered as in the neighborhood of the other.
412 min_samples (float): The number of samples in a neighborhood for a point to be considered as a core point.
413 best_cluster (str): The file path to save the output of the best cluster.
414
415 Returns:
416 None
417 """
418 if not os.path.exists(args.output_path):
419 os.makedirs(args.output_path)
420
421 if eps is not None:
422 clusterer = DBSCAN(eps = eps, min_samples = min_samples)
423 else:
424 clusterer = DBSCAN()
425
426 clustering = clusterer.fit(dataset)
427
428 core_samples_mask = np.zeros_like(clustering.labels_, dtype=bool)
429 core_samples_mask[clustering.core_sample_indices_] = True
430 labels = clustering.labels_
431
432 # Number of clusters in labels, ignoring noise if present.
433 n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
434
435
436 labels = labels
437 predict = [x+1 for x in labels]
438 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
439 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class'])
440
441
442 ########################## hierachical #######################################
443 def hierachical_agglomerative(dataset: pd.DataFrame, k_min: int, k_max: int, best_cluster: str, silhouette: str) -> None:
444 """
445 Perform hierarchical agglomerative clustering on the given dataset.
446
447 Args:
448 dataset (pandas.DataFrame): The dataset to be clustered.
449 k_min (int): The minimum number of clusters to consider.
450 k_max (int): The maximum number of clusters to consider.
451 best_cluster (str): The file path to save the output of the best cluster.
452 silhouette (str): Whether to generate silhouette plots ('True' or 'False').
453
454 Returns:
455 None
456 """
457 if not os.path.exists(args.output_path):
458 os.makedirs(args.output_path)
459
460 plt.figure(figsize=(10, 7))
461 plt.title("Customer Dendograms")
462 shc.dendrogram(shc.linkage(dataset, method='ward'), labels=dataset.index.values.tolist())
463 fig = plt.gcf()
464 fig.savefig(f'{args.output_path}/dendogram.png', dpi=200)
465
466 range_n_clusters = [i for i in range(k_min, k_max+1)]
467
468 scores = []
469 labels = []
470
471 n_classi = dataset.shape[0]
472
473 for n_clusters in range_n_clusters:
474 cluster = AgglomerativeClustering(n_clusters=n_clusters, affinity='euclidean', linkage='ward')
475 cluster.fit_predict(dataset)
476 cluster_labels = cluster.labels_
477 labels.append(cluster_labels)
478 write_to_csv(dataset, cluster_labels, f'{args.output_path}/hierarchical_with_' + str(n_clusters) + '_clusters.tsv')
479
480 best = max_index(scores) + k_min
481
482 for i in range(len(labels)):
483 prefix = ''
484 if (i + k_min == best):
485 prefix = '_BEST'
486 if silhouette == 'true':
487 silhouette_draw(dataset, labels[i], i + k_min, f'{args.output_path}/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png')
488
489 for i in range(len(labels)):
490 if (i + k_min == best):
491 labels = labels[i]
492 predict = [x+1 for x in labels]
493 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
494 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class'])
495
496
497 ############################# main ###########################################
498 def main(args_in:List[str] = None) -> None:
499 """
500 Initializes everything and sets the program in motion based on the fronted input arguments.
501
502 Returns:
503 None
504 """
505 global args
506 args = process_args(args_in)
507
508 if not os.path.exists(args.output_path):
509 os.makedirs(args.output_path)
510
511 #Data read
512
513 X = read_dataset(args.input)
514 X = X.iloc[:, 1:]
515 X = pd.DataFrame.to_dict(X, orient='list')
516 X = rewrite_input(X)
517 X = pd.DataFrame.from_dict(X, orient = 'index')
518
519 for i in X.columns:
520 if any(val is None or np.isnan(val) for val in X[i]):
521 X = X.drop(columns=[i])
522
523 if args.scaling == "true":
524 list_to_remove = []
525 toll_std=1e-8
526 for i in X.columns:
527 mean_i = X[i].mean()
528 std_i = X[i].std()
529 if std_i >toll_std:
530 #scaling with mean 0 and std 1
531 X[i] = (X[i]-mean_i)/std_i
532 else:
533 #remove feature because std = 0 during clustering
534 list_to_remove.append(i)
535 if len(list_to_remove)>0:
536 X = X.drop(columns=list_to_remove)
537
538 if args.k_max != None:
539 numero_classi = X.shape[0]
540 while args.k_max >= numero_classi:
541 err = 'Skipping k = ' + str(args.k_max) + ' since it is >= number of classes of dataset'
542 warning(err)
543 args.k_max = args.k_max - 1
544
545
546 if args.cluster_type == 'kmeans':
547 kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.best_cluster)
548
549 if args.cluster_type == 'dbscan':
550 dbscan(X, args.eps, args.min_samples, args.best_cluster)
551
552 if args.cluster_type == 'hierarchy':
553 hierachical_agglomerative(X, args.k_min, args.k_max, args.best_cluster, args.silhouette)
554
555 ##############################################################################
556 if __name__ == "__main__":
557 main()