Mercurial > repos > bgruening > tabpfn
comparison main.py @ 3:33d53eb476fd draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
author | bgruening |
---|---|
date | Mon, 20 Jan 2025 15:45:17 +0000 |
parents | c081e5e1d7ce |
children | e7b4afedc471 |
comparison
equal
deleted
inserted
replaced
2:c081e5e1d7ce | 3:33d53eb476fd |
---|---|
59 else: | 59 else: |
60 te_features = pd.read_csv(args["test_data"], sep="\t") | 60 te_features = pd.read_csv(args["test_data"], sep="\t") |
61 te_labels = [] | 61 te_labels = [] |
62 s_time = time.time() | 62 s_time = time.time() |
63 if args["selected_task"] == "Classification": | 63 if args["selected_task"] == "Classification": |
64 classifier = TabPFNClassifier(device="cpu") | 64 classifier = TabPFNClassifier() |
65 classifier.fit(tr_features, tr_labels) | 65 classifier.fit(tr_features, tr_labels) |
66 y_eval = classifier.predict(te_features) | 66 y_eval = classifier.predict(te_features) |
67 pred_probas_test = classifier.predict_proba(te_features) | 67 pred_probas_test = classifier.predict_proba(te_features) |
68 if len(te_labels) > 0: | 68 if len(te_labels) > 0: |
69 precision, recall, thresholds = precision_recall_curve( | 69 precision, recall, thresholds = precision_recall_curve( |
79 "Precision-Recall Curve", | 79 "Precision-Recall Curve", |
80 "Recall", | 80 "Recall", |
81 "Precision", | 81 "Precision", |
82 ) | 82 ) |
83 else: | 83 else: |
84 regressor = TabPFNRegressor(device="cpu") | 84 regressor = TabPFNRegressor() |
85 regressor.fit(tr_features, tr_labels) | 85 regressor.fit(tr_features, tr_labels) |
86 y_eval = regressor.predict(te_features) | 86 y_eval = regressor.predict(te_features) |
87 if len(te_labels) > 0: | 87 if len(te_labels) > 0: |
88 score = root_mean_squared_error(te_labels, y_eval) | 88 score = root_mean_squared_error(te_labels, y_eval) |
89 r2_metric_score = r2_score(te_labels, y_eval) | 89 r2_metric_score = r2_score(te_labels, y_eval) |