Mercurial > repos > bgruening > tabpfn
view main.py @ 5:49b4ee0d0965 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
author | bgruening |
---|---|
date | Wed, 26 Mar 2025 16:32:51 +0000 |
parents | e7b4afedc471 |
children |
line wrap: on
line source
""" Tabular data prediction using TabPFN """ import argparse import time import matplotlib.pyplot as plt import numpy as np import pandas as pd from sklearn.metrics import ( average_precision_score, precision_recall_curve, r2_score, root_mean_squared_error, ) from sklearn.preprocessing import label_binarize from tabpfn import TabPFNClassifier, TabPFNRegressor def separate_features_labels(data): df = pd.read_csv(data, sep="\t") labels = df.iloc[:, -1] features = df.iloc[:, :-1] return features, labels def classification_plot(y_true, y_scores): plt.figure(figsize=(8, 6)) is_binary = len(np.unique(y_true)) == 2 if is_binary: # Compute precision-recall curve precision, recall, _ = precision_recall_curve(y_true, y_scores[:, 1]) average_precision = average_precision_score(y_true, y_scores[:, 1]) plt.plot( recall, precision, label=f"Precision-Recall Curve (AP={average_precision:.2f})", ) plt.title("Precision-Recall Curve (binary classification)") else: y_true_bin = label_binarize(y_true, classes=np.unique(y_true)) n_classes = y_true_bin.shape[1] class_labels = [f"Class {i}" for i in range(n_classes)] # Plot PR curve for each class for i in range(n_classes): precision, recall, _ = precision_recall_curve( y_true_bin[:, i], y_scores[:, i] ) ap_score = average_precision_score(y_true_bin[:, i], y_scores[:, i]) plt.plot( recall, precision, label=f"{class_labels[i]} (AP = {ap_score:.2f})" ) # Compute micro-average PR curve precision, recall, _ = precision_recall_curve( y_true_bin.ravel(), y_scores.ravel() ) plt.plot( recall, precision, linestyle="--", color="black", label="Micro-average" ) plt.title( "Precision-Recall Curve (Multiclass Classification)" ) plt.xlabel("Recall") plt.ylabel("Precision") plt.legend(loc="lower left") plt.grid(True) plt.savefig("output_plot.png") def regression_plot(xval, yval, title, xlabel, ylabel): plt.figure(figsize=(8, 6)) plt.xlabel(xlabel) plt.ylabel(ylabel) plt.title(title) plt.legend(loc="lower left") plt.grid(True) plt.scatter(xval, yval, alpha=0.8) xticks = np.arange(len(xval)) plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x") plt.savefig("output_plot.png") def train_evaluate(args): """ Train TabPFN and predict """ # prepare train data tr_features, tr_labels = separate_features_labels(args["train_data"]) # prepare test data if args["testhaslabels"] == "true": te_features, te_labels = separate_features_labels(args["test_data"]) else: te_features = pd.read_csv(args["test_data"], sep="\t") te_labels = [] s_time = time.time() if args["selected_task"] == "Classification": classifier = TabPFNClassifier(random_state=42) classifier.fit(tr_features, tr_labels) y_eval = classifier.predict(te_features) pred_probas_test = classifier.predict_proba(te_features) if len(te_labels) > 0: classification_plot(te_labels, pred_probas_test) te_features["predicted_labels"] = y_eval te_features.to_csv( "output_predicted_data", sep="\t", index=None ) else: regressor = TabPFNRegressor(random_state=42) regressor.fit(tr_features, tr_labels) y_eval = regressor.predict(te_features) if len(te_labels) > 0: score = root_mean_squared_error(te_labels, y_eval) r2_metric_score = r2_score(te_labels, y_eval) regression_plot( te_labels, y_eval, f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}", "True values", "Predicted values", ) te_features["predicted_labels"] = y_eval te_features.to_csv( "output_predicted_data", sep="\t", index=None ) e_time = time.time() print( f"Time taken by TabPFN for training and prediction: {e_time - s_time} seconds" ) if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data") arg_parser.add_argument( "-testhaslabels", "--testhaslabels", required=True, help="if test data contain labels", ) arg_parser.add_argument( "-selectedtask", "--selected_task", required=True, help="Type of machine learning task", ) # get argument values args = vars(arg_parser.parse_args()) train_evaluate(args)