Mercurial > repos > bgruening > sklearn_train_test_eval
comparison train_test_split.py @ 15:2eb5c017958d draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 13:15:27 +0000 |
parents | caf7d2b71a48 |
children |
comparison
equal
deleted
inserted
replaced
14:4d1637cac794 | 15:2eb5c017958d |
---|---|
1 import argparse | 1 import argparse |
2 import json | 2 import json |
3 import warnings | 3 import warnings |
4 from distutils.version import LooseVersion as Version | |
4 | 5 |
5 import pandas as pd | 6 import pandas as pd |
7 from galaxy_ml import __version__ as galaxy_ml_version | |
6 from galaxy_ml.model_validations import train_test_split | 8 from galaxy_ml.model_validations import train_test_split |
7 from galaxy_ml.utils import get_cv, read_columns | 9 from galaxy_ml.utils import get_cv, read_columns |
8 | 10 |
9 | 11 |
10 def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None): | 12 def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None): |
67 col_index = target_input["col"][0] - 1 | 69 col_index = target_input["col"][0] - 1 |
68 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True) | 70 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True) |
69 y = df.iloc[:, col_index].values | 71 y = df.iloc[:, col_index].values |
70 | 72 |
71 # construct the cv splitter object | 73 # construct the cv splitter object |
72 splitter, groups = get_cv(params["mode_selection"]["cv_selector"]) | 74 cv_selector = params["mode_selection"]["cv_selector"] |
75 if Version(galaxy_ml_version) < Version("0.8.3"): | |
76 cv_selector.pop("n_stratification_bins", None) | |
77 splitter, groups = get_cv(cv_selector) | |
73 | 78 |
74 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) | 79 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) |
75 if nth_split > total_n_splits: | 80 if nth_split > total_n_splits: |
76 raise ValueError( | 81 raise ValueError( |
77 "Total number of splits is {}, but got `nth_split` " | 82 "Total number of splits is {}, but got `nth_split` " |