comparison main.py @ 3:33d53eb476fd draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit c1c3b6d5abd35875890c45baacab073b5e749537
author bgruening
date Mon, 20 Jan 2025 15:45:17 +0000
parents c081e5e1d7ce
children e7b4afedc471
comparison
equal deleted inserted replaced
2:c081e5e1d7ce 3:33d53eb476fd
59 else: 59 else:
60 te_features = pd.read_csv(args["test_data"], sep="\t") 60 te_features = pd.read_csv(args["test_data"], sep="\t")
61 te_labels = [] 61 te_labels = []
62 s_time = time.time() 62 s_time = time.time()
63 if args["selected_task"] == "Classification": 63 if args["selected_task"] == "Classification":
64 classifier = TabPFNClassifier(device="cpu") 64 classifier = TabPFNClassifier()
65 classifier.fit(tr_features, tr_labels) 65 classifier.fit(tr_features, tr_labels)
66 y_eval = classifier.predict(te_features) 66 y_eval = classifier.predict(te_features)
67 pred_probas_test = classifier.predict_proba(te_features) 67 pred_probas_test = classifier.predict_proba(te_features)
68 if len(te_labels) > 0: 68 if len(te_labels) > 0:
69 precision, recall, thresholds = precision_recall_curve( 69 precision, recall, thresholds = precision_recall_curve(
79 "Precision-Recall Curve", 79 "Precision-Recall Curve",
80 "Recall", 80 "Recall",
81 "Precision", 81 "Precision",
82 ) 82 )
83 else: 83 else:
84 regressor = TabPFNRegressor(device="cpu") 84 regressor = TabPFNRegressor()
85 regressor.fit(tr_features, tr_labels) 85 regressor.fit(tr_features, tr_labels)
86 y_eval = regressor.predict(te_features) 86 y_eval = regressor.predict(te_features)
87 if len(te_labels) > 0: 87 if len(te_labels) > 0:
88 score = root_mean_squared_error(te_labels, y_eval) 88 score = root_mean_squared_error(te_labels, y_eval)
89 r2_metric_score = r2_score(te_labels, y_eval) 89 r2_metric_score = r2_score(te_labels, y_eval)