Mercurial > repos > bgruening > tabpfn
diff main.py @ 2:c081e5e1d7ce draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
author | bgruening |
---|---|
date | Fri, 17 Jan 2025 22:23:34 +0000 |
parents | 3dc3c7443c8e |
children | 33d53eb476fd |
line wrap: on
line diff
--- a/main.py Wed Jan 15 20:34:13 2025 +0000 +++ b/main.py Fri Jan 17 22:23:34 2025 +0000 @@ -5,9 +5,15 @@ import time import matplotlib.pyplot as plt +import numpy as np import pandas as pd -from sklearn.metrics import accuracy_score, average_precision_score, precision_recall_curve -from tabpfn import TabPFNClassifier +from sklearn.metrics import ( + average_precision_score, + precision_recall_curve, + r2_score, + root_mean_squared_error +) +from tabpfn import TabPFNClassifier, TabPFNRegressor def separate_features_labels(data): @@ -17,38 +23,103 @@ return features, labels +def classification_plot(xval, yval, leg_label, title, xlabel, ylabel): + plt.figure(figsize=(8, 6)) + plt.plot(xval, yval, label=leg_label) + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.title(title) + plt.legend(loc="lower left") + plt.grid(True) + plt.savefig("output_plot.png") + + +def regression_plot(xval, yval, title, xlabel, ylabel): + plt.figure(figsize=(8, 6)) + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.title(title) + plt.legend(loc="lower left") + plt.grid(True) + plt.scatter(xval, yval, alpha=0.8) + xticks = np.arange(len(xval)) + plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x") + plt.savefig("output_plot.png") + + def train_evaluate(args): """ - Train TabPFN + Train TabPFN and predict """ + # prepare train data tr_features, tr_labels = separate_features_labels(args["train_data"]) - te_features, te_labels = separate_features_labels(args["test_data"]) - classifier = TabPFNClassifier(device='cpu') + # prepare test data + if args["testhaslabels"] == "haslabels": + te_features, te_labels = separate_features_labels(args["test_data"]) + else: + te_features = pd.read_csv(args["test_data"], sep="\t") + te_labels = [] s_time = time.time() - classifier.fit(tr_features, tr_labels) + if args["selected_task"] == "Classification": + classifier = TabPFNClassifier(device="cpu") + classifier.fit(tr_features, tr_labels) + y_eval = classifier.predict(te_features) + pred_probas_test = classifier.predict_proba(te_features) + if len(te_labels) > 0: + precision, recall, thresholds = precision_recall_curve( + te_labels, pred_probas_test[:, 1] + ) + average_precision = average_precision_score( + te_labels, pred_probas_test[:, 1] + ) + classification_plot( + recall, + precision, + f"Precision-Recall Curve (AP={average_precision:.2f})", + "Precision-Recall Curve", + "Recall", + "Precision", + ) + else: + regressor = TabPFNRegressor(device="cpu") + regressor.fit(tr_features, tr_labels) + y_eval = regressor.predict(te_features) + if len(te_labels) > 0: + score = root_mean_squared_error(te_labels, y_eval) + r2_metric_score = r2_score(te_labels, y_eval) + regression_plot( + te_labels, + y_eval, + f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}", + "True values", + "Predicted values", + ) e_time = time.time() - print("Time taken by TabPFN for training: {} seconds".format(e_time - s_time)) - y_eval = classifier.predict(te_features) - print('Accuracy', accuracy_score(te_labels, y_eval)) - pred_probas_test = classifier.predict_proba(te_features) + print( + "Time taken by TabPFN for training and prediction: {} seconds".format( + e_time - s_time + ) + ) te_features["predicted_labels"] = y_eval te_features.to_csv("output_predicted_data", sep="\t", index=None) - precision, recall, thresholds = precision_recall_curve(te_labels, pred_probas_test[:, 1]) - average_precision = average_precision_score(te_labels, pred_probas_test[:, 1]) - plt.figure(figsize=(8, 6)) - plt.plot(recall, precision, label=f'Precision-Recall Curve (AP={average_precision:.2f})') - plt.xlabel('Recall') - plt.ylabel('Precision') - plt.title('Precision-Recall Curve') - plt.legend(loc='lower left') - plt.grid(True) - plt.savefig("output_prec_recall_curve.png") if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data") + arg_parser.add_argument( + "-testhaslabels", + "--testhaslabels", + required=True, + help="if test data contain labels", + ) + arg_parser.add_argument( + "-selectedtask", + "--selected_task", + required=True, + help="Type of machine learning task", + ) # get argument values args = vars(arg_parser.parse_args()) train_evaluate(args)