Mercurial > repos > bgruening > tabpfn
annotate main.py @ 3:33d53eb476fd draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
author | bgruening |
---|---|
date | Mon, 20 Jan 2025 15:45:17 +0000 |
parents | c081e5e1d7ce |
children | e7b4afedc471 |
rev | line source |
---|---|
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
1 """ |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
2 Tabular data prediction using TabPFN |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
3 """ |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
4 import argparse |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
5 import time |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
6 |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
7 import matplotlib.pyplot as plt |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
8 import numpy as np |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
9 import pandas as pd |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
10 from sklearn.metrics import ( |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
11 average_precision_score, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
12 precision_recall_curve, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
13 r2_score, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
14 root_mean_squared_error |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
15 ) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
16 from tabpfn import TabPFNClassifier, TabPFNRegressor |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
17 |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
18 |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
19 def separate_features_labels(data): |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
20 df = pd.read_csv(data, sep="\t") |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
21 labels = df.iloc[:, -1] |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
22 features = df.iloc[:, :-1] |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
23 return features, labels |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
24 |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
25 |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
26 def classification_plot(xval, yval, leg_label, title, xlabel, ylabel): |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
27 plt.figure(figsize=(8, 6)) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
28 plt.plot(xval, yval, label=leg_label) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
29 plt.xlabel(xlabel) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
30 plt.ylabel(ylabel) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
31 plt.title(title) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
32 plt.legend(loc="lower left") |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
33 plt.grid(True) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
34 plt.savefig("output_plot.png") |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
35 |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
36 |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
37 def regression_plot(xval, yval, title, xlabel, ylabel): |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
38 plt.figure(figsize=(8, 6)) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
39 plt.xlabel(xlabel) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
40 plt.ylabel(ylabel) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
41 plt.title(title) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
42 plt.legend(loc="lower left") |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
43 plt.grid(True) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
44 plt.scatter(xval, yval, alpha=0.8) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
45 xticks = np.arange(len(xval)) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
46 plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x") |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
47 plt.savefig("output_plot.png") |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
48 |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
49 |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
50 def train_evaluate(args): |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
51 """ |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
52 Train TabPFN and predict |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
53 """ |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
54 # prepare train data |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
55 tr_features, tr_labels = separate_features_labels(args["train_data"]) |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
56 # prepare test data |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
57 if args["testhaslabels"] == "haslabels": |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
58 te_features, te_labels = separate_features_labels(args["test_data"]) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
59 else: |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
60 te_features = pd.read_csv(args["test_data"], sep="\t") |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
61 te_labels = [] |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
62 s_time = time.time() |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
63 if args["selected_task"] == "Classification": |
3
33d53eb476fd
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
bgruening
parents:
2
diff
changeset
|
64 classifier = TabPFNClassifier() |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
65 classifier.fit(tr_features, tr_labels) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
66 y_eval = classifier.predict(te_features) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
67 pred_probas_test = classifier.predict_proba(te_features) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
68 if len(te_labels) > 0: |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
69 precision, recall, thresholds = precision_recall_curve( |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
70 te_labels, pred_probas_test[:, 1] |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
71 ) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
72 average_precision = average_precision_score( |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
73 te_labels, pred_probas_test[:, 1] |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
74 ) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
75 classification_plot( |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
76 recall, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
77 precision, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
78 f"Precision-Recall Curve (AP={average_precision:.2f})", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
79 "Precision-Recall Curve", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
80 "Recall", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
81 "Precision", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
82 ) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
83 else: |
3
33d53eb476fd
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
bgruening
parents:
2
diff
changeset
|
84 regressor = TabPFNRegressor() |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
85 regressor.fit(tr_features, tr_labels) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
86 y_eval = regressor.predict(te_features) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
87 if len(te_labels) > 0: |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
88 score = root_mean_squared_error(te_labels, y_eval) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
89 r2_metric_score = r2_score(te_labels, y_eval) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
90 regression_plot( |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
91 te_labels, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
92 y_eval, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
93 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
94 "True values", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
95 "Predicted values", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
96 ) |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
97 e_time = time.time() |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
98 print( |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
99 "Time taken by TabPFN for training and prediction: {} seconds".format( |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
100 e_time - s_time |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
101 ) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
102 ) |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
103 te_features["predicted_labels"] = y_eval |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
104 te_features.to_csv("output_predicted_data", sep="\t", index=None) |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
105 |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
106 |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
107 if __name__ == "__main__": |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
108 arg_parser = argparse.ArgumentParser() |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
109 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
110 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data") |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
111 arg_parser.add_argument( |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
112 "-testhaslabels", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
113 "--testhaslabels", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
114 required=True, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
115 help="if test data contain labels", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
116 ) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
117 arg_parser.add_argument( |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
118 "-selectedtask", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
119 "--selected_task", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
120 required=True, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
121 help="Type of machine learning task", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
122 ) |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
123 # get argument values |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
124 args = vars(arg_parser.parse_args()) |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
125 train_evaluate(args) |