comparison cobraxy-9688ad27287b/COBRAxy/marea_cluster.py @ 90:a48b2e06ebe7 draft

Uploaded
author luca_milaz
date Sun, 13 Oct 2024 11:35:56 +0000
parents
children
comparison
equal deleted inserted replaced
89:6ddfc81e97d1 90:a48b2e06ebe7
1 # -*- coding: utf-8 -*-
2 """
3 Created on Mon Jun 3 19:51:00 2019
4 @author: Narger
5 """
6
7 import sys
8 import argparse
9 import os
10 import numpy as np
11 import pandas as pd
12 from sklearn.datasets import make_blobs
13 from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
14 from sklearn.metrics import silhouette_samples, silhouette_score, cluster
15 import matplotlib
16 matplotlib.use('agg')
17 import matplotlib.pyplot as plt
18 import scipy.cluster.hierarchy as shc
19 import matplotlib.cm as cm
20 from typing import Optional, Dict, List
21
22 ################################# process args ###############################
23 def process_args(args :List[str]) -> argparse.Namespace:
24 """
25 Processes command-line arguments.
26
27 Args:
28 args (list): List of command-line arguments.
29
30 Returns:
31 Namespace: An object containing parsed arguments.
32 """
33 parser = argparse.ArgumentParser(usage = '%(prog)s [options]',
34 description = 'process some value\'s' +
35 ' genes to create class.')
36
37 parser.add_argument('-ol', '--out_log',
38 help = "Output log")
39
40 parser.add_argument('-in', '--input',
41 type = str,
42 help = 'input dataset')
43
44 parser.add_argument('-cy', '--cluster_type',
45 type = str,
46 choices = ['kmeans', 'dbscan', 'hierarchy'],
47 default = 'kmeans',
48 help = 'choose clustering algorythm')
49
50 parser.add_argument('-k1', '--k_min',
51 type = int,
52 default = 2,
53 help = 'choose minimun cluster number to be generated')
54
55 parser.add_argument('-k2', '--k_max',
56 type = int,
57 default = 7,
58 help = 'choose maximum cluster number to be generated')
59
60 parser.add_argument('-el', '--elbow',
61 type = str,
62 default = 'false',
63 choices = ['true', 'false'],
64 help = 'choose if you want to generate an elbow plot for kmeans')
65
66 parser.add_argument('-si', '--silhouette',
67 type = str,
68 default = 'false',
69 choices = ['true', 'false'],
70 help = 'choose if you want silhouette plots')
71
72 parser.add_argument('-td', '--tool_dir',
73 type = str,
74 required = True,
75 help = 'your tool directory')
76
77 parser.add_argument('-ms', '--min_samples',
78 type = float,
79 help = 'min samples for dbscan (optional)')
80
81 parser.add_argument('-ep', '--eps',
82 type = float,
83 help = 'eps for dbscan (optional)')
84
85 parser.add_argument('-bc', '--best_cluster',
86 type = str,
87 help = 'output of best cluster tsv')
88
89
90
91 args = parser.parse_args()
92 return args
93
94 ########################### warning ###########################################
95 def warning(s :str) -> None:
96 """
97 Log a warning message to an output log file and print it to the console.
98
99 Args:
100 s (str): The warning message to be logged and printed.
101
102 Returns:
103 None
104 """
105 args = process_args(sys.argv)
106 with open(args.out_log, 'a') as log:
107 log.write(s + "\n\n")
108 print(s)
109
110 ########################## read dataset ######################################
111 def read_dataset(dataset :str) -> pd.DataFrame:
112 """
113 Read dataset from a CSV file and return it as a Pandas DataFrame.
114
115 Args:
116 dataset (str): the path to the dataset to convert into a DataFrame
117
118 Returns:
119 pandas.DataFrame: The dataset loaded as a Pandas DataFrame.
120
121 Raises:
122 pandas.errors.EmptyDataError: If the dataset file is empty.
123 sys.exit: If the dataset file has the wrong format (e.g., fewer than 2 columns)
124 """
125 try:
126 dataset = pd.read_csv(dataset, sep = '\t', header = 0)
127 except pd.errors.EmptyDataError:
128 sys.exit('Execution aborted: wrong format of dataset\n')
129 if len(dataset.columns) < 2:
130 sys.exit('Execution aborted: wrong format of dataset\n')
131 return dataset
132
133 ############################ rewrite_input ###################################
134 def rewrite_input(dataset :pd.DataFrame) -> Dict[str, List[Optional[float]]]:
135 """
136 Rewrite the dataset as a dictionary of lists instead of as a dictionary of dictionaries.
137
138 Args:
139 dataset (pandas.DataFrame): The dataset to be rewritten.
140
141 Returns:
142 dict: The rewritten dataset as a dictionary of lists.
143 """
144 #Riscrivo il dataset come dizionario di liste,
145 #non come dizionario di dizionari
146
147 dataset.pop('Reactions', None)
148
149 for key, val in dataset.items():
150 l = []
151 for i in val:
152 if i == 'None':
153 l.append(None)
154 else:
155 l.append(float(i))
156
157 dataset[key] = l
158
159 return dataset
160
161 ############################## write to csv ##################################
162 def write_to_csv (dataset :pd.DataFrame, labels :List[str], name :str) -> None:
163 """
164 Write dataset and predicted labels to a CSV file.
165
166 Args:
167 dataset (pandas.DataFrame): The dataset to be written.
168 labels (list): The predicted labels for each data point.
169 name (str): The name of the output CSV file.
170
171 Returns:
172 None
173 """
174 #labels = predict
175 predict = [x+1 for x in labels]
176
177 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
178
179 dest = name
180 classe.to_csv(dest, sep = '\t', index = False,
181 header = ['Patient_ID', 'Class'])
182
183 ########################### trova il massimo in lista ########################
184 def max_index (lista :List[int]) -> int:
185 """
186 Find the index of the maximum value in a list.
187
188 Args:
189 lista (list): The list in which we search for the index of the maximum value.
190
191 Returns:
192 int: The index of the maximum value in the list.
193 """
194 best = -1
195 best_index = 0
196 for i in range(len(lista)):
197 if lista[i] > best:
198 best = lista [i]
199 best_index = i
200
201 return best_index
202
203 ################################ kmeans #####################################
204 def kmeans (k_min: int, k_max: int, dataset: pd.DataFrame, elbow: str, silhouette: str, best_cluster: str) -> None:
205 """
206 Perform k-means clustering on the given dataset, which is an algorithm used to partition a dataset into groups (clusters) based on their characteristics.
207 The goal is to divide the data into homogeneous groups, where the elements within each group are similar to each other and different from the elements in other groups.
208
209 Args:
210 k_min (int): The minimum number of clusters to consider.
211 k_max (int): The maximum number of clusters to consider.
212 dataset (pandas.DataFrame): The dataset to perform clustering on.
213 elbow (str): Whether to generate an elbow plot for kmeans ('true' or 'false').
214 silhouette (str): Whether to generate silhouette plots ('true' or 'false').
215 best_cluster (str): The file path to save the output of the best cluster.
216
217 Returns:
218 None
219 """
220 if not os.path.exists('clustering'):
221 os.makedirs('clustering')
222
223
224 if elbow == 'true':
225 elbow = True
226 else:
227 elbow = False
228
229 if silhouette == 'true':
230 silhouette = True
231 else:
232 silhouette = False
233
234 range_n_clusters = [i for i in range(k_min, k_max+1)]
235 distortions = []
236 scores = []
237 all_labels = []
238
239 clusterer = KMeans(n_clusters=1, random_state=10)
240 distortions.append(clusterer.fit(dataset).inertia_)
241
242
243 for n_clusters in range_n_clusters:
244 clusterer = KMeans(n_clusters=n_clusters, random_state=10)
245 cluster_labels = clusterer.fit_predict(dataset)
246
247 all_labels.append(cluster_labels)
248 if n_clusters == 1:
249 silhouette_avg = 0
250 else:
251 silhouette_avg = silhouette_score(dataset, cluster_labels)
252 scores.append(silhouette_avg)
253 distortions.append(clusterer.fit(dataset).inertia_)
254
255 best = max_index(scores) + k_min
256
257 for i in range(len(all_labels)):
258 prefix = ''
259 if (i + k_min == best):
260 prefix = '_BEST'
261
262 write_to_csv(dataset, all_labels[i], 'clustering/kmeans_with_' + str(i + k_min) + prefix + '_clusters.tsv')
263
264
265 if (prefix == '_BEST'):
266 labels = all_labels[i]
267 predict = [x+1 for x in labels]
268 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
269 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class'])
270
271
272
273
274 if silhouette:
275 silhouette_draw(dataset, all_labels[i], i + k_min, 'clustering/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png')
276
277
278 if elbow:
279 elbow_plot(distortions, k_min,k_max)
280
281
282
283
284
285 ############################## elbow_plot ####################################
286 def elbow_plot (distortions: List[float], k_min: int, k_max: int) -> None:
287 """
288 Generate an elbow plot to visualize the distortion for different numbers of clusters.
289 The elbow plot is a graphical tool used in clustering analysis to help identifying the appropriate number of clusters by looking for the point where the rate of decrease
290 in distortion sharply decreases, indicating the optimal balance between model complexity and clustering quality.
291
292 Args:
293 distortions (list): List of distortion values for different numbers of clusters.
294 k_min (int): The minimum number of clusters considered.
295 k_max (int): The maximum number of clusters considered.
296
297 Returns:
298 None
299 """
300 plt.figure(0)
301 x = list(range(k_min, k_max + 1))
302 x.insert(0, 1)
303 plt.plot(x, distortions, marker = 'o')
304 plt.xlabel('Number of clusters (k)')
305 plt.ylabel('Distortion')
306 s = 'clustering/elbow_plot.png'
307 fig = plt.gcf()
308 fig.set_size_inches(18.5, 10.5, forward = True)
309 fig.savefig(s, dpi=100)
310
311
312 ############################## silhouette plot ###############################
313 def silhouette_draw(dataset: pd.DataFrame, labels: List[str], n_clusters: int, path:str) -> None:
314 """
315 Generate a silhouette plot for the clustering results.
316 The silhouette coefficient is a measure used to evaluate the quality of clusters obtained from a clustering algorithmand it quantifies how similar an object is to its own cluster compared to other clusters.
317 The silhouette coefficient ranges from -1 to 1, where:
318 - A value close to +1 indicates that the object is well matched to its own cluster and poorly matched to neighboring clusters. This implies that the object is in a dense, well-separated cluster.
319 - A value close to 0 indicates that the object is close to the decision boundary between two neighboring clusters.
320 - A value close to -1 indicates that the object may have been assigned to the wrong cluster.
321
322 Args:
323 dataset (pandas.DataFrame): The dataset used for clustering.
324 labels (list): The cluster labels assigned to each data point.
325 n_clusters (int): The number of clusters.
326 path (str): The path to save the silhouette plot image.
327
328 Returns:
329 None
330 """
331 if n_clusters == 1:
332 return None
333
334 silhouette_avg = silhouette_score(dataset, labels)
335 warning("For n_clusters = " + str(n_clusters) +
336 " The average silhouette_score is: " + str(silhouette_avg))
337
338 plt.close('all')
339 # Create a subplot with 1 row and 2 columns
340 fig, (ax1) = plt.subplots(1, 1)
341
342 fig.set_size_inches(18, 7)
343
344 # The 1st subplot is the silhouette plot
345 # The silhouette coefficient can range from -1, 1 but in this example all
346 # lie within [-0.1, 1]
347 ax1.set_xlim([-1, 1])
348 # The (n_clusters+1)*10 is for inserting blank space between silhouette
349 # plots of individual clusters, to demarcate them clearly.
350 ax1.set_ylim([0, len(dataset) + (n_clusters + 1) * 10])
351
352 # Compute the silhouette scores for each sample
353 sample_silhouette_values = silhouette_samples(dataset, labels)
354
355 y_lower = 10
356 for i in range(n_clusters):
357 # Aggregate the silhouette scores for samples belonging to
358 # cluster i, and sort them
359 ith_cluster_silhouette_values = \
360 sample_silhouette_values[labels == i]
361
362 ith_cluster_silhouette_values.sort()
363
364 size_cluster_i = ith_cluster_silhouette_values.shape[0]
365 y_upper = y_lower + size_cluster_i
366
367 color = cm.nipy_spectral(float(i) / n_clusters)
368 ax1.fill_betweenx(np.arange(y_lower, y_upper),
369 0, ith_cluster_silhouette_values,
370 facecolor=color, edgecolor=color, alpha=0.7)
371
372 # Label the silhouette plots with their cluster numbers at the middle
373 ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
374
375 # Compute the new y_lower for next plot
376 y_lower = y_upper + 10 # 10 for the 0 samples
377
378 ax1.set_title("The silhouette plot for the various clusters.")
379 ax1.set_xlabel("The silhouette coefficient values")
380 ax1.set_ylabel("Cluster label")
381
382 # The vertical line for average silhouette score of all the values
383 ax1.axvline(x=silhouette_avg, color="red", linestyle="--")
384
385 ax1.set_yticks([]) # Clear the yaxis labels / ticks
386 ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])
387
388
389 plt.suptitle(("Silhouette analysis for clustering on sample data "
390 "with n_clusters = " + str(n_clusters) + "\nAverage silhouette_score = " + str(silhouette_avg)), fontsize=12, fontweight='bold')
391
392
393 plt.savefig(path, bbox_inches='tight')
394
395 ######################## dbscan ##############################################
396 def dbscan(dataset: pd.DataFrame, eps: float, min_samples: float, best_cluster: str) -> None:
397 """
398 Perform DBSCAN clustering on the given dataset, which is a clustering algorithm that groups together closely packed points based on the notion of density.
399
400 Args:
401 dataset (pandas.DataFrame): The dataset to be clustered.
402 eps (float): The maximum distance between two samples for one to be considered as in the neighborhood of the other.
403 min_samples (float): The number of samples in a neighborhood for a point to be considered as a core point.
404 best_cluster (str): The file path to save the output of the best cluster.
405
406 Returns:
407 None
408 """
409 if not os.path.exists('clustering'):
410 os.makedirs('clustering')
411
412 if eps is not None:
413 clusterer = DBSCAN(eps = eps, min_samples = min_samples)
414 else:
415 clusterer = DBSCAN()
416
417 clustering = clusterer.fit(dataset)
418
419 core_samples_mask = np.zeros_like(clustering.labels_, dtype=bool)
420 core_samples_mask[clustering.core_sample_indices_] = True
421 labels = clustering.labels_
422
423 # Number of clusters in labels, ignoring noise if present.
424 n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
425
426
427 labels = labels
428 predict = [x+1 for x in labels]
429 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
430 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class'])
431
432
433 ########################## hierachical #######################################
434 def hierachical_agglomerative(dataset: pd.DataFrame, k_min: int, k_max: int, best_cluster: str, silhouette: str) -> None:
435 """
436 Perform hierarchical agglomerative clustering on the given dataset.
437
438 Args:
439 dataset (pandas.DataFrame): The dataset to be clustered.
440 k_min (int): The minimum number of clusters to consider.
441 k_max (int): The maximum number of clusters to consider.
442 best_cluster (str): The file path to save the output of the best cluster.
443 silhouette (str): Whether to generate silhouette plots ('true' or 'false').
444
445 Returns:
446 None
447 """
448 if not os.path.exists('clustering'):
449 os.makedirs('clustering')
450
451 plt.figure(figsize=(10, 7))
452 plt.title("Customer Dendograms")
453 shc.dendrogram(shc.linkage(dataset, method='ward'), labels=dataset.index.values.tolist())
454 fig = plt.gcf()
455 fig.savefig('clustering/dendogram.png', dpi=200)
456
457 range_n_clusters = [i for i in range(k_min, k_max+1)]
458
459 scores = []
460 labels = []
461
462 n_classi = dataset.shape[0]
463
464 for n_clusters in range_n_clusters:
465 cluster = AgglomerativeClustering(n_clusters=n_clusters, affinity='euclidean', linkage='ward')
466 cluster.fit_predict(dataset)
467 cluster_labels = cluster.labels_
468 labels.append(cluster_labels)
469 write_to_csv(dataset, cluster_labels, 'clustering/hierarchical_with_' + str(n_clusters) + '_clusters.tsv')
470
471 best = max_index(scores) + k_min
472
473 for i in range(len(labels)):
474 prefix = ''
475 if (i + k_min == best):
476 prefix = '_BEST'
477 if silhouette == 'true':
478 silhouette_draw(dataset, labels[i], i + k_min, 'clustering/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png')
479
480 for i in range(len(labels)):
481 if (i + k_min == best):
482 labels = labels[i]
483 predict = [x+1 for x in labels]
484 classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
485 classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class'])
486
487
488 ############################# main ###########################################
489 def main() -> None:
490 """
491 Initializes everything and sets the program in motion based on the fronted input arguments.
492
493 Returns:
494 None
495 """
496 if not os.path.exists('clustering'):
497 os.makedirs('clustering')
498
499 args = process_args(sys.argv)
500
501 #Data read
502
503 X = read_dataset(args.input)
504 X = pd.DataFrame.to_dict(X, orient='list')
505 X = rewrite_input(X)
506 X = pd.DataFrame.from_dict(X, orient = 'index')
507
508 for i in X.columns:
509 tmp = X[i][0]
510 if tmp == None:
511 X = X.drop(columns=[i])
512
513 ## NAN TO HANLDE
514
515 if args.k_max != None:
516 numero_classi = X.shape[0]
517 while args.k_max >= numero_classi:
518 err = 'Skipping k = ' + str(args.k_max) + ' since it is >= number of classes of dataset'
519 warning(err)
520 args.k_max = args.k_max - 1
521
522
523 if args.cluster_type == 'kmeans':
524 kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.best_cluster)
525
526 if args.cluster_type == 'dbscan':
527 dbscan(X, args.eps, args.min_samples, args.best_cluster)
528
529 if args.cluster_type == 'hierarchy':
530 hierachical_agglomerative(X, args.k_min, args.k_max, args.best_cluster, args.silhouette)
531
532 ##############################################################################
533 if __name__ == "__main__":
534 main()