Mercurial > repos > bgruening > tabpfn
annotate main.py @ 0:3dc3c7443c8e draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
| author | bgruening | 
|---|---|
| date | Wed, 15 Jan 2025 12:33:49 +0000 | 
| parents | |
| children | c081e5e1d7ce | 
| rev | line source | 
|---|---|
| 
0
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
1 """ | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
2 Tabular data prediction using TabPFN | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
3 """ | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
4 import argparse | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
5 import time | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
6 | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
7 import matplotlib.pyplot as plt | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
8 import pandas as pd | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
9 from sklearn.metrics import accuracy_score, average_precision_score, precision_recall_curve | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
10 from tabpfn import TabPFNClassifier | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
11 | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
12 | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
13 def separate_features_labels(data): | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
14 df = pd.read_csv(data, sep="\t") | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
15 labels = df.iloc[:, -1] | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
16 features = df.iloc[:, :-1] | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
17 return features, labels | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
18 | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
19 | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
20 def train_evaluate(args): | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
21 """ | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
22 Train TabPFN | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
23 """ | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
24 tr_features, tr_labels = separate_features_labels(args["train_data"]) | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
25 te_features, te_labels = separate_features_labels(args["test_data"]) | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
26 classifier = TabPFNClassifier(device='cpu') | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
27 s_time = time.time() | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
28 classifier.fit(tr_features, tr_labels) | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
29 e_time = time.time() | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
30 print("Time taken by TabPFN for training: {} seconds".format(e_time - s_time)) | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
31 y_eval = classifier.predict(te_features) | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
32 print('Accuracy', accuracy_score(te_labels, y_eval)) | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
33 pred_probas_test = classifier.predict_proba(te_features) | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
34 te_features["predicted_labels"] = y_eval | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
35 te_features.to_csv("output_predicted_data", sep="\t", index=None) | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
36 precision, recall, thresholds = precision_recall_curve(te_labels, pred_probas_test[:, 1]) | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
37 average_precision = average_precision_score(te_labels, pred_probas_test[:, 1]) | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
38 plt.figure(figsize=(8, 6)) | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
39 plt.plot(recall, precision, label=f'Precision-Recall Curve (AP={average_precision:.2f})') | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
40 plt.xlabel('Recall') | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
41 plt.ylabel('Precision') | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
42 plt.title('Precision-Recall Curve') | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
43 plt.legend(loc='lower left') | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
44 plt.grid(True) | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
45 plt.savefig("output_prec_recall_curve.png") | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
46 | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
47 | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
48 if __name__ == "__main__": | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
49 arg_parser = argparse.ArgumentParser() | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
50 arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data") | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
51 arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data") | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
52 # get argument values | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
53 args = vars(arg_parser.parse_args()) | 
| 
 
3dc3c7443c8e
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit bce8b0297bff54e7e29a6106a7f385fd1318c0aa
 
bgruening 
parents:  
diff
changeset
 | 
54 train_evaluate(args) | 
