annotate main.py @ 4:e7b4afedc471 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
author bgruening
date Tue, 11 Feb 2025 10:14:12 +0000
parents 33d53eb476fd
children 49b4ee0d0965
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
1 """
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
2 Tabular data prediction using TabPFN
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
3 """
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
4 import argparse
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
5 import time
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
6
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
7 import matplotlib.pyplot as plt
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
8 import numpy as np
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
9 import pandas as pd
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
10 from sklearn.metrics import (
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
11 average_precision_score,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
12 precision_recall_curve,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
13 r2_score,
4
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
14 root_mean_squared_error,
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
15 )
4
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
16 from sklearn.preprocessing import label_binarize
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
17 from tabpfn import TabPFNClassifier, TabPFNRegressor
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
18
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
19
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
20 def separate_features_labels(data):
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
21 df = pd.read_csv(data, sep="\t")
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
22 labels = df.iloc[:, -1]
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
23 features = df.iloc[:, :-1]
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
24 return features, labels
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
25
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
26
4
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
27 def classification_plot(y_true, y_scores):
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
28 plt.figure(figsize=(8, 6))
4
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
29 is_binary = len(np.unique(y_true)) == 2
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
30 if is_binary:
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
31 # Compute precision-recall curve
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
32 precision, recall, _ = precision_recall_curve(y_true, y_scores[:, 1])
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
33 average_precision = average_precision_score(y_true, y_scores[:, 1])
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
34 plt.plot(
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
35 recall,
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
36 precision,
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
37 label=f"Precision-Recall Curve (AP={average_precision:.2f})",
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
38 )
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
39 plt.title("Precision-Recall Curve (binary classification)")
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
40 else:
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
41 y_true_bin = label_binarize(y_true, classes=np.unique(y_true))
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
42 n_classes = y_true_bin.shape[1]
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
43 class_labels = [f"Class {i}" for i in range(n_classes)]
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
44 # Plot PR curve for each class
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
45 for i in range(n_classes):
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
46 precision, recall, _ = precision_recall_curve(
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
47 y_true_bin[:, i], y_scores[:, i]
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
48 )
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
49 ap_score = average_precision_score(y_true_bin[:, i], y_scores[:, i])
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
50 plt.plot(
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
51 recall, precision, label=f"{class_labels[i]} (AP = {ap_score:.2f})"
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
52 )
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
53 # Compute micro-average PR curve
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
54 precision, recall, _ = precision_recall_curve(
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
55 y_true_bin.ravel(), y_scores.ravel()
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
56 )
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
57 plt.plot(
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
58 recall, precision, linestyle="--", color="black", label="Micro-average"
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
59 )
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
60 plt.title("Precision-Recall Curve (Multiclass Classification)")
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
61 plt.xlabel("Recall")
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
62 plt.ylabel("Precision")
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
63 plt.legend(loc="lower left")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
64 plt.grid(True)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
65 plt.savefig("output_plot.png")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
66
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
67
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
68 def regression_plot(xval, yval, title, xlabel, ylabel):
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
69 plt.figure(figsize=(8, 6))
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
70 plt.xlabel(xlabel)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
71 plt.ylabel(ylabel)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
72 plt.title(title)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
73 plt.legend(loc="lower left")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
74 plt.grid(True)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
75 plt.scatter(xval, yval, alpha=0.8)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
76 xticks = np.arange(len(xval))
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
77 plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
78 plt.savefig("output_plot.png")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
79
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
80
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
81 def train_evaluate(args):
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
82 """
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
83 Train TabPFN and predict
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
84 """
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
85 # prepare train data
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
86 tr_features, tr_labels = separate_features_labels(args["train_data"])
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
87 # prepare test data
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
88 if args["testhaslabels"] == "haslabels":
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
89 te_features, te_labels = separate_features_labels(args["test_data"])
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
90 else:
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
91 te_features = pd.read_csv(args["test_data"], sep="\t")
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
92 te_labels = []
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
93 s_time = time.time()
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
94 if args["selected_task"] == "Classification":
3
33d53eb476fd planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
bgruening
parents: 2
diff changeset
95 classifier = TabPFNClassifier()
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
96 classifier.fit(tr_features, tr_labels)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
97 y_eval = classifier.predict(te_features)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
98 pred_probas_test = classifier.predict_proba(te_features)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
99 if len(te_labels) > 0:
4
e7b4afedc471 planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit f514c9038f1aac1ef0ca40a9f4866b1ad0fc7747
bgruening
parents: 3
diff changeset
100 classification_plot(te_labels, pred_probas_test)
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
101 else:
3
33d53eb476fd planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
bgruening
parents: 2
diff changeset
102 regressor = TabPFNRegressor()
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
103 regressor.fit(tr_features, tr_labels)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
104 y_eval = regressor.predict(te_features)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
105 if len(te_labels) > 0:
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
106 score = root_mean_squared_error(te_labels, y_eval)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
107 r2_metric_score = r2_score(te_labels, y_eval)
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
108 regression_plot(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
109 te_labels,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
110 y_eval,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
111 f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
112 "True values",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
113 "Predicted values",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
114 )
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
115 e_time = time.time()
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
116 print(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
117 "Time taken by TabPFN for training and prediction: {} seconds".format(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
118 e_time - s_time
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
119 )
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
120 )
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
121 te_features["predicted_labels"] = y_eval
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
122 te_features.to_csv("output_predicted_data", sep="\t", index=None)
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
123
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
124
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
125 if __name__ == "__main__":
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
126 arg_parser = argparse.ArgumentParser()
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
127 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data")
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
128 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data")
2
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
129 arg_parser.add_argument(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
130 "-testhaslabels",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
131 "--testhaslabels",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
132 required=True,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
133 help="if test data contain labels",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
134 )
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
135 arg_parser.add_argument(
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
136 "-selectedtask",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
137 "--selected_task",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
138 required=True,
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
139 help="Type of machine learning task",
c081e5e1d7ce planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit e87b82b59dced736af2f0d9de045c916400b7bc2
bgruening
parents: 0
diff changeset
140 )
0
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
141 # get argument values
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
142 args = vars(arg_parser.parse_args())
3dc3c7443c8e planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff changeset
143 train_evaluate(args)