Mercurial > repos > bgruening > tabpfn
diff main.py @ 4:e7b4afedc471 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
author | bgruening |
---|---|
date | Tue, 11 Feb 2025 10:14:12 +0000 |
parents | 33d53eb476fd |
children | 49b4ee0d0965 |
line wrap: on
line diff
--- a/main.py Mon Jan 20 15:45:17 2025 +0000 +++ b/main.py Tue Feb 11 10:14:12 2025 +0000 @@ -11,8 +11,9 @@ average_precision_score, precision_recall_curve, r2_score, - root_mean_squared_error + root_mean_squared_error, ) +from sklearn.preprocessing import label_binarize from tabpfn import TabPFNClassifier, TabPFNRegressor @@ -23,12 +24,42 @@ return features, labels -def classification_plot(xval, yval, leg_label, title, xlabel, ylabel): +def classification_plot(y_true, y_scores): plt.figure(figsize=(8, 6)) - plt.plot(xval, yval, label=leg_label) - plt.xlabel(xlabel) - plt.ylabel(ylabel) - plt.title(title) + is_binary = len(np.unique(y_true)) == 2 + if is_binary: + # Compute precision-recall curve + precision, recall, _ = precision_recall_curve(y_true, y_scores[:, 1]) + average_precision = average_precision_score(y_true, y_scores[:, 1]) + plt.plot( + recall, + precision, + label=f"Precision-Recall Curve (AP={average_precision:.2f})", + ) + plt.title("Precision-Recall Curve (binary classification)") + else: + y_true_bin = label_binarize(y_true, classes=np.unique(y_true)) + n_classes = y_true_bin.shape[1] + class_labels = [f"Class {i}" for i in range(n_classes)] + # Plot PR curve for each class + for i in range(n_classes): + precision, recall, _ = precision_recall_curve( + y_true_bin[:, i], y_scores[:, i] + ) + ap_score = average_precision_score(y_true_bin[:, i], y_scores[:, i]) + plt.plot( + recall, precision, label=f"{class_labels[i]} (AP = {ap_score:.2f})" + ) + # Compute micro-average PR curve + precision, recall, _ = precision_recall_curve( + y_true_bin.ravel(), y_scores.ravel() + ) + plt.plot( + recall, precision, linestyle="--", color="black", label="Micro-average" + ) + plt.title("Precision-Recall Curve (Multiclass Classification)") + plt.xlabel("Recall") + plt.ylabel("Precision") plt.legend(loc="lower left") plt.grid(True) plt.savefig("output_plot.png") @@ -66,20 +97,7 @@ y_eval = classifier.predict(te_features) pred_probas_test = classifier.predict_proba(te_features) if len(te_labels) > 0: - precision, recall, thresholds = precision_recall_curve( - te_labels, pred_probas_test[:, 1] - ) - average_precision = average_precision_score( - te_labels, pred_probas_test[:, 1] - ) - classification_plot( - recall, - precision, - f"Precision-Recall Curve (AP={average_precision:.2f})", - "Precision-Recall Curve", - "Recall", - "Precision", - ) + classification_plot(te_labels, pred_probas_test) else: regressor = TabPFNRegressor() regressor.fit(tr_features, tr_labels)