diff COBRAxy/marea_cluster.py @ 93:7e703e546998 draft

Uploaded
author luca_milaz
date Sun, 13 Oct 2024 11:41:34 +0000 (3 months ago)
parents 41f35c2f0c7b
children 3fca9b568faf
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/COBRAxy/marea_cluster.py	Sun Oct 13 11:41:34 2024 +0000
@@ -0,0 +1,534 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Mon Jun 3 19:51:00 2019
+@author: Narger
+"""
+
+import sys
+import argparse
+import os
+import numpy as np
+import pandas as pd
+from sklearn.datasets import make_blobs
+from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
+from sklearn.metrics import silhouette_samples, silhouette_score, cluster
+import matplotlib
+matplotlib.use('agg')
+import matplotlib.pyplot as plt
+import scipy.cluster.hierarchy as shc   
+import matplotlib.cm as cm
+from typing import Optional, Dict, List
+
+################################# process args ###############################
+def process_args(args :List[str]) -> argparse.Namespace:
+    """
+    Processes command-line arguments.
+
+    Args:
+        args (list): List of command-line arguments.
+
+    Returns:
+        Namespace: An object containing parsed arguments.
+    """
+    parser = argparse.ArgumentParser(usage = '%(prog)s [options]',
+                                     description = 'process some value\'s' +
+                                     ' genes to create class.')
+
+    parser.add_argument('-ol', '--out_log', 
+                        help = "Output log")
+    
+    parser.add_argument('-in', '--input',
+                        type = str,
+                        help = 'input dataset')
+    
+    parser.add_argument('-cy', '--cluster_type',
+                        type = str,
+                        choices = ['kmeans', 'dbscan', 'hierarchy'],
+                        default = 'kmeans',
+                        help = 'choose clustering algorythm')
+    
+    parser.add_argument('-k1', '--k_min', 
+                        type = int,
+                        default = 2,
+                        help = 'choose minimun cluster number to be generated')
+    
+    parser.add_argument('-k2', '--k_max', 
+                        type = int,
+                        default = 7,
+                        help = 'choose maximum cluster number to be generated')
+    
+    parser.add_argument('-el', '--elbow', 
+                        type = str,
+                        default = 'false',
+                        choices = ['true', 'false'],
+                        help = 'choose if you want to generate an elbow plot for kmeans')
+    
+    parser.add_argument('-si', '--silhouette', 
+                        type = str,
+                        default = 'false',
+                        choices = ['true', 'false'],
+                        help = 'choose if you want silhouette plots')
+    
+    parser.add_argument('-td', '--tool_dir',
+                        type = str,
+                        required = True,
+                        help = 'your tool directory')
+                        
+    parser.add_argument('-ms', '--min_samples',
+                        type = float,
+                        help = 'min samples for dbscan (optional)')
+                        
+    parser.add_argument('-ep', '--eps',
+                        type = float,
+                        help = 'eps for dbscan (optional)')
+                        
+    parser.add_argument('-bc', '--best_cluster',
+                        type = str,
+                        help = 'output of best cluster tsv')
+    				
+    
+    
+    args = parser.parse_args()
+    return args
+
+########################### warning ###########################################
+def warning(s :str) -> None:
+    """
+    Log a warning message to an output log file and print it to the console.
+
+    Args:
+        s (str): The warning message to be logged and printed.
+    
+    Returns:
+      None
+    """
+    args = process_args(sys.argv)
+    with open(args.out_log, 'a') as log:
+        log.write(s + "\n\n")
+    print(s)
+
+########################## read dataset ######################################
+def read_dataset(dataset :str) -> pd.DataFrame:
+    """
+    Read dataset from a CSV file and return it as a Pandas DataFrame.
+
+    Args:
+        dataset (str): the path to the dataset to convert into a DataFrame
+
+    Returns:
+        pandas.DataFrame: The dataset loaded as a Pandas DataFrame.
+
+    Raises:
+        pandas.errors.EmptyDataError: If the dataset file is empty.
+        sys.exit: If the dataset file has the wrong format (e.g., fewer than 2 columns)
+    """
+    try:
+        dataset = pd.read_csv(dataset, sep = '\t', header = 0)
+    except pd.errors.EmptyDataError:
+        sys.exit('Execution aborted: wrong format of dataset\n')
+    if len(dataset.columns) < 2:
+        sys.exit('Execution aborted: wrong format of dataset\n')
+    return dataset
+
+############################ rewrite_input ###################################
+def rewrite_input(dataset :pd.DataFrame) -> Dict[str, List[Optional[float]]]:
+    """
+    Rewrite the dataset as a dictionary of lists instead of as a dictionary of dictionaries.
+
+    Args:
+        dataset (pandas.DataFrame): The dataset to be rewritten.
+
+    Returns:
+        dict: The rewritten dataset as a dictionary of lists.
+    """
+    #Riscrivo il dataset come dizionario di liste, 
+    #non come dizionario di dizionari
+    
+    dataset.pop('Reactions', None)
+    
+    for key, val in dataset.items():
+        l = []
+        for i in val:
+            if i == 'None':
+                l.append(None)
+            else:
+                l.append(float(i))
+   
+        dataset[key] = l
+    
+    return dataset
+
+############################## write to csv ##################################
+def write_to_csv (dataset :pd.DataFrame, labels :List[str], name :str) -> None:
+    """
+    Write dataset and predicted labels to a CSV file.
+
+    Args:
+        dataset (pandas.DataFrame): The dataset to be written.
+        labels (list): The predicted labels for each data point.
+        name (str): The name of the output CSV file.
+
+    Returns:
+        None
+    """
+    #labels = predict
+    predict = [x+1 for x in labels]
+  
+    classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
+
+    dest = name
+    classe.to_csv(dest, sep = '\t', index = False,
+                      header = ['Patient_ID', 'Class'])
+   
+########################### trova il massimo in lista ########################
+def max_index (lista :List[int]) -> int:
+    """
+    Find the index of the maximum value in a list.
+
+    Args:
+        lista (list): The list in which we search for the index of the maximum value.
+
+    Returns:
+        int: The index of the maximum value in the list.
+    """
+    best = -1
+    best_index = 0
+    for i in range(len(lista)):
+        if lista[i] > best:
+            best = lista [i]
+            best_index = i
+            
+    return best_index
+    
+################################ kmeans #####################################
+def kmeans (k_min: int, k_max: int, dataset: pd.DataFrame, elbow: str, silhouette: str, best_cluster: str) -> None:
+    """
+    Perform k-means clustering on the given dataset, which is an algorithm used to partition a dataset into groups (clusters) based on their characteristics.
+    The goal is to divide the data into homogeneous groups, where the elements within each group are similar to each other and different from the elements in other groups.
+
+    Args:
+        k_min (int): The minimum number of clusters to consider.
+        k_max (int): The maximum number of clusters to consider.
+        dataset (pandas.DataFrame): The dataset to perform clustering on.
+        elbow (str): Whether to generate an elbow plot for kmeans ('true' or 'false').
+        silhouette (str): Whether to generate silhouette plots ('true' or 'false').
+        best_cluster (str): The file path to save the output of the best cluster.
+
+    Returns:
+        None
+    """
+    if not os.path.exists('clustering'):
+        os.makedirs('clustering')
+    
+        
+    if elbow == 'true':
+        elbow = True
+    else:
+        elbow = False
+        
+    if silhouette == 'true':
+        silhouette = True
+    else:
+        silhouette = False
+        
+    range_n_clusters = [i for i in range(k_min, k_max+1)]
+    distortions = []
+    scores = []
+    all_labels = []
+    
+    clusterer = KMeans(n_clusters=1, random_state=10)
+    distortions.append(clusterer.fit(dataset).inertia_)
+    
+    
+    for n_clusters in range_n_clusters:
+        clusterer = KMeans(n_clusters=n_clusters, random_state=10)
+        cluster_labels = clusterer.fit_predict(dataset)
+        
+        all_labels.append(cluster_labels)
+        if n_clusters == 1:
+            silhouette_avg = 0
+        else:
+            silhouette_avg = silhouette_score(dataset, cluster_labels)
+        scores.append(silhouette_avg)
+        distortions.append(clusterer.fit(dataset).inertia_)
+        
+    best = max_index(scores) + k_min
+        
+    for i in range(len(all_labels)):
+        prefix = ''
+        if (i + k_min == best):
+            prefix = '_BEST'
+            
+        write_to_csv(dataset, all_labels[i], 'clustering/kmeans_with_' + str(i + k_min) + prefix + '_clusters.tsv')
+        
+        
+        if (prefix == '_BEST'):
+            labels = all_labels[i]
+            predict = [x+1 for x in labels]
+            classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
+            classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class'])
+            
+          
+        
+       
+        if silhouette:
+            silhouette_draw(dataset, all_labels[i], i + k_min, 'clustering/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png')
+        
+        
+    if elbow:
+        elbow_plot(distortions, k_min,k_max) 
+
+   
+    
+    
+
+############################## elbow_plot ####################################
+def elbow_plot (distortions: List[float], k_min: int, k_max: int) -> None:
+    """
+    Generate an elbow plot to visualize the distortion for different numbers of clusters.
+    The elbow plot is a graphical tool used in clustering analysis to help identifying the appropriate number of clusters by looking for the point where the rate of decrease
+    in distortion sharply decreases, indicating the optimal balance between model complexity and clustering quality.
+
+    Args:
+        distortions (list): List of distortion values for different numbers of clusters.
+        k_min (int): The minimum number of clusters considered.
+        k_max (int): The maximum number of clusters considered.
+
+    Returns:
+        None
+    """
+    plt.figure(0)
+    x = list(range(k_min, k_max + 1))
+    x.insert(0, 1)
+    plt.plot(x, distortions, marker = 'o')
+    plt.xlabel('Number of clusters (k)')
+    plt.ylabel('Distortion')
+    s = 'clustering/elbow_plot.png'
+    fig = plt.gcf()
+    fig.set_size_inches(18.5, 10.5, forward = True)
+    fig.savefig(s, dpi=100)
+    
+    
+############################## silhouette plot ###############################
+def silhouette_draw(dataset: pd.DataFrame, labels: List[str], n_clusters: int, path:str) -> None:
+    """
+    Generate a silhouette plot for the clustering results.
+    The silhouette coefficient is a measure used to evaluate the quality of clusters obtained from a clustering algorithmand it quantifies how similar an object is to its own cluster compared to other clusters.
+    The silhouette coefficient ranges from -1 to 1, where:
+    - A value close to +1 indicates that the object is well matched to its own cluster and poorly matched to neighboring clusters. This implies that the object is in a dense, well-separated cluster.
+    - A value close to 0 indicates that the object is close to the decision boundary between two neighboring clusters.
+    - A value close to -1 indicates that the object may have been assigned to the wrong cluster.
+
+    Args:
+        dataset (pandas.DataFrame): The dataset used for clustering.
+        labels (list): The cluster labels assigned to each data point.
+        n_clusters (int): The number of clusters.
+        path (str): The path to save the silhouette plot image.
+
+    Returns:
+        None
+    """
+    if n_clusters == 1:
+        return None
+        
+    silhouette_avg = silhouette_score(dataset, labels)
+    warning("For n_clusters = " + str(n_clusters) +
+          " The average silhouette_score is: " + str(silhouette_avg))
+           
+    plt.close('all')
+    # Create a subplot with 1 row and 2 columns
+    fig, (ax1) = plt.subplots(1, 1)
+    
+    fig.set_size_inches(18, 7)
+        
+    # The 1st subplot is the silhouette plot
+    # The silhouette coefficient can range from -1, 1 but in this example all
+    # lie within [-0.1, 1]
+    ax1.set_xlim([-1, 1])
+    # The (n_clusters+1)*10 is for inserting blank space between silhouette
+    # plots of individual clusters, to demarcate them clearly.
+    ax1.set_ylim([0, len(dataset) + (n_clusters + 1) * 10])
+    
+    # Compute the silhouette scores for each sample
+    sample_silhouette_values = silhouette_samples(dataset, labels)
+        
+    y_lower = 10
+    for i in range(n_clusters):
+        # Aggregate the silhouette scores for samples belonging to
+        # cluster i, and sort them
+        ith_cluster_silhouette_values = \
+        sample_silhouette_values[labels == i]
+        
+        ith_cluster_silhouette_values.sort()
+    
+        size_cluster_i = ith_cluster_silhouette_values.shape[0]
+        y_upper = y_lower + size_cluster_i
+    
+        color = cm.nipy_spectral(float(i) / n_clusters)
+        ax1.fill_betweenx(np.arange(y_lower, y_upper),
+                          0, ith_cluster_silhouette_values,
+                                     facecolor=color, edgecolor=color, alpha=0.7)
+        
+        # Label the silhouette plots with their cluster numbers at the middle
+        ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
+        
+        # Compute the new y_lower for next plot
+        y_lower = y_upper + 10  # 10 for the 0 samples
+    
+        ax1.set_title("The silhouette plot for the various clusters.")
+        ax1.set_xlabel("The silhouette coefficient values")
+        ax1.set_ylabel("Cluster label")
+        
+        # The vertical line for average silhouette score of all the values
+        ax1.axvline(x=silhouette_avg, color="red", linestyle="--")
+    
+        ax1.set_yticks([])  # Clear the yaxis labels / ticks
+        ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])
+        
+        
+        plt.suptitle(("Silhouette analysis for clustering on sample data "
+                          "with n_clusters = " + str(n_clusters) + "\nAverage silhouette_score = " + str(silhouette_avg)), fontsize=12, fontweight='bold')
+            
+            
+        plt.savefig(path, bbox_inches='tight')
+            
+######################## dbscan ##############################################
+def dbscan(dataset: pd.DataFrame, eps: float, min_samples: float, best_cluster: str) -> None:
+    """
+    Perform DBSCAN clustering on the given dataset, which is a clustering algorithm that groups together closely packed points based on the notion of density.
+
+    Args:
+        dataset (pandas.DataFrame): The dataset to be clustered.
+        eps (float): The maximum distance between two samples for one to be considered as in the neighborhood of the other.
+        min_samples (float): The number of samples in a neighborhood for a point to be considered as a core point.
+        best_cluster (str): The file path to save the output of the best cluster.
+
+    Returns:
+        None
+    """
+    if not os.path.exists('clustering'):
+        os.makedirs('clustering')
+        
+    if eps is not None:
+        clusterer = DBSCAN(eps = eps, min_samples = min_samples)
+    else:
+        clusterer = DBSCAN()
+    
+    clustering = clusterer.fit(dataset)
+    
+    core_samples_mask = np.zeros_like(clustering.labels_, dtype=bool)
+    core_samples_mask[clustering.core_sample_indices_] = True
+    labels = clustering.labels_
+
+    # Number of clusters in labels, ignoring noise if present.
+    n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
+    
+    
+    labels = labels
+    predict = [x+1 for x in labels]
+    classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
+    classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class'])
+  
+    
+########################## hierachical #######################################
+def hierachical_agglomerative(dataset: pd.DataFrame, k_min: int, k_max: int, best_cluster: str, silhouette: str) -> None:
+    """
+    Perform hierarchical agglomerative clustering on the given dataset.
+
+    Args:
+        dataset (pandas.DataFrame): The dataset to be clustered.
+        k_min (int): The minimum number of clusters to consider.
+        k_max (int): The maximum number of clusters to consider.
+        best_cluster (str): The file path to save the output of the best cluster.
+        silhouette (str): Whether to generate silhouette plots ('true' or 'false').
+
+    Returns:
+        None
+    """
+    if not os.path.exists('clustering'):
+        os.makedirs('clustering')
+    
+    plt.figure(figsize=(10, 7))  
+    plt.title("Customer Dendograms")  
+    shc.dendrogram(shc.linkage(dataset, method='ward'), labels=dataset.index.values.tolist())  
+    fig = plt.gcf()
+    fig.savefig('clustering/dendogram.png', dpi=200)
+    
+    range_n_clusters = [i for i in range(k_min, k_max+1)]
+
+    scores = []
+    labels = []
+    
+    n_classi = dataset.shape[0]
+    
+    for n_clusters in range_n_clusters:  
+        cluster = AgglomerativeClustering(n_clusters=n_clusters, affinity='euclidean', linkage='ward')  
+        cluster.fit_predict(dataset)  
+        cluster_labels = cluster.labels_
+        labels.append(cluster_labels)
+        write_to_csv(dataset, cluster_labels, 'clustering/hierarchical_with_' + str(n_clusters) + '_clusters.tsv')
+        
+    best = max_index(scores) + k_min
+    
+    for i in range(len(labels)):
+        prefix = ''
+        if (i + k_min == best):
+            prefix = '_BEST'
+        if silhouette == 'true':
+            silhouette_draw(dataset, labels[i], i + k_min, 'clustering/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png')
+     
+    for i in range(len(labels)):
+        if (i + k_min == best):
+            labels = labels[i]
+            predict = [x+1 for x in labels]
+            classe = (pd.DataFrame(list(zip(dataset.index, predict)))).astype(str)
+            classe.to_csv(best_cluster, sep = '\t', index = False, header = ['Patient_ID', 'Class'])
+            
+    
+############################# main ###########################################
+def main() -> None:
+    """
+    Initializes everything and sets the program in motion based on the fronted input arguments.
+
+    Returns:
+        None
+    """
+    if not os.path.exists('clustering'):
+        os.makedirs('clustering')
+
+    args = process_args(sys.argv)
+    
+    #Data read
+    
+    X = read_dataset(args.input)
+    X = pd.DataFrame.to_dict(X, orient='list')
+    X = rewrite_input(X)
+    X = pd.DataFrame.from_dict(X, orient = 'index')
+    
+    for i in X.columns:
+        tmp = X[i][0]
+        if tmp == None:
+            X = X.drop(columns=[i])
+
+    ## NAN TO HANLDE
+            
+    if args.k_max != None:
+       numero_classi = X.shape[0]
+       while args.k_max >= numero_classi:
+          err = 'Skipping k = ' + str(args.k_max) + ' since it is >= number of classes of dataset'
+          warning(err)
+          args.k_max = args.k_max - 1
+    
+    
+    if args.cluster_type == 'kmeans':
+        kmeans(args.k_min, args.k_max, X, args.elbow, args.silhouette, args.best_cluster)
+    
+    if args.cluster_type == 'dbscan':
+        dbscan(X, args.eps, args.min_samples, args.best_cluster)
+        
+    if args.cluster_type == 'hierarchy':
+        hierachical_agglomerative(X, args.k_min, args.k_max, args.best_cluster, args.silhouette)
+        
+##############################################################################
+if __name__ == "__main__":
+    main()