comparison train_test_split.py @ 34:1fe00785190d draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:44:18 +0000
parents 4b359039f09f
children
comparison
equal deleted inserted replaced
33:5d5d9cc554f9 34:1fe00785190d
1 import argparse 1 import argparse
2 import json 2 import json
3 import warnings 3 import warnings
4 from distutils.version import LooseVersion as Version
4 5
5 import pandas as pd 6 import pandas as pd
7 from galaxy_ml import __version__ as galaxy_ml_version
6 from galaxy_ml.model_validations import train_test_split 8 from galaxy_ml.model_validations import train_test_split
7 from galaxy_ml.utils import get_cv, read_columns 9 from galaxy_ml.utils import get_cv, read_columns
8 10
9 11
10 def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None): 12 def _get_single_cv_split(params, array, infile_labels=None, infile_groups=None):
67 col_index = target_input["col"][0] - 1 69 col_index = target_input["col"][0] - 1
68 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True) 70 df = pd.read_csv(infile_labels, sep="\t", header=header, parse_dates=True)
69 y = df.iloc[:, col_index].values 71 y = df.iloc[:, col_index].values
70 72
71 # construct the cv splitter object 73 # construct the cv splitter object
72 splitter, groups = get_cv(params["mode_selection"]["cv_selector"]) 74 cv_selector = params["mode_selection"]["cv_selector"]
75 if Version(galaxy_ml_version) < Version("0.8.3"):
76 cv_selector.pop("n_stratification_bins", None)
77 splitter, groups = get_cv(cv_selector)
73 78
74 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups) 79 total_n_splits = splitter.get_n_splits(array.values, y=y, groups=groups)
75 if nth_split > total_n_splits: 80 if nth_split > total_n_splits:
76 raise ValueError( 81 raise ValueError(
77 "Total number of splits is {}, but got `nth_split` " 82 "Total number of splits is {}, but got `nth_split` "