Mercurial > repos > bgruening > tabpfn
comparison main.py @ 5:49b4ee0d0965 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
author | bgruening |
---|---|
date | Wed, 26 Mar 2025 16:32:51 +0000 |
parents | e7b4afedc471 |
children |
comparison
equal
deleted
inserted
replaced
4:e7b4afedc471 | 5:49b4ee0d0965 |
---|---|
55 y_true_bin.ravel(), y_scores.ravel() | 55 y_true_bin.ravel(), y_scores.ravel() |
56 ) | 56 ) |
57 plt.plot( | 57 plt.plot( |
58 recall, precision, linestyle="--", color="black", label="Micro-average" | 58 recall, precision, linestyle="--", color="black", label="Micro-average" |
59 ) | 59 ) |
60 plt.title("Precision-Recall Curve (Multiclass Classification)") | 60 plt.title( |
61 "Precision-Recall Curve (Multiclass Classification)" | |
62 ) | |
61 plt.xlabel("Recall") | 63 plt.xlabel("Recall") |
62 plt.ylabel("Precision") | 64 plt.ylabel("Precision") |
63 plt.legend(loc="lower left") | 65 plt.legend(loc="lower left") |
64 plt.grid(True) | 66 plt.grid(True) |
65 plt.savefig("output_plot.png") | 67 plt.savefig("output_plot.png") |
83 Train TabPFN and predict | 85 Train TabPFN and predict |
84 """ | 86 """ |
85 # prepare train data | 87 # prepare train data |
86 tr_features, tr_labels = separate_features_labels(args["train_data"]) | 88 tr_features, tr_labels = separate_features_labels(args["train_data"]) |
87 # prepare test data | 89 # prepare test data |
88 if args["testhaslabels"] == "haslabels": | 90 if args["testhaslabels"] == "true": |
89 te_features, te_labels = separate_features_labels(args["test_data"]) | 91 te_features, te_labels = separate_features_labels(args["test_data"]) |
90 else: | 92 else: |
91 te_features = pd.read_csv(args["test_data"], sep="\t") | 93 te_features = pd.read_csv(args["test_data"], sep="\t") |
92 te_labels = [] | 94 te_labels = [] |
93 s_time = time.time() | 95 s_time = time.time() |
94 if args["selected_task"] == "Classification": | 96 if args["selected_task"] == "Classification": |
95 classifier = TabPFNClassifier() | 97 classifier = TabPFNClassifier(random_state=42) |
96 classifier.fit(tr_features, tr_labels) | 98 classifier.fit(tr_features, tr_labels) |
97 y_eval = classifier.predict(te_features) | 99 y_eval = classifier.predict(te_features) |
98 pred_probas_test = classifier.predict_proba(te_features) | 100 pred_probas_test = classifier.predict_proba(te_features) |
99 if len(te_labels) > 0: | 101 if len(te_labels) > 0: |
100 classification_plot(te_labels, pred_probas_test) | 102 classification_plot(te_labels, pred_probas_test) |
103 te_features["predicted_labels"] = y_eval | |
104 te_features.to_csv( | |
105 "output_predicted_data", sep="\t", index=None | |
106 ) | |
101 else: | 107 else: |
102 regressor = TabPFNRegressor() | 108 regressor = TabPFNRegressor(random_state=42) |
103 regressor.fit(tr_features, tr_labels) | 109 regressor.fit(tr_features, tr_labels) |
104 y_eval = regressor.predict(te_features) | 110 y_eval = regressor.predict(te_features) |
105 if len(te_labels) > 0: | 111 if len(te_labels) > 0: |
106 score = root_mean_squared_error(te_labels, y_eval) | 112 score = root_mean_squared_error(te_labels, y_eval) |
107 r2_metric_score = r2_score(te_labels, y_eval) | 113 r2_metric_score = r2_score(te_labels, y_eval) |
110 y_eval, | 116 y_eval, |
111 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}", | 117 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}", |
112 "True values", | 118 "True values", |
113 "Predicted values", | 119 "Predicted values", |
114 ) | 120 ) |
121 te_features["predicted_labels"] = y_eval | |
122 te_features.to_csv( | |
123 "output_predicted_data", sep="\t", index=None | |
124 ) | |
115 e_time = time.time() | 125 e_time = time.time() |
116 print( | 126 print( |
117 "Time taken by TabPFN for training and prediction: {} seconds".format( | 127 f"Time taken by TabPFN for training and prediction: {e_time - s_time} seconds" |
118 e_time - s_time | |
119 ) | |
120 ) | 128 ) |
121 te_features["predicted_labels"] = y_eval | |
122 te_features.to_csv("output_predicted_data", sep="\t", index=None) | |
123 | 129 |
124 | 130 |
125 if __name__ == "__main__": | 131 if __name__ == "__main__": |
126 arg_parser = argparse.ArgumentParser() | 132 arg_parser = argparse.ArgumentParser() |
127 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") | 133 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") |