# HG changeset patch
# User goeckslab
# Date 1752023581 0
# Node ID f4cb41f458fd37a005c0f3685c070950e4c7b4b7
# Parent a32ff7201629386941e484aa03623a8e7629604e
planemo upload for repository https://github.com/goeckslab/gleam commit b430f8b466655878c3bf63b053655fdbf039ddb0
diff -r a32ff7201629 -r f4cb41f458fd base_model_trainer.py
--- a/base_model_trainer.py Wed Jul 02 19:00:03 2025 +0000
+++ b/base_model_trainer.py Wed Jul 09 01:13:01 2025 +0000
@@ -127,9 +127,11 @@
and self.cross_validation is not None
and self.cross_validation is False
):
- self.setup_params["cross_validation"] = self.cross_validation
+ logging.info(
+ "cross_validation is set to False. This will disable cross-validation."
+ )
- if hasattr(self, "cross_validation") and self.cross_validation is not None:
+ if hasattr(self, "cross_validation") and self.cross_validation:
if hasattr(self, "cross_validation_folds"):
self.setup_params["fold"] = self.cross_validation_folds
@@ -182,10 +184,11 @@
)
if hasattr(self, "models") and self.models is not None:
- self.best_model = self.exp.compare_models(include=self.models)
+ self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation)
else:
- self.best_model = self.exp.compare_models()
+ self.best_model = self.exp.compare_models(cross_validation=self.cross_validation)
self.results = self.exp.pull()
+
if self.task_type == "classification":
self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
@@ -314,7 +317,7 @@
html_content += (
""
'
'
- "
Model Metrics from Cross-Validation Set
"
+ f"Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}
"
f"Best Model: {model_name}
"
"The best model is selected by: Accuracy (Classification)"
" or R2 (Regression).
"
diff -r a32ff7201629 -r f4cb41f458fd feature_importance.py
--- a/feature_importance.py Wed Jul 02 19:00:03 2025 +0000
+++ b/feature_importance.py Wed Jul 09 01:13:01 2025 +0000
@@ -120,6 +120,9 @@
used_features = model.feature_name_
elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"):
used_features = model.booster_.feature_name()
+ elif hasattr(model, "feature_names_in_"):
+ # scikitālearn's standard attribute for the names of features used during fit
+ used_features = list(model.feature_names_in_)
else:
used_features = X_transformed.columns
@@ -130,7 +133,14 @@
plot_X = X_shap
plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)"
else:
- sampled_X = X_transformed[used_features].sample(100, random_state=42)
+ logging.warning(f"len(X_transformed) = {len(X_transformed)}")
+ max_samples = 100
+ n_samples = min(max_samples, len(X_transformed))
+ sampled_X = X_transformed[used_features].sample(
+ n=n_samples,
+ replace=False,
+ random_state=42
+ )
explainer = shap.KernelExplainer(model.predict, sampled_X)
shap_values = explainer.shap_values(sampled_X)
plot_X = sampled_X
diff -r a32ff7201629 -r f4cb41f458fd pycaret_train.py
--- a/pycaret_train.py Wed Jul 02 19:00:03 2025 +0000
+++ b/pycaret_train.py Wed Jul 09 01:13:01 2025 +0000
@@ -29,6 +29,9 @@
parser.add_argument("--cross_validation", action="store_true",
default=None,
help="Perform cross-validation for PyCaret setup")
+ parser.add_argument("--no_cross_validation", action="store_true",
+ default=None,
+ help="Don't perform cross-validation for PyCaret setup")
parser.add_argument("--cross_validation_folds", type=int,
default=None,
help="Number of cross-validation folds \
@@ -62,11 +65,15 @@
args = parser.parse_args()
+ cross_validation = True
+ if args.no_cross_validation:
+ cross_validation = False
+
model_kwargs = {
"train_size": args.train_size,
"normalize": args.normalize,
"feature_selection": args.feature_selection,
- "cross_validation": args.cross_validation,
+ "cross_validation": cross_validation,
"cross_validation_folds": args.cross_validation_folds,
"remove_outliers": args.remove_outliers,
"remove_multicollinearity": args.remove_multicollinearity,
diff -r a32ff7201629 -r f4cb41f458fd test-data/expected_best_model_classification_customized_cross_off.csv
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/expected_best_model_classification_customized_cross_off.csv Wed Jul 09 01:13:01 2025 +0000
@@ -0,0 +1,3 @@
+Parameter,Value
+priors,
+var_smoothing,1e-09
diff -r a32ff7201629 -r f4cb41f458fd test-data/expected_model_classification_customized_cross_off.h5
Binary file test-data/expected_model_classification_customized_cross_off.h5 has changed