Mercurial > repos > muon-spectroscopy-computational-project > larch_select_paths
comparison larch_select_paths.py @ 1:7fdca938d90c draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_select_paths commit 1cf6d7160497ba58fe16a51f00d088a20934eba6
| author | muon-spectroscopy-computational-project |
|---|---|
| date | Wed, 06 Dec 2023 13:04:15 +0000 |
| parents | 2e827836f0ad |
| children | 204c4afe2f1e |
comparison
equal
deleted
inserted
replaced
| 0:2e827836f0ad | 1:7fdca938d90c |
|---|---|
| 1 import csv | 1 import csv |
| 2 import json | 2 import json |
| 3 import os | 3 import os |
| 4 import re | 4 import re |
| 5 import sys | 5 import sys |
| 6 from itertools import combinations | |
| 6 from zipfile import ZIP_DEFLATED, ZipFile | 7 from zipfile import ZIP_DEFLATED, ZipFile |
| 8 | |
| 9 | |
| 10 class CriteriaSelector: | |
| 11 def __init__(self, criteria: "dict[str, int|float]"): | |
| 12 self.max_number = criteria["max_number"] | |
| 13 self.max_path_length = criteria["max_path_length"] | |
| 14 self.min_amp_ratio = criteria["min_amplitude_ratio"] | |
| 15 self.max_degeneracy = criteria["max_degeneracy"] | |
| 16 self.path_count = 0 | |
| 17 | |
| 18 def evaluate(self, path_id: int, row: "list[str]") -> (bool, None): | |
| 19 if self.max_number and self.path_count >= self.max_number: | |
| 20 print(f"Reject path: {self.max_number} paths already reached") | |
| 21 return (False, None) | |
| 22 | |
| 23 r_effective = float(row[5].strip()) | |
| 24 if self.max_path_length and r_effective > self.max_path_length: | |
| 25 print(f"Reject path: {r_effective} > {self.max_path_length}") | |
| 26 return (False, None) | |
| 27 | |
| 28 amplitude_ratio = float(row[2].strip()) | |
| 29 if self.min_amp_ratio and (amplitude_ratio < self.min_amp_ratio): | |
| 30 print(f"Reject path: {amplitude_ratio} < {self.min_amp_ratio}") | |
| 31 return (False, None) | |
| 32 | |
| 33 degeneracy = float(row[3].strip()) | |
| 34 if self.max_degeneracy and degeneracy > self.max_degeneracy: | |
| 35 print(f"Reject path: {degeneracy} > {self.max_degeneracy}") | |
| 36 return (False, None) | |
| 37 | |
| 38 self.path_count += 1 | |
| 39 return (True, None) | |
| 40 | |
| 41 | |
| 42 class ManualSelector: | |
| 43 def __init__(self, selection: dict): | |
| 44 self.select_all = selection["selection"] == "all" | |
| 45 self.paths = selection["paths"] | |
| 46 self.path_values_ids = [path_value["id"] for path_value in self.paths] | |
| 47 | |
| 48 def evaluate(self, path_id: int, row: "list[str]") -> (bool, "None|dict"): | |
| 49 if path_id in self.path_values_ids: | |
| 50 return (True, self.paths[self.path_values_ids.index(path_id)]) | |
| 51 | |
| 52 if self.select_all or int(row[-1]): | |
| 53 return (True, None) | |
| 54 | |
| 55 return (False, None) | |
| 7 | 56 |
| 8 | 57 |
| 9 class GDSWriter: | 58 class GDSWriter: |
| 10 def __init__(self, default_variables: "dict[str, dict]"): | 59 def __init__(self, default_variables: "dict[str, dict]"): |
| 11 self.default_properties = { | 60 self.default_properties = { |
| 34 self.append_gds(name=name, value=value, vary=vary) | 83 self.append_gds(name=name, value=value, vary=vary) |
| 35 | 84 |
| 36 def append_gds( | 85 def append_gds( |
| 37 self, | 86 self, |
| 38 name: str, | 87 name: str, |
| 39 value: float = 0., | 88 value: float = 0.0, |
| 40 expr: str = None, | 89 expr: str = None, |
| 41 vary: bool = True, | 90 vary: bool = True, |
| 42 label: str = "", | 91 label: str = "", |
| 43 ): | 92 ): |
| 44 """Append a single GDS variable to the list of rows, later to be | 93 """Append a single GDS variable to the list of rows, later to be |
| 120 vary=self.default_properties[property_name]["vary"], | 169 vary=self.default_properties[property_name]["vary"], |
| 121 ) | 170 ) |
| 122 return auto_name | 171 return auto_name |
| 123 | 172 |
| 124 def write(self): | 173 def write(self): |
| 125 """Write GDS rows to file. | 174 """Write GDS rows to file.""" |
| 126 """ | |
| 127 with open("gds.csv", "w") as out: | 175 with open("gds.csv", "w") as out: |
| 128 out.writelines(self.rows) | 176 out.writelines(self.rows) |
| 129 | 177 |
| 130 | 178 |
| 131 class PathsWriter: | 179 class PathsWriter: |
| 133 self.rows = [ | 181 self.rows = [ |
| 134 f"{'id':>4s}, {'filename':>24s}, {'label':>24s}, {'s02':>3s}, " | 182 f"{'id':>4s}, {'filename':>24s}, {'label':>24s}, {'s02':>3s}, " |
| 135 f"{'e0':>4s}, {'sigma2':>24s}, {'deltar':>10s}\n" | 183 f"{'e0':>4s}, {'sigma2':>24s}, {'deltar':>10s}\n" |
| 136 ] | 184 ] |
| 137 self.gds_writer = GDSWriter(default_variables=default_variables) | 185 self.gds_writer = GDSWriter(default_variables=default_variables) |
| 186 self.all_combinations = [[0]] # 0 corresponds to the header row | |
| 138 | 187 |
| 139 def parse_feff_output( | 188 def parse_feff_output( |
| 140 self, | 189 self, |
| 141 paths_file: str, | 190 paths_file: str, |
| 142 selection: "dict[str, str|list]", | 191 selection: "dict[str, str|list]", |
| 149 selection (dict[str, str|list]): Dictionary indicating which paths | 198 selection (dict[str, str|list]): Dictionary indicating which paths |
| 150 to select, and how to define their variables. | 199 to select, and how to define their variables. |
| 151 directory_label (str, optional): Label to indicate paths from a | 200 directory_label (str, optional): Label to indicate paths from a |
| 152 separate directory. Defaults to "". | 201 separate directory. Defaults to "". |
| 153 """ | 202 """ |
| 154 paths = selection["paths"] | 203 combinations_list = [] |
| 155 path_values_ids = [path_value["id"] for path_value in paths] | 204 if selection["selection"] in {"criteria", "combinations"}: |
| 156 | 205 selector = CriteriaSelector(selection) |
| 206 else: | |
| 207 selector = ManualSelector(selection) | |
| 208 | |
| 209 selected_ids = self.select_rows(paths_file, directory_label, selector) | |
| 210 | |
| 211 if selection["selection"] == "combinations": | |
| 212 min_number = selection["min_combination_size"] | |
| 213 min_number = min(min_number, len(selected_ids)) | |
| 214 max_number = selection["max_combination_size"] | |
| 215 if not max_number or max_number > len(selected_ids): | |
| 216 max_number = len(selected_ids) | |
| 217 | |
| 218 for number_of_paths in range(min_number, max_number + 1): | |
| 219 for combination in combinations(selected_ids, number_of_paths): | |
| 220 combinations_list.append(combination) | |
| 221 | |
| 222 new_combinations = len(combinations_list) | |
| 223 print( | |
| 224 f"{new_combinations} combinations for {directory_label}:\n" | |
| 225 f"{combinations_list}" | |
| 226 ) | |
| 227 old_combinations_len = len(self.all_combinations) | |
| 228 self.all_combinations *= new_combinations | |
| 229 for i, combination in enumerate(self.all_combinations): | |
| 230 new_combinations = combinations_list[i // old_combinations_len] | |
| 231 self.all_combinations[i] = combination + list(new_combinations) | |
| 232 else: | |
| 233 for combination in self.all_combinations: | |
| 234 combination.extend(selected_ids) | |
| 235 | |
| 236 def select_rows( | |
| 237 self, | |
| 238 paths_file: str, | |
| 239 directory_label: str, | |
| 240 selector: "CriteriaSelector|ManualSelector", | |
| 241 ) -> "list[int]": | |
| 242 """Evaluate each row in turn to decide whether or not it should be | |
| 243 included in the final output. Does not account for combinations. | |
| 244 | |
| 245 Args: | |
| 246 paths_file (str): CSV summary filename. | |
| 247 directory_label (str): Label to indicate paths from a separate | |
| 248 directory. | |
| 249 selector (CriteriaSelector|ManualSelector): Object to evaluate | |
| 250 whether to select each path or not. | |
| 251 | |
| 252 Returns: | |
| 253 list[int]: The ids of the selected rows. | |
| 254 """ | |
| 255 row_ids = [] | |
| 157 with open(paths_file) as file: | 256 with open(paths_file) as file: |
| 158 reader = csv.reader(file) | 257 reader = csv.reader(file) |
| 159 for row in reader: | 258 for row in reader: |
| 160 id_match = re.search(r"\d+", row[0]) | 259 id_match = re.search(r"\d+", row[0]) |
| 161 if id_match: | 260 if id_match: |
| 162 path_id = int(id_match.group()) | 261 path_id = int(id_match.group()) |
| 163 filename = row[0].strip() | 262 selected, path_value = selector.evaluate( |
| 164 path_label = row[-2].strip() | 263 path_id=path_id, |
| 165 variables = {} | 264 row=row, |
| 166 | 265 ) |
| 167 if path_id in path_values_ids: | 266 if selected: |
| 168 path_value = paths[path_values_ids.index(path_id)] | 267 filename = row[0].strip() |
| 169 for property in self.gds_writer.default_properties: | 268 path_label = row[-2].strip() |
| 170 variables[property] = self.gds_writer.parse_gds( | 269 row_id = self.parse_row( |
| 171 property_name=property, | 270 directory_label, filename, path_label, path_value |
| 172 variable_name=path_value[property]["name"], | |
| 173 path_variable=path_value[property], | |
| 174 directory_label=directory_label, | |
| 175 path_label=path_label, | |
| 176 ) | |
| 177 self.parse_selected_path( | |
| 178 filename=filename, | |
| 179 path_label=path_label, | |
| 180 directory_label=directory_label, | |
| 181 **variables, | |
| 182 ) | 271 ) |
| 183 elif selection["selection"] == "all" or int(row[-1]): | 272 row_ids.append(row_id) |
| 184 path_value = None | 273 |
| 185 for property in self.gds_writer.default_properties: | 274 return row_ids |
| 186 variables[property] = self.gds_writer.parse_gds( | 275 |
| 187 property_name=property, | 276 def parse_row( |
| 188 directory_label=directory_label, | 277 self, |
| 189 path_label=path_label, | 278 directory_label: str, |
| 190 ) | 279 filename: str, |
| 191 self.parse_selected_path( | 280 path_label: str, |
| 192 filename=filename, | 281 path_value: "None|dict", |
| 193 path_label=path_label, | 282 ) -> int: |
| 194 directory_label=directory_label, | 283 """Parse row for GDS and path information. |
| 195 **variables, | 284 |
| 196 ) | 285 Args: |
| 286 directory_label (str): Label to indicate paths from a separate | |
| 287 directory. | |
| 288 filename (str): Filename for the FEFF path, extracted from row. | |
| 289 path_label (str): Label for the FEFF path, extracted from row. | |
| 290 path_value (None|dict): The values associated with the selected | |
| 291 FEFF path. May be None in which case defaults are used. | |
| 292 | |
| 293 Returns: | |
| 294 int: The id of the added row. | |
| 295 """ | |
| 296 variables = {} | |
| 297 if path_value is not None: | |
| 298 for property in self.gds_writer.default_properties: | |
| 299 variables[property] = self.gds_writer.parse_gds( | |
| 300 property_name=property, | |
| 301 variable_name=path_value[property]["name"], | |
| 302 path_variable=path_value[property], | |
| 303 directory_label=directory_label, | |
| 304 path_label=path_label, | |
| 305 ) | |
| 306 else: | |
| 307 for property in self.gds_writer.default_properties: | |
| 308 variables[property] = self.gds_writer.parse_gds( | |
| 309 property_name=property, | |
| 310 directory_label=directory_label, | |
| 311 path_label=path_label, | |
| 312 ) | |
| 313 | |
| 314 return self.parse_selected_path( | |
| 315 filename=filename, | |
| 316 path_label=path_label, | |
| 317 directory_label=directory_label, | |
| 318 **variables, | |
| 319 ) | |
| 197 | 320 |
| 198 def parse_selected_path( | 321 def parse_selected_path( |
| 199 self, | 322 self, |
| 200 filename: str, | 323 filename: str, |
| 201 path_label: str, | 324 path_label: str, |
| 202 directory_label: str = "", | 325 directory_label: str = "", |
| 203 s02: str = "s02", | 326 s02: str = "s02", |
| 204 e0: str = "e0", | 327 e0: str = "e0", |
| 205 sigma2: str = "sigma2", | 328 sigma2: str = "sigma2", |
| 206 deltar: str = "alpha*reff", | 329 deltar: str = "alpha*reff", |
| 207 ): | 330 ) -> int: |
| 208 """Format and append row representing a selected FEFF path. | 331 """Format and append row representing a selected FEFF path. |
| 209 | 332 |
| 210 Args: | 333 Args: |
| 211 filename (str): Name of the underlying FEFF path file, without | 334 filename (str): Name of the underlying FEFF path file, without |
| 212 parent directory. | 335 parent directory. |
| 218 e0 (str, optional): Energy shift variable name. Defaults to "e0". | 341 e0 (str, optional): Energy shift variable name. Defaults to "e0". |
| 219 sigma2 (str, optional): Mean squared displacement variable name. | 342 sigma2 (str, optional): Mean squared displacement variable name. |
| 220 Defaults to "sigma2". | 343 Defaults to "sigma2". |
| 221 deltar (str, optional): Change in path length variable. | 344 deltar (str, optional): Change in path length variable. |
| 222 Defaults to "alpha*reff". | 345 Defaults to "alpha*reff". |
| 346 | |
| 347 Returns: | |
| 348 int: The id of the added row. | |
| 223 """ | 349 """ |
| 224 if directory_label: | 350 if directory_label: |
| 225 filename = os.path.join(directory_label, filename) | 351 filename = os.path.join(directory_label, filename) |
| 226 label = f"{directory_label}.{path_label}" | 352 label = f"{directory_label}.{path_label}" |
| 227 else: | 353 else: |
| 228 filename = os.path.join("feff", filename) | 354 filename = os.path.join("feff", filename) |
| 229 label = path_label | 355 label = path_label |
| 230 | 356 |
| 357 row_id = len(self.rows) | |
| 231 self.rows.append( | 358 self.rows.append( |
| 232 f"{len(self.rows):>4d}, {filename:>24s}, {label:>24s}, " | 359 f"{row_id:>4d}, {filename:>24s}, {label:>24s}, " |
| 233 f"{s02:>3s}, {e0:>4s}, {sigma2:>24s}, {deltar:>10s}\n" | 360 f"{s02:>3s}, {e0:>4s}, {sigma2:>24s}, {deltar:>10s}\n" |
| 234 ) | 361 ) |
| 235 | 362 |
| 363 return row_id | |
| 364 | |
| 236 def write(self): | 365 def write(self): |
| 237 """Write selected path and GDS rows to file. | 366 """Write selected path and GDS rows to file.""" |
| 238 """ | |
| 239 self.gds_writer.write() | 367 self.gds_writer.write() |
| 240 with open("sp.csv", "w") as out: | 368 |
| 241 out.writelines(self.rows) | 369 if len(self.all_combinations) == 1: |
| 370 with open("sp.csv", "w") as out: | |
| 371 out.writelines(self.rows) | |
| 372 else: | |
| 373 for combination in self.all_combinations: | |
| 374 filename = "_".join([str(c) for c in combination[1:]]) | |
| 375 print(f"Writing combination {filename}") | |
| 376 with open(f"sp/{filename}.csv", "w") as out: | |
| 377 for row_id, row in enumerate(self.rows): | |
| 378 if row_id in combination: | |
| 379 out.write(row) | |
| 242 | 380 |
| 243 | 381 |
| 244 def main(input_values: dict): | 382 def main(input_values: dict): |
| 245 """Select paths and define GDS parameters. | 383 """Select paths and define GDS parameters. |
| 246 | 384 |
| 263 else: | 401 else: |
| 264 zfill_length = len(str(len(input_values["feff_outputs"]))) | 402 zfill_length = len(str(len(input_values["feff_outputs"]))) |
| 265 labels = set() | 403 labels = set() |
| 266 with ZipFile("merged.zip", "x", ZIP_DEFLATED) as zipfile_out: | 404 with ZipFile("merged.zip", "x", ZIP_DEFLATED) as zipfile_out: |
| 267 for i, feff_output in enumerate(input_values["feff_outputs"]): | 405 for i, feff_output in enumerate(input_values["feff_outputs"]): |
| 268 label = feff_output.pop("label") or str(i + 1).zfill( | 406 label = feff_output["label"] |
| 269 zfill_length | 407 if not label: |
| 270 ) | 408 label = str(i + 1).zfill(zfill_length) |
| 271 if label in labels: | 409 if label in labels: |
| 272 raise ValueError(f"Label '{label}' is not unique") | 410 raise ValueError(f"Label '{label}' is not unique") |
| 273 labels.add(label) | 411 labels.add(label) |
| 274 | 412 |
| 275 writer.parse_feff_output( | 413 writer.parse_feff_output( |
| 281 with ZipFile(feff_output["paths_zip"]) as z: | 419 with ZipFile(feff_output["paths_zip"]) as z: |
| 282 for zipinfo in z.infolist(): | 420 for zipinfo in z.infolist(): |
| 283 if zipinfo.filename != "feff/": | 421 if zipinfo.filename != "feff/": |
| 284 zipinfo.filename = zipinfo.filename[5:] | 422 zipinfo.filename = zipinfo.filename[5:] |
| 285 z.extract(member=zipinfo, path=label) | 423 z.extract(member=zipinfo, path=label) |
| 286 zipfile_out.write( | 424 filename = os.path.join(label, zipinfo.filename) |
| 287 os.path.join(label, zipinfo.filename) | 425 zipfile_out.write(filename) |
| 288 ) | |
| 289 | 426 |
| 290 writer.write() | 427 writer.write() |
| 291 | 428 |
| 292 | 429 |
| 293 if __name__ == "__main__": | 430 if __name__ == "__main__": |
