Mercurial > repos > muon-spectroscopy-computational-project > larch_athena
comparison larch_athena.py @ 0:ae2f265ecf8e draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_athena commit 5be486890442dedfb327289d597e1c8110240735
| author | muon-spectroscopy-computational-project |
|---|---|
| date | Tue, 14 Nov 2023 15:34:40 +0000 |
| parents | |
| children | 2b3115342fef |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:ae2f265ecf8e |
|---|---|
| 1 import gc | |
| 2 import json | |
| 3 import os | |
| 4 import re | |
| 5 import sys | |
| 6 | |
| 7 from common import read_group | |
| 8 | |
| 9 from larch.io import ( | |
| 10 create_athena, | |
| 11 h5group, | |
| 12 merge_groups, | |
| 13 read_ascii, | |
| 14 set_array_labels, | |
| 15 ) | |
| 16 from larch.symboltable import Group | |
| 17 from larch.xafs import autobk, pre_edge, rebin_xafs, xftf | |
| 18 | |
| 19 import matplotlib | |
| 20 import matplotlib.pyplot as plt | |
| 21 | |
| 22 import numpy as np | |
| 23 | |
| 24 | |
| 25 class Reader: | |
| 26 def __init__( | |
| 27 self, | |
| 28 energy_column: str, | |
| 29 mu_column: str, | |
| 30 xftf_params: dict, | |
| 31 data_format: str, | |
| 32 extract_group: str = None, | |
| 33 ): | |
| 34 self.energy_column = energy_column | |
| 35 self.mu_column = mu_column | |
| 36 self.xftf_params = xftf_params | |
| 37 self.data_format = data_format | |
| 38 self.extract_group = extract_group | |
| 39 | |
| 40 def load_data( | |
| 41 self, | |
| 42 dat_file: str, | |
| 43 merge_inputs: bool, | |
| 44 is_zipped: bool, | |
| 45 ) -> "dict[str, Group]": | |
| 46 if merge_inputs: | |
| 47 out_group = self.merge_files( | |
| 48 dat_files=dat_file, is_zipped=is_zipped | |
| 49 ) | |
| 50 return {"out": out_group} | |
| 51 else: | |
| 52 return self.load_single_file( | |
| 53 filepath=dat_file, is_zipped=is_zipped | |
| 54 ) | |
| 55 | |
| 56 def merge_files( | |
| 57 self, | |
| 58 dat_files: str, | |
| 59 is_zipped: bool, | |
| 60 ) -> Group: | |
| 61 if is_zipped: | |
| 62 all_groups = list(self.load_zipped_files().values()) | |
| 63 else: | |
| 64 all_groups = [] | |
| 65 for filepath in dat_files.split(","): | |
| 66 group = self.load_single_file(filepath)["out"] | |
| 67 all_groups.append(group) | |
| 68 | |
| 69 return merge_groups(all_groups, xarray="energy", yarray="mu") | |
| 70 | |
| 71 def load_single_file( | |
| 72 self, | |
| 73 filepath: str, | |
| 74 is_zipped: bool = False, | |
| 75 ) -> "dict[str,Group]": | |
| 76 if is_zipped: | |
| 77 return self.load_zipped_files() | |
| 78 | |
| 79 print(f"Attempting to read from {filepath}") | |
| 80 if self.data_format == "athena": | |
| 81 group = read_group(filepath, self.extract_group, self.xftf_params) | |
| 82 else: | |
| 83 # Try ascii anyway | |
| 84 try: | |
| 85 group = self.load_ascii(filepath) | |
| 86 if not group.array_labels: | |
| 87 # In later versions of larch, won't get a type error it | |
| 88 # will just fail to load any data | |
| 89 group = self.load_h5(filepath) | |
| 90 except (UnicodeDecodeError, TypeError): | |
| 91 # Indicates this isn't plaintext, try h5 | |
| 92 group = self.load_h5(filepath) | |
| 93 return {"out": group} | |
| 94 | |
| 95 def load_ascii(self, dat_file): | |
| 96 with open(dat_file) as f: | |
| 97 labels = None | |
| 98 last_line = None | |
| 99 line = f.readline() | |
| 100 while line: | |
| 101 if not line.startswith("#"): | |
| 102 if last_line is not None and last_line.find("\t") > 0: | |
| 103 labels = [] | |
| 104 for label in last_line.split("\t"): | |
| 105 labels.append(label.strip()) | |
| 106 break | |
| 107 | |
| 108 last_line = line | |
| 109 line = f.readline() | |
| 110 | |
| 111 xas_data = read_ascii(filename=dat_file, labels=labels) | |
| 112 xas_data = self.rename_cols(xas_data) | |
| 113 return xas_data | |
| 114 | |
| 115 def load_h5(self, dat_file): | |
| 116 h5_group = h5group(fname=dat_file, mode="r") | |
| 117 energy = h5_group.entry1.instrument.qexafs_energy.qexafs_energy | |
| 118 mu = h5_group.entry1.instrument.qexafs_counterTimer01.lnI0It | |
| 119 xafs_group = Group(data=np.array([energy[:], mu[:]])) | |
| 120 set_array_labels(xafs_group, ["energy", "mu"]) | |
| 121 return xafs_group | |
| 122 | |
| 123 def load_zipped_files(self) -> "dict[str, Group]": | |
| 124 def sorting_key(filename: str) -> str: | |
| 125 return re.findall(r"\d+", filename)[-1] | |
| 126 | |
| 127 all_paths = list(os.walk("dat_files")) | |
| 128 all_paths.sort(key=lambda x: x[0]) | |
| 129 file_total = sum([len(f) for _, _, f in all_paths]) | |
| 130 print(f"{file_total} files found") | |
| 131 key_length = len(str(file_total)) | |
| 132 i = 0 | |
| 133 keyed_data = {} | |
| 134 for dirpath, _, filenames in all_paths: | |
| 135 try: | |
| 136 filenames.sort(key=sorting_key) | |
| 137 except IndexError as e: | |
| 138 print( | |
| 139 "WARNING: Unable to sort files numerically, " | |
| 140 f"defaulting to sorting alphabetically:\n{e}" | |
| 141 ) | |
| 142 filenames.sort() | |
| 143 | |
| 144 for filename in filenames: | |
| 145 key = str(i).zfill(key_length) | |
| 146 filepath = os.path.join(dirpath, filename) | |
| 147 xas_data = self.load_single_file(filepath) | |
| 148 keyed_data[key] = xas_data["out"] | |
| 149 i += 1 | |
| 150 | |
| 151 return keyed_data | |
| 152 | |
| 153 def rename_cols(self, xafs_group: Group) -> Group: | |
| 154 labels = [label.lower() for label in xafs_group.array_labels] | |
| 155 print(f"Read columns: {labels}") | |
| 156 | |
| 157 if "energy" in labels: | |
| 158 print("'energy' present in column headers") | |
| 159 elif self.energy_column is not None: | |
| 160 if self.energy_column.lower() in labels: | |
| 161 labels[labels.index(self.energy_column.lower())] = "energy" | |
| 162 else: | |
| 163 raise ValueError(f"{self.energy_column} not found in {labels}") | |
| 164 else: | |
| 165 for i, label in enumerate(labels): | |
| 166 if label == "col1" or label.endswith("energy"): | |
| 167 labels[i] = "energy" | |
| 168 break | |
| 169 | |
| 170 if "mu" in labels: | |
| 171 print("'mu' present in column headers") | |
| 172 elif self.mu_column is not None: | |
| 173 if self.mu_column.lower() in labels: | |
| 174 labels[labels.index(self.mu_column.lower())] = "mu" | |
| 175 else: | |
| 176 raise ValueError(f"{self.mu_column} not found in {labels}") | |
| 177 else: | |
| 178 for i, label in enumerate(labels): | |
| 179 if label in ["col2", "xmu", "lni0it", "ffi0"]: | |
| 180 labels[i] = "mu" | |
| 181 break | |
| 182 | |
| 183 if labels != xafs_group.array_labels: | |
| 184 print(f"Renaming columns to: {labels}") | |
| 185 return set_array_labels(xafs_group, labels) | |
| 186 else: | |
| 187 return xafs_group | |
| 188 | |
| 189 | |
| 190 def calibrate_energy( | |
| 191 xafs_group: Group, | |
| 192 energy_0: float, | |
| 193 energy_min: float, | |
| 194 energy_max: float, | |
| 195 energy_format: str, | |
| 196 ): | |
| 197 if energy_0 is not None: | |
| 198 print(f"Recalibrating energy edge from {xafs_group.e0} to {energy_0}") | |
| 199 xafs_group.energy = xafs_group.energy + energy_0 - xafs_group.e0 | |
| 200 xafs_group.e0 = energy_0 | |
| 201 | |
| 202 if not (energy_min or energy_max): | |
| 203 return xafs_group | |
| 204 | |
| 205 if energy_min: | |
| 206 if energy_format == "relative": | |
| 207 energy_min += xafs_group.e0 | |
| 208 index_min = np.searchsorted(xafs_group.energy, energy_min) | |
| 209 else: | |
| 210 index_min = 0 | |
| 211 | |
| 212 if energy_max: | |
| 213 if energy_format == "relative": | |
| 214 energy_max += xafs_group.e0 | |
| 215 index_max = np.searchsorted(xafs_group.energy, energy_max) | |
| 216 else: | |
| 217 index_max = len(xafs_group.energy) | |
| 218 | |
| 219 print( | |
| 220 f"Cropping energy range from {energy_min} to {energy_max}, " | |
| 221 f"index {index_min} to {index_max}" | |
| 222 ) | |
| 223 try: | |
| 224 xafs_group.dmude = xafs_group.dmude[index_min:index_max] | |
| 225 xafs_group.pre_edge = xafs_group.pre_edge[index_min:index_max] | |
| 226 xafs_group.post_edge = xafs_group.post_edge[index_min:index_max] | |
| 227 xafs_group.flat = xafs_group.flat[index_min:index_max] | |
| 228 except AttributeError: | |
| 229 pass | |
| 230 | |
| 231 xafs_group.energy = xafs_group.energy[index_min:index_max] | |
| 232 xafs_group.mu = xafs_group.mu[index_min:index_max] | |
| 233 | |
| 234 # Sanity check | |
| 235 if len(xafs_group.energy) == 0: | |
| 236 raise ValueError("Energy cropping led to an empty array") | |
| 237 | |
| 238 return xafs_group | |
| 239 | |
| 240 | |
| 241 def main( | |
| 242 xas_data: Group, | |
| 243 input_values: dict, | |
| 244 path_key: str = "out", | |
| 245 ): | |
| 246 energy_0 = input_values["variables"]["energy_0"] | |
| 247 if energy_0 is None and hasattr(xas_data, "e0"): | |
| 248 energy_0 = xas_data.e0 | |
| 249 | |
| 250 energy_format = input_values["variables"]["energy_format"] | |
| 251 pre1 = input_values["variables"]["pre1"] | |
| 252 pre2 = input_values["variables"]["pre2"] | |
| 253 pre1 = validate_pre(pre1, energy_0, energy_format) | |
| 254 pre2 = validate_pre(pre2, energy_0, energy_format) | |
| 255 | |
| 256 pre_edge( | |
| 257 energy=xas_data.energy, | |
| 258 mu=xas_data.mu, | |
| 259 group=xas_data, | |
| 260 e0=energy_0, | |
| 261 pre1=pre1, | |
| 262 pre2=pre2, | |
| 263 ) | |
| 264 | |
| 265 energy_min = input_values["variables"]["energy_min"] | |
| 266 energy_max = input_values["variables"]["energy_max"] | |
| 267 xas_data = calibrate_energy( | |
| 268 xas_data, | |
| 269 energy_0, | |
| 270 energy_min, | |
| 271 energy_max, | |
| 272 energy_format=energy_format, | |
| 273 ) | |
| 274 | |
| 275 if input_values["rebin"]: | |
| 276 print(xas_data.energy, xas_data.mu) | |
| 277 rebin_xafs(energy=xas_data.energy, mu=xas_data.mu, group=xas_data) | |
| 278 xas_data = xas_data.rebinned | |
| 279 pre_edge(energy=xas_data.energy, mu=xas_data.mu, group=xas_data) | |
| 280 | |
| 281 try: | |
| 282 autobk(xas_data) | |
| 283 except ValueError as e: | |
| 284 raise ValueError( | |
| 285 f"autobk failed with energy={xas_data.energy}, mu={xas_data.mu}.\n" | |
| 286 "This may occur if the edge is not included in the above ranges." | |
| 287 ) from e | |
| 288 xftf(xas_data, **xftf_params) | |
| 289 | |
| 290 if input_values["plot_graph"]: | |
| 291 plot_edge_fits(f"edge/{path_key}.png", xas_data) | |
| 292 plot_flattened(f"flat/{path_key}.png", xas_data) | |
| 293 plot_derivative(f"derivative/{path_key}.png", xas_data) | |
| 294 | |
| 295 xas_project = create_athena(f"prj/{path_key}.prj") | |
| 296 xas_project.add_group(xas_data) | |
| 297 if input_values["annotation"]: | |
| 298 group = next(iter(xas_project.groups.values())) | |
| 299 group.args["annotation"] = input_values["annotation"] | |
| 300 xas_project.save() | |
| 301 | |
| 302 # Ensure that we do not run out of memory when running on large zips | |
| 303 gc.collect() | |
| 304 | |
| 305 | |
| 306 def validate_pre(pre, energy_0, energy_format): | |
| 307 if pre is not None and energy_format == "absolute": | |
| 308 if energy_0 is None: | |
| 309 raise ValueError( | |
| 310 "Edge energy must be set manually or be present in the " | |
| 311 "existing Athena project if using absolute format." | |
| 312 ) | |
| 313 pre -= energy_0 | |
| 314 | |
| 315 return pre | |
| 316 | |
| 317 | |
| 318 def plot_derivative(plot_path: str, xafs_group: Group): | |
| 319 plt.figure() | |
| 320 plt.plot(xafs_group.energy, xafs_group.dmude) | |
| 321 plt.grid(color="r", linestyle=":", linewidth=1) | |
| 322 plt.xlabel("Energy (eV)") | |
| 323 plt.ylabel("Derivative normalised to x$\mu$(E)") # noqa: W605 | |
| 324 plt.savefig(plot_path, format="png") | |
| 325 plt.close("all") | |
| 326 | |
| 327 | |
| 328 def plot_edge_fits(plot_path: str, xafs_group: Group): | |
| 329 plt.figure() | |
| 330 plt.plot(xafs_group.energy, xafs_group.pre_edge, "g", label="pre-edge") | |
| 331 plt.plot(xafs_group.energy, xafs_group.post_edge, "r", label="post-edge") | |
| 332 plt.plot(xafs_group.energy, xafs_group.mu, "b", label="fit data") | |
| 333 plt.grid(color="r", linestyle=":", linewidth=1) | |
| 334 plt.xlabel("Energy (eV)") | |
| 335 plt.ylabel("x$\mu$(E)") # noqa: W605 | |
| 336 plt.title("pre-edge and post_edge fitting to $\mu$") # noqa: W605 | |
| 337 plt.legend() | |
| 338 plt.savefig(plot_path, format="png") | |
| 339 plt.close("all") | |
| 340 | |
| 341 | |
| 342 def plot_flattened(plot_path: str, xafs_group: Group): | |
| 343 plt.figure() | |
| 344 plt.plot(xafs_group.energy, xafs_group.flat) | |
| 345 plt.grid(color="r", linestyle=":", linewidth=1) | |
| 346 plt.xlabel("Energy (eV)") | |
| 347 plt.ylabel("normalised x$\mu$(E)") # noqa: W605 | |
| 348 plt.savefig(plot_path, format="png") | |
| 349 plt.close("all") | |
| 350 | |
| 351 | |
| 352 if __name__ == "__main__": | |
| 353 # larch imports set this to an interactive backend, so need to change it | |
| 354 matplotlib.use("Agg") | |
| 355 | |
| 356 dat_file = sys.argv[1] | |
| 357 input_values = json.load(open(sys.argv[2], "r", encoding="utf-8")) | |
| 358 merge_inputs = input_values["merge_inputs"]["merge_inputs"] | |
| 359 data_format = input_values["merge_inputs"]["format"]["format"] | |
| 360 if "is_zipped" in input_values["merge_inputs"]["format"]: | |
| 361 is_zipped = bool( | |
| 362 input_values["merge_inputs"]["format"]["is_zipped"]["is_zipped"] | |
| 363 ) | |
| 364 else: | |
| 365 is_zipped = False | |
| 366 xftf_params = input_values["variables"]["xftf"] | |
| 367 extract_group = None | |
| 368 | |
| 369 if "extract_group" in input_values["merge_inputs"]["format"]: | |
| 370 extract_group = input_values["merge_inputs"]["format"]["extract_group"] | |
| 371 | |
| 372 energy_column = None | |
| 373 mu_column = None | |
| 374 if "energy_column" in input_values["merge_inputs"]["format"]: | |
| 375 energy_column = input_values["merge_inputs"]["format"]["energy_column"] | |
| 376 if "mu_column" in input_values["merge_inputs"]["format"]: | |
| 377 mu_column = input_values["merge_inputs"]["format"]["mu_column"] | |
| 378 | |
| 379 reader = Reader( | |
| 380 energy_column=energy_column, | |
| 381 mu_column=mu_column, | |
| 382 xftf_params=xftf_params, | |
| 383 data_format=data_format, | |
| 384 extract_group=extract_group, | |
| 385 ) | |
| 386 keyed_data = reader.load_data( | |
| 387 dat_file=dat_file, | |
| 388 merge_inputs=merge_inputs, | |
| 389 is_zipped=is_zipped, | |
| 390 ) | |
| 391 for key, group in keyed_data.items(): | |
| 392 main( | |
| 393 group, | |
| 394 input_values=input_values, | |
| 395 path_key=key, | |
| 396 ) |
