Mercurial > repos > bgruening > tabpfn
annotate main.py @ 0:3dc3c7443c8e draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
| author | bgruening |
|---|---|
| date | Wed, 15 Jan 2025 12:33:49 +0000 |
| parents | |
| children | c081e5e1d7ce |
| rev | line source |
|---|---|
|
0
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
1 """ |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
2 Tabular data prediction using TabPFN |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
3 """ |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
4 import argparse |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
5 import time |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
6 |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
7 import matplotlib.pyplot as plt |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
8 import pandas as pd |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
9 from sklearn.metrics import accuracy_score, average_precision_score, precision_recall_curve |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
10 from tabpfn import TabPFNClassifier |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
11 |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
12 |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
13 def separate_features_labels(data): |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
14 df = pd.read_csv(data, sep="\t") |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
15 labels = df.iloc[:, -1] |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
16 features = df.iloc[:, :-1] |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
17 return features, labels |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
18 |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
19 |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
20 def train_evaluate(args): |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
21 """ |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
22 Train TabPFN |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
23 """ |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
24 tr_features, tr_labels = separate_features_labels(args["train_data"]) |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
25 te_features, te_labels = separate_features_labels(args["test_data"]) |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
26 classifier = TabPFNClassifier(device='cpu') |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
27 s_time = time.time() |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
28 classifier.fit(tr_features, tr_labels) |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
29 e_time = time.time() |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
30 print("Time taken by TabPFN for training: {} seconds".format(e_time - s_time)) |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
31 y_eval = classifier.predict(te_features) |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
32 print('Accuracy', accuracy_score(te_labels, y_eval)) |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
33 pred_probas_test = classifier.predict_proba(te_features) |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
34 te_features["predicted_labels"] = y_eval |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
35 te_features.to_csv("output_predicted_data", sep="\t", index=None) |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
36 precision, recall, thresholds = precision_recall_curve(te_labels, pred_probas_test[:, 1]) |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
37 average_precision = average_precision_score(te_labels, pred_probas_test[:, 1]) |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
38 plt.figure(figsize=(8, 6)) |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
39 plt.plot(recall, precision, label=f'Precision-Recall Curve (AP={average_precision:.2f})') |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
40 plt.xlabel('Recall') |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
41 plt.ylabel('Precision') |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
42 plt.title('Precision-Recall Curve') |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
43 plt.legend(loc='lower left') |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
44 plt.grid(True) |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
45 plt.savefig("output_prec_recall_curve.png") |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
46 |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
47 |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
48 if __name__ == "__main__": |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
49 arg_parser = argparse.ArgumentParser() |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
50 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
51 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data") |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
52 # get argument values |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
53 args = vars(arg_parser.parse_args()) |
|
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
bgruening
parents:
diff
changeset
|
54 train_evaluate(args) |
