comparison base_model_trainer.py @ 15:01e7c5481f13 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit f632803cda732005bdcf3ac3e8fe7a807a82c1d9
author goeckslab
date Mon, 19 Jan 2026 05:54:52 +0000
parents bf0df21a1ea3
children
comparison
equal deleted inserted replaced
14:edd515746388 15:01e7c5481f13
117 f"Target column number {self.target_col} is invalid. " 117 f"Target column number {self.target_col} is invalid. "
118 f"Please select a number between 1 and {num_cols}." 118 f"Please select a number between 1 and {num_cols}."
119 ) 119 )
120 120
121 self.target = names[target_index] 121 self.target = names[target_index]
122 sample_id_column = getattr(self, "sample_id_column", None)
123 if sample_id_column:
124 sample_id_column = sample_id_column.replace(".", "_")
125 self.sample_id_column = sample_id_column
126 else:
127 self.sample_id_column = None
128 self.sample_id_series = None
122 129
123 # Conditional drop: only if 'prediction_label' exists and is not 130 # Conditional drop: only if 'prediction_label' exists and is not
124 # the target 131 # the target
125 if "prediction_label" in self.data.columns and ( 132 if "prediction_label" in self.data.columns and (
126 self.data.columns[target_index] != "prediction_label" 133 self.data.columns[target_index] != "prediction_label"
152 159
153 # Update names after possible drop 160 # Update names after possible drop
154 names = self.data.columns.to_list() 161 names = self.data.columns.to_list()
155 LOG.info(f"Dataset columns after processing: {names}") 162 LOG.info(f"Dataset columns after processing: {names}")
156 163
157 self.features_name = [n for n in names if n != self.target] 164 sample_id_valid = False
158 self.plot_feature_names = self._select_plot_features(self.features_name) 165 if sample_id_column:
166 if sample_id_column not in self.data.columns:
167 LOG.warning(
168 "Sample ID column '%s' not found; proceeding without group-aware split.",
169 sample_id_column,
170 )
171 sample_id_column = None
172 self.sample_id_column = None
173 elif sample_id_column == self.target:
174 LOG.warning(
175 "Sample ID column '%s' matches target column; skipping group-aware split.",
176 sample_id_column,
177 )
178 sample_id_column = None
179 self.sample_id_column = None
180 else:
181 sample_id_valid = True
159 182
160 if self.test_file: 183 if self.test_file:
161 LOG.info(f"Loading test data from {self.test_file}") 184 LOG.info(f"Loading test data from {self.test_file}")
162 df_test = pd.read_csv( 185 df_test = pd.read_csv(
163 self.test_file, sep=None, engine="python" 186 self.test_file, sep=None, engine="python"
164 ) 187 )
165 df_test.columns = df_test.columns.str.replace(".", "_") 188 df_test.columns = df_test.columns.str.replace(".", "_")
166 self.test_data = df_test 189 self.test_data = df_test
190
191 if sample_id_valid and self.test_data is None:
192 train_size = getattr(self, "train_size", None)
193 if train_size is None:
194 train_size = 0.7
195 if train_size <= 0 or train_size >= 1:
196 LOG.warning(
197 "Invalid train_size=%s; skipping group-aware split.",
198 train_size,
199 )
200 else:
201 rng = np.random.RandomState(self.random_seed)
202
203 def _allocate_split_counts(n_total: int, probs: list) -> list:
204 if n_total <= 0:
205 return [0 for _ in probs]
206 counts = [0 for _ in probs]
207 active = [i for i, p in enumerate(probs) if p > 0]
208 remainder = n_total
209 if active and n_total >= len(active):
210 for i in active:
211 counts[i] = 1
212 remainder -= len(active)
213 if remainder > 0:
214 probs_arr = np.array(probs, dtype=float)
215 probs_arr = probs_arr / probs_arr.sum()
216 raw = remainder * probs_arr
217 floors = np.floor(raw).astype(int)
218 for i, value in enumerate(floors.tolist()):
219 counts[i] += value
220 leftover = remainder - int(floors.sum())
221 if leftover > 0 and active:
222 frac = raw - floors
223 order = sorted(active, key=lambda i: (-frac[i], i))
224 for i in range(leftover):
225 counts[order[i % len(order)]] += 1
226 return counts
227
228 def _choose_split(counts: list, targets: list, active: list) -> int:
229 remaining = [targets[i] - counts[i] for i in range(len(targets))]
230 best = max(active, key=lambda i: (remaining[i], -counts[i], -targets[i]))
231 if remaining[best] <= 0:
232 best = min(active, key=lambda i: counts[i])
233 return best
234
235 probs = [train_size, 1.0 - train_size]
236 targets = _allocate_split_counts(len(self.data), probs)
237 counts = [0, 0]
238 active = [0, 1]
239 train_idx = []
240 test_idx = []
241
242 group_series = self.data[sample_id_column].astype(object)
243 missing_mask = group_series.isna()
244 if missing_mask.any():
245 group_series = group_series.copy()
246 group_series.loc[missing_mask] = [
247 f"__missing__{idx}" for idx in group_series.index[missing_mask]
248 ]
249 group_to_indices = {}
250 for idx, group_id in group_series.items():
251 group_to_indices.setdefault(group_id, []).append(idx)
252
253 group_ids = sorted(group_to_indices.keys(), key=lambda x: str(x))
254 rng.shuffle(group_ids)
255
256 for group_id in group_ids:
257 split_idx = _choose_split(counts, targets, active)
258 counts[split_idx] += len(group_to_indices[group_id])
259 if split_idx == 0:
260 train_idx.extend(group_to_indices[group_id])
261 else:
262 test_idx.extend(group_to_indices[group_id])
263
264 missing_splits = []
265 if not train_idx:
266 missing_splits.append("train")
267 if not test_idx:
268 missing_splits.append("test")
269 if missing_splits:
270 LOG.warning(
271 "Group-aware split using '%s' produced empty %s set; "
272 "falling back to default split.",
273 sample_id_column,
274 " and ".join(missing_splits),
275 )
276 else:
277 self.test_data = self.data.loc[test_idx].reset_index(drop=True)
278 self.data = self.data.loc[train_idx].reset_index(drop=True)
279 LOG.info(
280 "Applied group-aware split using '%s' (train=%s, test=%s).",
281 sample_id_column,
282 len(train_idx),
283 len(test_idx),
284 )
285
286 if sample_id_valid:
287 self.sample_id_series = self.data[sample_id_column].copy()
288 if sample_id_column in self.data.columns:
289 self.data = self.data.drop(columns=[sample_id_column])
290 if self.test_data is not None and sample_id_column in self.test_data.columns:
291 self.test_data = self.test_data.drop(columns=[sample_id_column])
292
293 # Refresh feature lists after any sample-id column removal.
294 names = self.data.columns.to_list()
295 self.features_name = [n for n in names if n != self.target]
296 self.plot_feature_names = self._select_plot_features(self.features_name)
167 297
168 def _select_plot_features(self, all_features): 298 def _select_plot_features(self, all_features):
169 limit = getattr(self, "plot_feature_limit", 30) 299 limit = getattr(self, "plot_feature_limit", 30)
170 if not isinstance(limit, int) or limit <= 0: 300 if not isinstance(limit, int) or limit <= 0:
171 LOG.info( 301 LOG.info(
239 if val is not None: 369 if val is not None:
240 self.setup_params[attr] = val 370 self.setup_params[attr] = val
241 if getattr(self, "cross_validation_folds", None) is not None: 371 if getattr(self, "cross_validation_folds", None) is not None:
242 self.setup_params["fold"] = self.cross_validation_folds 372 self.setup_params["fold"] = self.cross_validation_folds
243 LOG.info(self.setup_params) 373 LOG.info(self.setup_params)
374
375 group_series = getattr(self, "sample_id_series", None)
376 if group_series is not None and getattr(self, "cross_validation", None) is not False:
377 n_groups = pd.Series(group_series).nunique(dropna=False)
378 fold_count = getattr(self, "cross_validation_folds", None)
379 if fold_count is not None and fold_count > n_groups:
380 LOG.warning(
381 "cross_validation_folds=%s exceeds unique groups=%s; "
382 "skipping group-aware CV.",
383 fold_count,
384 n_groups,
385 )
386 else:
387 self.setup_params["fold_strategy"] = "groupkfold"
388 self.setup_params["fold_groups"] = pd.Series(group_series).reset_index(drop=True)
389 LOG.info(
390 "Enabled group-aware CV with %s unique groups.",
391 n_groups,
392 )
244 393
245 if self.task_type == "classification": 394 if self.task_type == "classification":
246 from pycaret.classification import ClassificationExperiment 395 from pycaret.classification import ClassificationExperiment
247 396
248 self.exp = ClassificationExperiment() 397 self.exp = ClassificationExperiment()