Mercurial > repos > bgruening > tabpfn
annotate main.py @ 5:49b4ee0d0965 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
author | bgruening |
---|---|
date | Wed, 26 Mar 2025 16:32:51 +0000 |
parents | e7b4afedc471 |
children |
rev | line source |
---|---|
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
1 """ |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
2 Tabular data prediction using TabPFN |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
3 """ |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
4 import argparse |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
5 import time |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
6 |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
7 import matplotlib.pyplot as plt |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
8 import numpy as np |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
9 import pandas as pd |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
10 from sklearn.metrics import ( |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
11 average_precision_score, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
12 precision_recall_curve, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
13 r2_score, |
4
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
14 root_mean_squared_error, |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
15 ) |
4
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
16 from sklearn.preprocessing import label_binarize |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
17 from tabpfn import TabPFNClassifier, TabPFNRegressor |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
18 |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
19 |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
20 def separate_features_labels(data): |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
21 df = pd.read_csv(data, sep="\t") |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
22 labels = df.iloc[:, -1] |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
23 features = df.iloc[:, :-1] |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
24 return features, labels |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
25 |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
26 |
4
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
27 def classification_plot(y_true, y_scores): |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
28 plt.figure(figsize=(8, 6)) |
4
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
29 is_binary = len(np.unique(y_true)) == 2 |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
30 if is_binary: |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
31 # Compute precision-recall curve |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
32 precision, recall, _ = precision_recall_curve(y_true, y_scores[:, 1]) |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
33 average_precision = average_precision_score(y_true, y_scores[:, 1]) |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
34 plt.plot( |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
35 recall, |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
36 precision, |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
37 label=f"Precision-Recall Curve (AP={average_precision:.2f})", |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
38 ) |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
39 plt.title("Precision-Recall Curve (binary classification)") |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
40 else: |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
41 y_true_bin = label_binarize(y_true, classes=np.unique(y_true)) |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
42 n_classes = y_true_bin.shape[1] |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
43 class_labels = [f"Class {i}" for i in range(n_classes)] |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
44 # Plot PR curve for each class |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
45 for i in range(n_classes): |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
46 precision, recall, _ = precision_recall_curve( |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
47 y_true_bin[:, i], y_scores[:, i] |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
48 ) |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
49 ap_score = average_precision_score(y_true_bin[:, i], y_scores[:, i]) |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
50 plt.plot( |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
51 recall, precision, label=f"{class_labels[i]} (AP = {ap_score:.2f})" |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
52 ) |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
53 # Compute micro-average PR curve |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
54 precision, recall, _ = precision_recall_curve( |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
55 y_true_bin.ravel(), y_scores.ravel() |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
56 ) |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
57 plt.plot( |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
58 recall, precision, linestyle="--", color="black", label="Micro-average" |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
59 ) |
5
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
60 plt.title( |
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
61 "Precision-Recall Curve (Multiclass Classification)" |
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
62 ) |
4
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
63 plt.xlabel("Recall") |
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
64 plt.ylabel("Precision") |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
65 plt.legend(loc="lower left") |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
66 plt.grid(True) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
67 plt.savefig("output_plot.png") |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
68 |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
69 |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
70 def regression_plot(xval, yval, title, xlabel, ylabel): |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
71 plt.figure(figsize=(8, 6)) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
72 plt.xlabel(xlabel) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
73 plt.ylabel(ylabel) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
74 plt.title(title) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
75 plt.legend(loc="lower left") |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
76 plt.grid(True) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
77 plt.scatter(xval, yval, alpha=0.8) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
78 xticks = np.arange(len(xval)) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
79 plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x") |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
80 plt.savefig("output_plot.png") |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
81 |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
82 |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
83 def train_evaluate(args): |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
84 """ |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
85 Train TabPFN and predict |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
86 """ |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
87 # prepare train data |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
88 tr_features, tr_labels = separate_features_labels(args["train_data"]) |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
89 # prepare test data |
5
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
90 if args["testhaslabels"] == "true": |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
91 te_features, te_labels = separate_features_labels(args["test_data"]) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
92 else: |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
93 te_features = pd.read_csv(args["test_data"], sep="\t") |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
94 te_labels = [] |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
95 s_time = time.time() |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
96 if args["selected_task"] == "Classification": |
5
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
97 classifier = TabPFNClassifier(random_state=42) |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
98 classifier.fit(tr_features, tr_labels) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
99 y_eval = classifier.predict(te_features) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
100 pred_probas_test = classifier.predict_proba(te_features) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
101 if len(te_labels) > 0: |
4
e7b4afedc471
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents:
3
diff
changeset
|
102 classification_plot(te_labels, pred_probas_test) |
5
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
103 te_features["predicted_labels"] = y_eval |
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
104 te_features.to_csv( |
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
105 "output_predicted_data", sep="\t", index=None |
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
106 ) |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
107 else: |
5
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
108 regressor = TabPFNRegressor(random_state=42) |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
109 regressor.fit(tr_features, tr_labels) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
110 y_eval = regressor.predict(te_features) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
111 if len(te_labels) > 0: |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
112 score = root_mean_squared_error(te_labels, y_eval) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
113 r2_metric_score = r2_score(te_labels, y_eval) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
114 regression_plot( |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
115 te_labels, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
116 y_eval, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
117 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
118 "True values", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
119 "Predicted values", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
120 ) |
5
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
121 te_features["predicted_labels"] = y_eval |
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
122 te_features.to_csv( |
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
123 "output_predicted_data", sep="\t", index=None |
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
124 ) |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
125 e_time = time.time() |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
126 print( |
5
49b4ee0d0965
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents:
4
diff
changeset
|
127 f"Time taken by TabPFN for training and prediction: {e_time - s_time} seconds" |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
128 ) |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
129 |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
130 |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
131 if __name__ == "__main__": |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
132 arg_parser = argparse.ArgumentParser() |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
133 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
134 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data") |
2
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
135 arg_parser.add_argument( |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
136 "-testhaslabels", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
137 "--testhaslabels", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
138 required=True, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
139 help="if test data contain labels", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
140 ) |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
141 arg_parser.add_argument( |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
142 "-selectedtask", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
143 "--selected_task", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
144 required=True, |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
145 help="Type of machine learning task", |
c081e5e1d7ce
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents:
0
diff
changeset
|
146 ) |
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
147 # get argument values |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
148 args = vars(arg_parser.parse_args()) |
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
149 train_evaluate(args) |