Mercurial > repos > bgruening > tabpfn
diff main.py @ 0:3dc3c7443c8e draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
author | bgruening |
---|---|
date | Wed, 15 Jan 2025 12:33:49 +0000 |
parents | |
children | c081e5e1d7ce |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/main.py Wed Jan 15 12:33:49 2025 +0000 @@ -0,0 +1,54 @@ +""" +Tabular data prediction using TabPFN +""" +import argparse +import time + +import matplotlib.pyplot as plt +import pandas as pd +from sklearn.metrics import accuracy_score, average_precision_score, precision_recall_curve +from tabpfn import TabPFNClassifier + + +def separate_features_labels(data): + df = pd.read_csv(data, sep="\t") + labels = df.iloc[:, -1] + features = df.iloc[:, :-1] + return features, labels + + +def train_evaluate(args): + """ + Train TabPFN + """ + tr_features, tr_labels = separate_features_labels(args["train_data"]) + te_features, te_labels = separate_features_labels(args["test_data"]) + classifier = TabPFNClassifier(device='cpu') + s_time = time.time() + classifier.fit(tr_features, tr_labels) + e_time = time.time() + print("Time taken by TabPFN for training: {} seconds".format(e_time - s_time)) + y_eval = classifier.predict(te_features) + print('Accuracy', accuracy_score(te_labels, y_eval)) + pred_probas_test = classifier.predict_proba(te_features) + te_features["predicted_labels"] = y_eval + te_features.to_csv("output_predicted_data", sep="\t", index=None) + precision, recall, thresholds = precision_recall_curve(te_labels, pred_probas_test[:, 1]) + average_precision = average_precision_score(te_labels, pred_probas_test[:, 1]) + plt.figure(figsize=(8, 6)) + plt.plot(recall, precision, label=f'Precision-Recall Curve (AP={average_precision:.2f})') + plt.xlabel('Recall') + plt.ylabel('Precision') + plt.title('Precision-Recall Curve') + plt.legend(loc='lower left') + plt.grid(True) + plt.savefig("output_prec_recall_curve.png") + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") + arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data") + # get argument values + args = vars(arg_parser.parse_args()) + train_evaluate(args)