diff COBRAxy/marea_cluster.py @ 147:3fca9b568faf draft

Uploaded
author bimib
date Wed, 06 Nov 2024 13:57:24 +0000
parents 41f35c2f0c7b
children 7f3552eaf774
line wrap: on
line diff
--- a/COBRAxy/marea_cluster.py	Wed Nov 06 10:12:52 2024 +0000
+++ b/COBRAxy/marea_cluster.py	Wed Nov 06 13:57:24 2024 +0000
@@ -20,7 +20,7 @@
 from typing import Optional, Dict, List
 
 ################################# process args ###############################
-def process_args(args :List[str]) -> argparse.Namespace:
+def process_args(args :List[str] = None) -> argparse.Namespace:
     """
     Processes command-line arguments.
 
@@ -86,9 +86,13 @@
                         type = str,
                         help = 'output of best cluster tsv')
     				
-    
+    parser.add_argument(
+        '-idop', '--output_path', 
+        type = str,
+        default='result',
+        help = 'output path for maps')
     
-    args = parser.parse_args()
+    args = parser.parse_args(args)
     return args
 
 ########################### warning ###########################################
@@ -217,8 +221,8 @@
     Returns:
         None
     """
-    if not os.path.exists('clustering'):
-        os.makedirs('clustering')
+    if not os.path.exists(args.output_path):
+        os.makedirs(args.output_path)
     
         
     if elbow == 'true':
@@ -259,7 +263,7 @@
         if (i + k_min == best):
             prefix = '_BEST'
             
-        write_to_csv(dataset, all_labels[i], 'clustering/kmeans_with_' + str(i + k_min) + prefix + '_clusters.tsv')
+        write_to_csv(dataset, all_labels[i], f'{args.output_path}/kmeans_with_' + str(i + k_min) + prefix + '_clusters.tsv')
         
         
         if (prefix == '_BEST'):
@@ -272,7 +276,7 @@
         
        
         if silhouette:
-            silhouette_draw(dataset, all_labels[i], i + k_min, 'clustering/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png')
+            silhouette_draw(dataset, all_labels[i], i + k_min, f'{args.output_path}/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png')
         
         
     if elbow:
@@ -303,7 +307,7 @@
     plt.plot(x, distortions, marker = 'o')
     plt.xlabel('Number of clusters (k)')
     plt.ylabel('Distortion')
-    s = 'clustering/elbow_plot.png'
+    s = f'{args.output_path}/elbow_plot.png'
     fig = plt.gcf()
     fig.set_size_inches(18.5, 10.5, forward = True)
     fig.savefig(s, dpi=100)
@@ -406,8 +410,8 @@
     Returns:
         None
     """
-    if not os.path.exists('clustering'):
-        os.makedirs('clustering')
+    if not os.path.exists(args.output_path):
+        os.makedirs(args.output_path)
         
     if eps is not None:
         clusterer = DBSCAN(eps = eps, min_samples = min_samples)
@@ -445,14 +449,14 @@
     Returns:
         None
     """
-    if not os.path.exists('clustering'):
-        os.makedirs('clustering')
+    if not os.path.exists(args.output_path):
+        os.makedirs(args.output_path)
     
     plt.figure(figsize=(10, 7))  
     plt.title("Customer Dendograms")  
     shc.dendrogram(shc.linkage(dataset, method='ward'), labels=dataset.index.values.tolist())  
     fig = plt.gcf()
-    fig.savefig('clustering/dendogram.png', dpi=200)
+    fig.savefig(f'{args.output_path}/dendogram.png', dpi=200)
     
     range_n_clusters = [i for i in range(k_min, k_max+1)]
 
@@ -466,7 +470,7 @@
         cluster.fit_predict(dataset)  
         cluster_labels = cluster.labels_
         labels.append(cluster_labels)
-        write_to_csv(dataset, cluster_labels, 'clustering/hierarchical_with_' + str(n_clusters) + '_clusters.tsv')
+        write_to_csv(dataset, cluster_labels, f'{args.output_path}/hierarchical_with_' + str(n_clusters) + '_clusters.tsv')
         
     best = max_index(scores) + k_min
     
@@ -475,7 +479,7 @@
         if (i + k_min == best):
             prefix = '_BEST'
         if silhouette == 'true':
-            silhouette_draw(dataset, labels[i], i + k_min, 'clustering/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png')
+            silhouette_draw(dataset, labels[i], i + k_min, f'{args.output_path}/silhouette_with_' + str(i + k_min) + prefix + '_clusters.png')
      
     for i in range(len(labels)):
         if (i + k_min == best):
@@ -486,17 +490,18 @@
             
     
 ############################# main ###########################################
-def main() -> None:
+def main(args_in:List[str] = None) -> None:
     """
     Initializes everything and sets the program in motion based on the fronted input arguments.
 
     Returns:
         None
     """
-    if not os.path.exists('clustering'):
-        os.makedirs('clustering')
+    global args
+    args = process_args(args_in)
 
-    args = process_args(sys.argv)
+    if not os.path.exists(args.output_path):
+        os.makedirs(args.output_path)
     
     #Data read