comparison main.py @ 5:49b4ee0d0965 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
author bgruening
date Wed, 26 Mar 2025 16:32:51 +0000
parents e7b4afedc471
children
comparison
equal deleted inserted replaced
4:e7b4afedc471 5:49b4ee0d0965
55 y_true_bin.ravel(), y_scores.ravel() 55 y_true_bin.ravel(), y_scores.ravel()
56 ) 56 )
57 plt.plot( 57 plt.plot(
58 recall, precision, linestyle="--", color="black", label="Micro-average" 58 recall, precision, linestyle="--", color="black", label="Micro-average"
59 ) 59 )
60 plt.title("Precision-Recall Curve (Multiclass Classification)") 60 plt.title(
61 "Precision-Recall Curve (Multiclass Classification)"
62 )
61 plt.xlabel("Recall") 63 plt.xlabel("Recall")
62 plt.ylabel("Precision") 64 plt.ylabel("Precision")
63 plt.legend(loc="lower left") 65 plt.legend(loc="lower left")
64 plt.grid(True) 66 plt.grid(True)
65 plt.savefig("output_plot.png") 67 plt.savefig("output_plot.png")
83 Train TabPFN and predict 85 Train TabPFN and predict
84 """ 86 """
85 # prepare train data 87 # prepare train data
86 tr_features, tr_labels = separate_features_labels(args["train_data"]) 88 tr_features, tr_labels = separate_features_labels(args["train_data"])
87 # prepare test data 89 # prepare test data
88 if args["testhaslabels"] == "haslabels": 90 if args["testhaslabels"] == "true":
89 te_features, te_labels = separate_features_labels(args["test_data"]) 91 te_features, te_labels = separate_features_labels(args["test_data"])
90 else: 92 else:
91 te_features = pd.read_csv(args["test_data"], sep="\t") 93 te_features = pd.read_csv(args["test_data"], sep="\t")
92 te_labels = [] 94 te_labels = []
93 s_time = time.time() 95 s_time = time.time()
94 if args["selected_task"] == "Classification": 96 if args["selected_task"] == "Classification":
95 classifier = TabPFNClassifier() 97 classifier = TabPFNClassifier(random_state=42)
96 classifier.fit(tr_features, tr_labels) 98 classifier.fit(tr_features, tr_labels)
97 y_eval = classifier.predict(te_features) 99 y_eval = classifier.predict(te_features)
98 pred_probas_test = classifier.predict_proba(te_features) 100 pred_probas_test = classifier.predict_proba(te_features)
99 if len(te_labels) > 0: 101 if len(te_labels) > 0:
100 classification_plot(te_labels, pred_probas_test) 102 classification_plot(te_labels, pred_probas_test)
103 te_features["predicted_labels"] = y_eval
104 te_features.to_csv(
105 "output_predicted_data", sep="\t", index=None
106 )
101 else: 107 else:
102 regressor = TabPFNRegressor() 108 regressor = TabPFNRegressor(random_state=42)
103 regressor.fit(tr_features, tr_labels) 109 regressor.fit(tr_features, tr_labels)
104 y_eval = regressor.predict(te_features) 110 y_eval = regressor.predict(te_features)
105 if len(te_labels) > 0: 111 if len(te_labels) > 0:
106 score = root_mean_squared_error(te_labels, y_eval) 112 score = root_mean_squared_error(te_labels, y_eval)
107 r2_metric_score = r2_score(te_labels, y_eval) 113 r2_metric_score = r2_score(te_labels, y_eval)
110 y_eval, 116 y_eval,
111 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}", 117 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}",
112 "True values", 118 "True values",
113 "Predicted values", 119 "Predicted values",
114 ) 120 )
121 te_features["predicted_labels"] = y_eval
122 te_features.to_csv(
123 "output_predicted_data", sep="\t", index=None
124 )
115 e_time = time.time() 125 e_time = time.time()
116 print( 126 print(
117 "Time taken by TabPFN for training and prediction: {} seconds".format( 127 f"Time taken by TabPFN for training and prediction: {e_time - s_time} seconds"
118 e_time - s_time
119 )
120 ) 128 )
121 te_features["predicted_labels"] = y_eval
122 te_features.to_csv("output_predicted_data", sep="\t", index=None)
123 129
124 130
125 if __name__ == "__main__": 131 if __name__ == "__main__":
126 arg_parser = argparse.ArgumentParser() 132 arg_parser = argparse.ArgumentParser()
127 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") 133 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data")