Repository 'tabpfn'
hg clone https://radegast.galaxyproject.org/repos/bgruening/tabpfn

Changeset 10:f0c7f0bad621 (2026-04-22)
Previous changeset 9:ed78e1448387 (2026-04-20)
Commit message:
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/tabpfn commit 5f1f7b83ced6c25d0024de3fbcc63f3f9e25373f
modified:
main.py
b
diff -r ed78e1448387 -r f0c7f0bad621 main.py
--- a/main.py Mon Apr 20 08:08:59 2026 +0000
+++ b/main.py Wed Apr 22 21:52:07 2026 +0000
[
@@ -84,6 +84,8 @@
     """
     Train TabPFN and predict
     """
+    MAX_IGNORE_PRETRAINING_LIMITS_SAMPLES = 1000
+    SEED = 42
     # prepare train data
     tr_features, tr_labels = separate_features_labels(args["train_data"], args["train_header"])
     # prepare test data
@@ -94,7 +96,10 @@
         te_labels = []
     s_time = time.time()
     if args["selected_task"] == "Classification":
-        classifier = TabPFNClassifier(random_state=42, model_path=args["model_path"])
+        if tr_features.shape[0] <= MAX_IGNORE_PRETRAINING_LIMITS_SAMPLES:
+            classifier = TabPFNClassifier(random_state=SEED, model_path=args["model_path"])
+        else:
+            classifier = TabPFNClassifier(random_state=SEED, model_path=args["model_path"], ignore_pretraining_limits=True)
         classifier.fit(tr_features, tr_labels)
         y_eval = classifier.predict(te_features)
         pred_probas_test = classifier.predict_proba(te_features)
@@ -105,7 +110,10 @@
             "output_predicted_data", sep="\t", index=None
         )
     else:
-        regressor = TabPFNRegressor(random_state=42, model_path=args["model_path"])
+        if tr_features.shape[0] <= MAX_IGNORE_PRETRAINING_LIMITS_SAMPLES:
+            regressor = TabPFNRegressor(random_state=SEED, model_path=args["model_path"])
+        else:
+            regressor = TabPFNRegressor(random_state=SEED, model_path=args["model_path"], ignore_pretraining_limits=True)
         regressor.fit(tr_features, tr_labels)
         y_eval = regressor.predict(te_features)
         if len(te_labels) > 0: