Mercurial > repos > goeckslab > tabular_learner
comparison base_model_trainer.py @ 15:01e7c5481f13 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit f632803cda732005bdcf3ac3e8fe7a807a82c1d9
| author | goeckslab |
|---|---|
| date | Mon, 19 Jan 2026 05:54:52 +0000 |
| parents | bf0df21a1ea3 |
| children |
comparison
equal
deleted
inserted
replaced
| 14:edd515746388 | 15:01e7c5481f13 |
|---|---|
| 117 f"Target column number {self.target_col} is invalid. " | 117 f"Target column number {self.target_col} is invalid. " |
| 118 f"Please select a number between 1 and {num_cols}." | 118 f"Please select a number between 1 and {num_cols}." |
| 119 ) | 119 ) |
| 120 | 120 |
| 121 self.target = names[target_index] | 121 self.target = names[target_index] |
| 122 sample_id_column = getattr(self, "sample_id_column", None) | |
| 123 if sample_id_column: | |
| 124 sample_id_column = sample_id_column.replace(".", "_") | |
| 125 self.sample_id_column = sample_id_column | |
| 126 else: | |
| 127 self.sample_id_column = None | |
| 128 self.sample_id_series = None | |
| 122 | 129 |
| 123 # Conditional drop: only if 'prediction_label' exists and is not | 130 # Conditional drop: only if 'prediction_label' exists and is not |
| 124 # the target | 131 # the target |
| 125 if "prediction_label" in self.data.columns and ( | 132 if "prediction_label" in self.data.columns and ( |
| 126 self.data.columns[target_index] != "prediction_label" | 133 self.data.columns[target_index] != "prediction_label" |
| 152 | 159 |
| 153 # Update names after possible drop | 160 # Update names after possible drop |
| 154 names = self.data.columns.to_list() | 161 names = self.data.columns.to_list() |
| 155 LOG.info(f"Dataset columns after processing: {names}") | 162 LOG.info(f"Dataset columns after processing: {names}") |
| 156 | 163 |
| 157 self.features_name = [n for n in names if n != self.target] | 164 sample_id_valid = False |
| 158 self.plot_feature_names = self._select_plot_features(self.features_name) | 165 if sample_id_column: |
| 166 if sample_id_column not in self.data.columns: | |
| 167 LOG.warning( | |
| 168 "Sample ID column '%s' not found; proceeding without group-aware split.", | |
| 169 sample_id_column, | |
| 170 ) | |
| 171 sample_id_column = None | |
| 172 self.sample_id_column = None | |
| 173 elif sample_id_column == self.target: | |
| 174 LOG.warning( | |
| 175 "Sample ID column '%s' matches target column; skipping group-aware split.", | |
| 176 sample_id_column, | |
| 177 ) | |
| 178 sample_id_column = None | |
| 179 self.sample_id_column = None | |
| 180 else: | |
| 181 sample_id_valid = True | |
| 159 | 182 |
| 160 if self.test_file: | 183 if self.test_file: |
| 161 LOG.info(f"Loading test data from {self.test_file}") | 184 LOG.info(f"Loading test data from {self.test_file}") |
| 162 df_test = pd.read_csv( | 185 df_test = pd.read_csv( |
| 163 self.test_file, sep=None, engine="python" | 186 self.test_file, sep=None, engine="python" |
| 164 ) | 187 ) |
| 165 df_test.columns = df_test.columns.str.replace(".", "_") | 188 df_test.columns = df_test.columns.str.replace(".", "_") |
| 166 self.test_data = df_test | 189 self.test_data = df_test |
| 190 | |
| 191 if sample_id_valid and self.test_data is None: | |
| 192 train_size = getattr(self, "train_size", None) | |
| 193 if train_size is None: | |
| 194 train_size = 0.7 | |
| 195 if train_size <= 0 or train_size >= 1: | |
| 196 LOG.warning( | |
| 197 "Invalid train_size=%s; skipping group-aware split.", | |
| 198 train_size, | |
| 199 ) | |
| 200 else: | |
| 201 rng = np.random.RandomState(self.random_seed) | |
| 202 | |
| 203 def _allocate_split_counts(n_total: int, probs: list) -> list: | |
| 204 if n_total <= 0: | |
| 205 return [0 for _ in probs] | |
| 206 counts = [0 for _ in probs] | |
| 207 active = [i for i, p in enumerate(probs) if p > 0] | |
| 208 remainder = n_total | |
| 209 if active and n_total >= len(active): | |
| 210 for i in active: | |
| 211 counts[i] = 1 | |
| 212 remainder -= len(active) | |
| 213 if remainder > 0: | |
| 214 probs_arr = np.array(probs, dtype=float) | |
| 215 probs_arr = probs_arr / probs_arr.sum() | |
| 216 raw = remainder * probs_arr | |
| 217 floors = np.floor(raw).astype(int) | |
| 218 for i, value in enumerate(floors.tolist()): | |
| 219 counts[i] += value | |
| 220 leftover = remainder - int(floors.sum()) | |
| 221 if leftover > 0 and active: | |
| 222 frac = raw - floors | |
| 223 order = sorted(active, key=lambda i: (-frac[i], i)) | |
| 224 for i in range(leftover): | |
| 225 counts[order[i % len(order)]] += 1 | |
| 226 return counts | |
| 227 | |
| 228 def _choose_split(counts: list, targets: list, active: list) -> int: | |
| 229 remaining = [targets[i] - counts[i] for i in range(len(targets))] | |
| 230 best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i])) | |
| 231 if remaining[best] <= 0: | |
| 232 best = min(active, key=lambda i: counts[i]) | |
| 233 return best | |
| 234 | |
| 235 probs = [train_size, 1.0 - train_size] | |
| 236 targets = _allocate_split_counts(len(self.data), probs) | |
| 237 counts = [0, 0] | |
| 238 active = [0, 1] | |
| 239 train_idx = [] | |
| 240 test_idx = [] | |
| 241 | |
| 242 group_series = self.data[sample_id_column].astype(object) | |
| 243 missing_mask = group_series.isna() | |
| 244 if missing_mask.any(): | |
| 245 group_series = group_series.copy() | |
| 246 group_series.loc[missing_mask] = [ | |
| 247 f"__missing__{idx}" for idx in group_series.index[missing_mask] | |
| 248 ] | |
| 249 group_to_indices = {} | |
| 250 for idx, group_id in group_series.items(): | |
| 251 group_to_indices.setdefault(group_id, []).append(idx) | |
| 252 | |
| 253 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x)) | |
| 254 rng.shuffle(group_ids) | |
| 255 | |
| 256 for group_id in group_ids: | |
| 257 split_idx = _choose_split(counts, targets, active) | |
| 258 counts[split_idx] += len(group_to_indices[group_id]) | |
| 259 if split_idx == 0: | |
| 260 train_idx.extend(group_to_indices[group_id]) | |
| 261 else: | |
| 262 test_idx.extend(group_to_indices[group_id]) | |
| 263 | |
| 264 missing_splits = [] | |
| 265 if not train_idx: | |
| 266 missing_splits.append("train") | |
| 267 if not test_idx: | |
| 268 missing_splits.append("test") | |
| 269 if missing_splits: | |
| 270 LOG.warning( | |
| 271 "Group-aware split using '%s' produced empty %s set; " | |
| 272 "falling back to default split.", | |
| 273 sample_id_column, | |
| 274 " and ".join(missing_splits), | |
| 275 ) | |
| 276 else: | |
| 277 self.test_data = self.data.loc[test_idx].reset_index(drop=True) | |
| 278 self.data = self.data.loc[train_idx].reset_index(drop=True) | |
| 279 LOG.info( | |
| 280 "Applied group-aware split using '%s' (train=%s, test=%s).", | |
| 281 sample_id_column, | |
| 282 len(train_idx), | |
| 283 len(test_idx), | |
| 284 ) | |
| 285 | |
| 286 if sample_id_valid: | |
| 287 self.sample_id_series = self.data[sample_id_column].copy() | |
| 288 if sample_id_column in self.data.columns: | |
| 289 self.data = self.data.drop(columns=[sample_id_column]) | |
| 290 if self.test_data is not None and sample_id_column in self.test_data.columns: | |
| 291 self.test_data = self.test_data.drop(columns=[sample_id_column]) | |
| 292 | |
| 293 # Refresh feature lists after any sample-id column removal. | |
| 294 names = self.data.columns.to_list() | |
| 295 self.features_name = [n for n in names if n != self.target] | |
| 296 self.plot_feature_names = self._select_plot_features(self.features_name) | |
| 167 | 297 |
| 168 def _select_plot_features(self, all_features): | 298 def _select_plot_features(self, all_features): |
| 169 limit = getattr(self, "plot_feature_limit", 30) | 299 limit = getattr(self, "plot_feature_limit", 30) |
| 170 if not isinstance(limit, int) or limit <= 0: | 300 if not isinstance(limit, int) or limit <= 0: |
| 171 LOG.info( | 301 LOG.info( |
| 239 if val is not None: | 369 if val is not None: |
| 240 self.setup_params[attr] = val | 370 self.setup_params[attr] = val |
| 241 if getattr(self, "cross_validation_folds", None) is not None: | 371 if getattr(self, "cross_validation_folds", None) is not None: |
| 242 self.setup_params["fold"] = self.cross_validation_folds | 372 self.setup_params["fold"] = self.cross_validation_folds |
| 243 LOG.info(self.setup_params) | 373 LOG.info(self.setup_params) |
| 374 | |
| 375 group_series = getattr(self, "sample_id_series", None) | |
| 376 if group_series is not None and getattr(self, "cross_validation", None) is not False: | |
| 377 n_groups = pd.Series(group_series).nunique(dropna=False) | |
| 378 fold_count = getattr(self, "cross_validation_folds", None) | |
| 379 if fold_count is not None and fold_count > n_groups: | |
| 380 LOG.warning( | |
| 381 "cross_validation_folds=%s exceeds unique groups=%s; " | |
| 382 "skipping group-aware CV.", | |
| 383 fold_count, | |
| 384 n_groups, | |
| 385 ) | |
| 386 else: | |
| 387 self.setup_params["fold_strategy"] = "groupkfold" | |
| 388 self.setup_params["fold_groups"] = pd.Series(group_series).reset_index(drop=True) | |
| 389 LOG.info( | |
| 390 "Enabled group-aware CV with %s unique groups.", | |
| 391 n_groups, | |
| 392 ) | |
| 244 | 393 |
| 245 if self.task_type == "classification": | 394 if self.task_type == "classification": |
| 246 from pycaret.classification import ClassificationExperiment | 395 from pycaret.classification import ClassificationExperiment |
| 247 | 396 |
| 248 self.exp = ClassificationExperiment() | 397 self.exp = ClassificationExperiment() |
