Mercurial > repos > bgruening > tabpfn
comparison main.py @ 2:c081e5e1d7ce draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
author | bgruening |
---|---|
date | Fri, 17 Jan 2025 22:23:34 +0000 |
parents | 3dc3c7443c8e |
children | 33d53eb476fd |
comparison
equal
deleted
inserted
replaced
1:5112462f2dd3 | 2:c081e5e1d7ce |
---|---|
3 """ | 3 """ |
4 import argparse | 4 import argparse |
5 import time | 5 import time |
6 | 6 |
7 import matplotlib.pyplot as plt | 7 import matplotlib.pyplot as plt |
8 import numpy as np | |
8 import pandas as pd | 9 import pandas as pd |
9 from sklearn.metrics import accuracy_score, average_precision_score, precision_recall_curve | 10 from sklearn.metrics import ( |
10 from tabpfn import TabPFNClassifier | 11 average_precision_score, |
12 precision_recall_curve, | |
13 r2_score, | |
14 root_mean_squared_error | |
15 ) | |
16 from tabpfn import TabPFNClassifier, TabPFNRegressor | |
11 | 17 |
12 | 18 |
13 def separate_features_labels(data): | 19 def separate_features_labels(data): |
14 df = pd.read_csv(data, sep="\t") | 20 df = pd.read_csv(data, sep="\t") |
15 labels = df.iloc[:, -1] | 21 labels = df.iloc[:, -1] |
16 features = df.iloc[:, :-1] | 22 features = df.iloc[:, :-1] |
17 return features, labels | 23 return features, labels |
18 | 24 |
19 | 25 |
26 def classification_plot(xval, yval, leg_label, title, xlabel, ylabel): | |
27 plt.figure(figsize=(8, 6)) | |
28 plt.plot(xval, yval, label=leg_label) | |
29 plt.xlabel(xlabel) | |
30 plt.ylabel(ylabel) | |
31 plt.title(title) | |
32 plt.legend(loc="lower left") | |
33 plt.grid(True) | |
34 plt.savefig("output_plot.png") | |
35 | |
36 | |
37 def regression_plot(xval, yval, title, xlabel, ylabel): | |
38 plt.figure(figsize=(8, 6)) | |
39 plt.xlabel(xlabel) | |
40 plt.ylabel(ylabel) | |
41 plt.title(title) | |
42 plt.legend(loc="lower left") | |
43 plt.grid(True) | |
44 plt.scatter(xval, yval, alpha=0.8) | |
45 xticks = np.arange(len(xval)) | |
46 plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x") | |
47 plt.savefig("output_plot.png") | |
48 | |
49 | |
20 def train_evaluate(args): | 50 def train_evaluate(args): |
21 """ | 51 """ |
22 Train TabPFN | 52 Train TabPFN and predict |
23 """ | 53 """ |
54 # prepare train data | |
24 tr_features, tr_labels = separate_features_labels(args["train_data"]) | 55 tr_features, tr_labels = separate_features_labels(args["train_data"]) |
25 te_features, te_labels = separate_features_labels(args["test_data"]) | 56 # prepare test data |
26 classifier = TabPFNClassifier(device='cpu') | 57 if args["testhaslabels"] == "haslabels": |
58 te_features, te_labels = separate_features_labels(args["test_data"]) | |
59 else: | |
60 te_features = pd.read_csv(args["test_data"], sep="\t") | |
61 te_labels = [] | |
27 s_time = time.time() | 62 s_time = time.time() |
28 classifier.fit(tr_features, tr_labels) | 63 if args["selected_task"] == "Classification": |
64 classifier = TabPFNClassifier(device="cpu") | |
65 classifier.fit(tr_features, tr_labels) | |
66 y_eval = classifier.predict(te_features) | |
67 pred_probas_test = classifier.predict_proba(te_features) | |
68 if len(te_labels) > 0: | |
69 precision, recall, thresholds = precision_recall_curve( | |
70 te_labels, pred_probas_test[:, 1] | |
71 ) | |
72 average_precision = average_precision_score( | |
73 te_labels, pred_probas_test[:, 1] | |
74 ) | |
75 classification_plot( | |
76 recall, | |
77 precision, | |
78 f"Precision-Recall Curve (AP={average_precision:.2f})", | |
79 "Precision-Recall Curve", | |
80 "Recall", | |
81 "Precision", | |
82 ) | |
83 else: | |
84 regressor = TabPFNRegressor(device="cpu") | |
85 regressor.fit(tr_features, tr_labels) | |
86 y_eval = regressor.predict(te_features) | |
87 if len(te_labels) > 0: | |
88 score = root_mean_squared_error(te_labels, y_eval) | |
89 r2_metric_score = r2_score(te_labels, y_eval) | |
90 regression_plot( | |
91 te_labels, | |
92 y_eval, | |
93 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}", | |
94 "True values", | |
95 "Predicted values", | |
96 ) | |
29 e_time = time.time() | 97 e_time = time.time() |
30 print("Time taken by TabPFN for training: {} seconds".format(e_time - s_time)) | 98 print( |
31 y_eval = classifier.predict(te_features) | 99 "Time taken by TabPFN for training and prediction: {} seconds".format( |
32 print('Accuracy', accuracy_score(te_labels, y_eval)) | 100 e_time - s_time |
33 pred_probas_test = classifier.predict_proba(te_features) | 101 ) |
102 ) | |
34 te_features["predicted_labels"] = y_eval | 103 te_features["predicted_labels"] = y_eval |
35 te_features.to_csv("output_predicted_data", sep="\t", index=None) | 104 te_features.to_csv("output_predicted_data", sep="\t", index=None) |
36 precision, recall, thresholds = precision_recall_curve(te_labels, pred_probas_test[:, 1]) | |
37 average_precision = average_precision_score(te_labels, pred_probas_test[:, 1]) | |
38 plt.figure(figsize=(8, 6)) | |
39 plt.plot(recall, precision, label=f'Precision-Recall Curve (AP={average_precision:.2f})') | |
40 plt.xlabel('Recall') | |
41 plt.ylabel('Precision') | |
42 plt.title('Precision-Recall Curve') | |
43 plt.legend(loc='lower left') | |
44 plt.grid(True) | |
45 plt.savefig("output_prec_recall_curve.png") | |
46 | 105 |
47 | 106 |
48 if __name__ == "__main__": | 107 if __name__ == "__main__": |
49 arg_parser = argparse.ArgumentParser() | 108 arg_parser = argparse.ArgumentParser() |
50 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") | 109 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") |
51 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data") | 110 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data") |
111 arg_parser.add_argument( | |
112 "-testhaslabels", | |
113 "--testhaslabels", | |
114 required=True, | |
115 help="if test data contain labels", | |
116 ) | |
117 arg_parser.add_argument( | |
118 "-selectedtask", | |
119 "--selected_task", | |
120 required=True, | |
121 help="Type of machine learning task", | |
122 ) | |
52 # get argument values | 123 # get argument values |
53 args = vars(arg_parser.parse_args()) | 124 args = vars(arg_parser.parse_args()) |
54 train_evaluate(args) | 125 train_evaluate(args) |