annotate main.py @ 3:33d53eb476fd draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
author bgruening
date Mon, 20 Jan 2025 15:45:17 +0000
parents c081e5e1d7ce
children e7b4afedc471
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
1 """
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
2 Tabular data prediction using TabPFN
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
3 """
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
4 import argparse
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
5 import time
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
6
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
7 import matplotlib.pyplot as plt
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
8 import numpy as np
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
9 import pandas as pd
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
10 from sklearn.metrics import (
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
11 average_precision_score,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
12 precision_recall_curve,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
13 r2_score,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
14 root_mean_squared_error
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
15 )
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
16 from tabpfn import TabPFNClassifier, TabPFNRegressor
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
17
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
18
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
19 def separate_features_labels(data):
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
20 df = pd.read_csv(data, sep="\t")
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
21 labels = df.iloc[:, -1]
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
22 features = df.iloc[:, :-1]
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
23 return features, labels
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
24
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
25
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
26 def classification_plot(xval, yval, leg_label, title, xlabel, ylabel):
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
27 plt.figure(figsize=(8, 6))
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
28 plt.plot(xval, yval, label=leg_label)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
29 plt.xlabel(xlabel)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
30 plt.ylabel(ylabel)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
31 plt.title(title)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
32 plt.legend(loc="lower left")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
33 plt.grid(True)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
34 plt.savefig("output_plot.png")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
35
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
36
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
37 def regression_plot(xval, yval, title, xlabel, ylabel):
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
38 plt.figure(figsize=(8, 6))
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
39 plt.xlabel(xlabel)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
40 plt.ylabel(ylabel)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
41 plt.title(title)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
42 plt.legend(loc="lower left")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
43 plt.grid(True)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
44 plt.scatter(xval, yval, alpha=0.8)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
45 xticks = np.arange(len(xval))
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
46 plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
47 plt.savefig("output_plot.png")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
48
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
49
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
50 def train_evaluate(args):
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
51 """
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
52 Train TabPFN and predict
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
53 """
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
54 # prepare train data
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
55 tr_features, tr_labels = separate_features_labels(args["train_data"])
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
56 # prepare test data
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
57 if args["testhaslabels"] == "haslabels":
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
58 te_features, te_labels = separate_features_labels(args["test_data"])
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
59 else:
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
60 te_features = pd.read_csv(args["test_data"], sep="\t")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
61 te_labels = []
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
62 s_time = time.time()
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
63 if args["selected_task"] == "Classification":
3
33d53eb476fd planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
bgruening
parents: 2
diff changeset
64 classifier = TabPFNClassifier()
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
65 classifier.fit(tr_features, tr_labels)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
66 y_eval = classifier.predict(te_features)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
67 pred_probas_test = classifier.predict_proba(te_features)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
68 if len(te_labels) > 0:
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
69 precision, recall, thresholds = precision_recall_curve(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
70 te_labels, pred_probas_test[:, 1]
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
71 )
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
72 average_precision = average_precision_score(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
73 te_labels, pred_probas_test[:, 1]
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
74 )
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
75 classification_plot(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
76 recall,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
77 precision,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
78 f"Precision-Recall Curve (AP={average_precision:.2f})",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
79 "Precision-Recall Curve",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
80 "Recall",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
81 "Precision",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
82 )
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
83 else:
3
33d53eb476fd planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
bgruening
parents: 2
diff changeset
84 regressor = TabPFNRegressor()
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
85 regressor.fit(tr_features, tr_labels)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
86 y_eval = regressor.predict(te_features)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
87 if len(te_labels) > 0:
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
88 score = root_mean_squared_error(te_labels, y_eval)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
89 r2_metric_score = r2_score(te_labels, y_eval)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
90 regression_plot(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
91 te_labels,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
92 y_eval,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
93 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
94 "True values",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
95 "Predicted values",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
96 )
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
97 e_time = time.time()
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
98 print(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
99 "Time taken by TabPFN for training and prediction: {} seconds".format(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
100 e_time - s_time
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
101 )
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
102 )
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
103 te_features["predicted_labels"] = y_eval
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
104 te_features.to_csv("output_predicted_data", sep="\t", index=None)
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
105
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
106
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
107 if __name__ == "__main__":
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
108 arg_parser = argparse.ArgumentParser()
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
109 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data")
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
110 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data")
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
111 arg_parser.add_argument(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
112 "-testhaslabels",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
113 "--testhaslabels",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
114 required=True,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
115 help="if test data contain labels",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
116 )
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
117 arg_parser.add_argument(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
118 "-selectedtask",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
119 "--selected_task",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
120 required=True,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
121 help="Type of machine learning task",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
122 )
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
123 # get argument values
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
124 args = vars(arg_parser.parse_args())
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
125 train_evaluate(args)