diff main.py @ 4:e7b4afedc471 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
author bgruening
date Tue, 11 Feb 2025 10:14:12 +0000
parents 33d53eb476fd
children 49b4ee0d0965
line wrap: on
line diff
--- a/main.py	Mon Jan 20 15:45:17 2025 +0000
+++ b/main.py	Tue Feb 11 10:14:12 2025 +0000
@@ -11,8 +11,9 @@
     average_precision_score,
     precision_recall_curve,
     r2_score,
-    root_mean_squared_error
+    root_mean_squared_error,
 )
+from sklearn.preprocessing import label_binarize
 from tabpfn import TabPFNClassifier, TabPFNRegressor
 
 
@@ -23,12 +24,42 @@
     return features, labels
 
 
-def classification_plot(xval, yval, leg_label, title, xlabel, ylabel):
+def classification_plot(y_true, y_scores):
     plt.figure(figsize=(8, 6))
-    plt.plot(xval, yval, label=leg_label)
-    plt.xlabel(xlabel)
-    plt.ylabel(ylabel)
-    plt.title(title)
+    is_binary = len(np.unique(y_true)) == 2
+    if is_binary:
+        # Compute precision-recall curve
+        precision, recall, _ = precision_recall_curve(y_true, y_scores[:, 1])
+        average_precision = average_precision_score(y_true, y_scores[:, 1])
+        plt.plot(
+            recall,
+            precision,
+            label=f"Precision-Recall Curve (AP={average_precision:.2f})",
+        )
+        plt.title("Precision-Recall Curve (binary classification)")
+    else:
+        y_true_bin = label_binarize(y_true, classes=np.unique(y_true))
+        n_classes = y_true_bin.shape[1]
+        class_labels = [f"Class {i}" for i in range(n_classes)]
+        # Plot PR curve for each class
+        for i in range(n_classes):
+            precision, recall, _ = precision_recall_curve(
+                y_true_bin[:, i], y_scores[:, i]
+            )
+            ap_score = average_precision_score(y_true_bin[:, i], y_scores[:, i])
+            plt.plot(
+                recall, precision, label=f"{class_labels[i]} (AP = {ap_score:.2f})"
+            )
+        # Compute micro-average PR curve
+        precision, recall, _ = precision_recall_curve(
+            y_true_bin.ravel(), y_scores.ravel()
+        )
+        plt.plot(
+            recall, precision, linestyle="--", color="black", label="Micro-average"
+        )
+        plt.title("Precision-Recall Curve (Multiclass Classification)")
+    plt.xlabel("Recall")
+    plt.ylabel("Precision")
     plt.legend(loc="lower left")
     plt.grid(True)
     plt.savefig("output_plot.png")
@@ -66,20 +97,7 @@
         y_eval = classifier.predict(te_features)
         pred_probas_test = classifier.predict_proba(te_features)
         if len(te_labels) > 0:
-            precision, recall, thresholds = precision_recall_curve(
-                te_labels, pred_probas_test[:, 1]
-            )
-            average_precision = average_precision_score(
-                te_labels, pred_probas_test[:, 1]
-            )
-            classification_plot(
-                recall,
-                precision,
-                f"Precision-Recall Curve (AP={average_precision:.2f})",
-                "Precision-Recall Curve",
-                "Recall",
-                "Precision",
-            )
+            classification_plot(te_labels, pred_probas_test)
     else:
         regressor = TabPFNRegressor()
         regressor.fit(tr_features, tr_labels)