annotate main.py @ 5:49b4ee0d0965 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
author bgruening
date Wed, 26 Mar 2025 16:32:51 +0000
parents e7b4afedc471
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
1 """
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
2 Tabular data prediction using TabPFN
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
3 """
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
4 import argparse
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
5 import time
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
6
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
7 import matplotlib.pyplot as plt
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
8 import numpy as np
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
9 import pandas as pd
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
10 from sklearn.metrics import (
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
11 average_precision_score,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
12 precision_recall_curve,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
13 r2_score,
4
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
14 root_mean_squared_error,
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
15 )
4
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
16 from sklearn.preprocessing import label_binarize
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
17 from tabpfn import TabPFNClassifier, TabPFNRegressor
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
18
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
19
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
20 def separate_features_labels(data):
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
21 df = pd.read_csv(data, sep="\t")
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
22 labels = df.iloc[:, -1]
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
23 features = df.iloc[:, :-1]
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
24 return features, labels
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
25
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
26
4
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
27 def classification_plot(y_true, y_scores):
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
28 plt.figure(figsize=(8, 6))
4
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
29 is_binary = len(np.unique(y_true)) == 2
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
30 if is_binary:
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
31 # Compute precision-recall curve
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
32 precision, recall, _ = precision_recall_curve(y_true, y_scores[:, 1])
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
33 average_precision = average_precision_score(y_true, y_scores[:, 1])
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
34 plt.plot(
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
35 recall,
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
36 precision,
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
37 label=f"Precision-Recall Curve (AP={average_precision:.2f})",
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
38 )
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
39 plt.title("Precision-Recall Curve (binary classification)")
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
40 else:
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
41 y_true_bin = label_binarize(y_true, classes=np.unique(y_true))
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
42 n_classes = y_true_bin.shape[1]
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
43 class_labels = [f"Class {i}" for i in range(n_classes)]
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
44 # Plot PR curve for each class
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
45 for i in range(n_classes):
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
46 precision, recall, _ = precision_recall_curve(
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
47 y_true_bin[:, i], y_scores[:, i]
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
48 )
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
49 ap_score = average_precision_score(y_true_bin[:, i], y_scores[:, i])
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
50 plt.plot(
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
51 recall, precision, label=f"{class_labels[i]} (AP = {ap_score:.2f})"
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
52 )
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
53 # Compute micro-average PR curve
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
54 precision, recall, _ = precision_recall_curve(
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
55 y_true_bin.ravel(), y_scores.ravel()
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
56 )
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
57 plt.plot(
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
58 recall, precision, linestyle="--", color="black", label="Micro-average"
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
59 )
5
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
60 plt.title(
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
61 "Precision-Recall Curve (Multiclass Classification)"
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
62 )
4
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
63 plt.xlabel("Recall")
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
64 plt.ylabel("Precision")
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
65 plt.legend(loc="lower left")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
66 plt.grid(True)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
67 plt.savefig("output_plot.png")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
68
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
69
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
70 def regression_plot(xval, yval, title, xlabel, ylabel):
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
71 plt.figure(figsize=(8, 6))
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
72 plt.xlabel(xlabel)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
73 plt.ylabel(ylabel)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
74 plt.title(title)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
75 plt.legend(loc="lower left")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
76 plt.grid(True)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
77 plt.scatter(xval, yval, alpha=0.8)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
78 xticks = np.arange(len(xval))
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
79 plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
80 plt.savefig("output_plot.png")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
81
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
82
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
83 def train_evaluate(args):
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
84 """
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
85 Train TabPFN and predict
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
86 """
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
87 # prepare train data
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
88 tr_features, tr_labels = separate_features_labels(args["train_data"])
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
89 # prepare test data
5
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
90 if args["testhaslabels"] == "true":
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
91 te_features, te_labels = separate_features_labels(args["test_data"])
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
92 else:
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
93 te_features = pd.read_csv(args["test_data"], sep="\t")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
94 te_labels = []
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
95 s_time = time.time()
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
96 if args["selected_task"] == "Classification":
5
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
97 classifier = TabPFNClassifier(random_state=42)
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
98 classifier.fit(tr_features, tr_labels)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
99 y_eval = classifier.predict(te_features)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
100 pred_probas_test = classifier.predict_proba(te_features)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
101 if len(te_labels) > 0:
4
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
102 classification_plot(te_labels, pred_probas_test)
5
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
103 te_features["predicted_labels"] = y_eval
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
104 te_features.to_csv(
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
105 "output_predicted_data", sep="\t", index=None
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
106 )
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
107 else:
5
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
108 regressor = TabPFNRegressor(random_state=42)
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
109 regressor.fit(tr_features, tr_labels)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
110 y_eval = regressor.predict(te_features)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
111 if len(te_labels) > 0:
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
112 score = root_mean_squared_error(te_labels, y_eval)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
113 r2_metric_score = r2_score(te_labels, y_eval)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
114 regression_plot(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
115 te_labels,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
116 y_eval,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
117 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
118 "True values",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
119 "Predicted values",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
120 )
5
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
121 te_features["predicted_labels"] = y_eval
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
122 te_features.to_csv(
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
123 "output_predicted_data", sep="\t", index=None
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
124 )
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
125 e_time = time.time()
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
126 print(
5
49b4ee0d0965 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit cefdfdc13838de5108e13f54ecd69babb44009a1
bgruening
parents: 4
diff changeset
127 f"Time taken by TabPFN for training and prediction: {e_time - s_time} seconds"
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
128 )
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
129
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
130
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
131 if __name__ == "__main__":
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
132 arg_parser = argparse.ArgumentParser()
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
133 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data")
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
134 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data")
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
135 arg_parser.add_argument(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
136 "-testhaslabels",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
137 "--testhaslabels",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
138 required=True,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
139 help="if test data contain labels",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
140 )
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
141 arg_parser.add_argument(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
142 "-selectedtask",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
143 "--selected_task",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
144 required=True,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
145 help="Type of machine learning task",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
146 )
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
147 # get argument values
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
148 args = vars(arg_parser.parse_args())
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
149 train_evaluate(args)