Mercurial > repos > muon-spectroscopy-computational-project > larch_athena
comparison larch_athena.py @ 1:2b3115342fef draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_athena commit 1cf6d7160497ba58fe16a51f00d088a20934eba6
author | muon-spectroscopy-computational-project |
---|---|
date | Wed, 06 Dec 2023 13:03:55 +0000 |
parents | ae2f265ecf8e |
children | a1e26990131c |
comparison
equal
deleted
inserted
replaced
0:ae2f265ecf8e | 1:2b3115342fef |
---|---|
2 import json | 2 import json |
3 import os | 3 import os |
4 import re | 4 import re |
5 import sys | 5 import sys |
6 | 6 |
7 from common import read_group | 7 from common import ( |
8 pre_edge_with_defaults, read_all_groups, read_group, xftf_with_defaults | |
9 ) | |
8 | 10 |
9 from larch.io import ( | 11 from larch.io import ( |
10 create_athena, | 12 create_athena, |
11 h5group, | 13 h5group, |
12 merge_groups, | 14 merge_groups, |
13 read_ascii, | 15 read_ascii, |
14 set_array_labels, | 16 set_array_labels, |
15 ) | 17 ) |
16 from larch.symboltable import Group | 18 from larch.symboltable import Group |
17 from larch.xafs import autobk, pre_edge, rebin_xafs, xftf | 19 from larch.xafs import rebin_xafs |
18 | 20 |
19 import matplotlib | 21 import matplotlib |
20 import matplotlib.pyplot as plt | 22 import matplotlib.pyplot as plt |
21 | 23 |
22 import numpy as np | 24 import numpy as np |
25 class Reader: | 27 class Reader: |
26 def __init__( | 28 def __init__( |
27 self, | 29 self, |
28 energy_column: str, | 30 energy_column: str, |
29 mu_column: str, | 31 mu_column: str, |
30 xftf_params: dict, | |
31 data_format: str, | 32 data_format: str, |
32 extract_group: str = None, | 33 extract_group: "dict[str, str]" = None, |
33 ): | 34 ): |
34 self.energy_column = energy_column | 35 self.energy_column = energy_column |
35 self.mu_column = mu_column | 36 self.mu_column = mu_column |
36 self.xftf_params = xftf_params | |
37 self.data_format = data_format | 37 self.data_format = data_format |
38 self.extract_group = extract_group | 38 self.extract_group = extract_group |
39 | 39 |
40 def load_data( | 40 def load_data( |
41 self, | 41 self, |
70 | 70 |
71 def load_single_file( | 71 def load_single_file( |
72 self, | 72 self, |
73 filepath: str, | 73 filepath: str, |
74 is_zipped: bool = False, | 74 is_zipped: bool = False, |
75 ) -> "dict[str,Group]": | 75 ) -> "tuple[dict, bool]": |
76 if is_zipped: | 76 if is_zipped: |
77 return self.load_zipped_files() | 77 return self.load_zipped_files() |
78 | 78 |
79 print(f"Attempting to read from {filepath}") | 79 print(f"Attempting to read from {filepath}") |
80 if self.data_format == "athena": | 80 if self.data_format == "athena": |
81 group = read_group(filepath, self.extract_group, self.xftf_params) | 81 if self.extract_group["extract_group"] == "single": |
82 group = read_group(filepath, self.extract_group["group_name"]) | |
83 return {"out": group} | |
84 elif self.extract_group["extract_group"] == "multiple": | |
85 groups = {} | |
86 for repeat in self.extract_group["multiple"]: | |
87 name = repeat["group_name"] | |
88 groups[name] = read_group(filepath, name) | |
89 return groups | |
90 else: | |
91 return read_all_groups(filepath) | |
92 | |
82 else: | 93 else: |
83 # Try ascii anyway | 94 # Try ascii anyway |
84 try: | 95 try: |
85 group = self.load_ascii(filepath) | 96 group = self.load_ascii(filepath) |
86 if not group.array_labels: | 97 if not group.array_labels: |
88 # will just fail to load any data | 99 # will just fail to load any data |
89 group = self.load_h5(filepath) | 100 group = self.load_h5(filepath) |
90 except (UnicodeDecodeError, TypeError): | 101 except (UnicodeDecodeError, TypeError): |
91 # Indicates this isn't plaintext, try h5 | 102 # Indicates this isn't plaintext, try h5 |
92 group = self.load_h5(filepath) | 103 group = self.load_h5(filepath) |
93 return {"out": group} | 104 pre_edge_with_defaults(group) |
105 xftf_with_defaults(group) | |
106 return {"out": group} | |
94 | 107 |
95 def load_ascii(self, dat_file): | 108 def load_ascii(self, dat_file): |
96 with open(dat_file) as f: | 109 with open(dat_file) as f: |
97 labels = None | 110 labels = None |
98 last_line = None | 111 last_line = None |
154 labels = [label.lower() for label in xafs_group.array_labels] | 167 labels = [label.lower() for label in xafs_group.array_labels] |
155 print(f"Read columns: {labels}") | 168 print(f"Read columns: {labels}") |
156 | 169 |
157 if "energy" in labels: | 170 if "energy" in labels: |
158 print("'energy' present in column headers") | 171 print("'energy' present in column headers") |
159 elif self.energy_column is not None: | 172 elif self.energy_column: |
160 if self.energy_column.lower() in labels: | 173 if self.energy_column.lower() in labels: |
161 labels[labels.index(self.energy_column.lower())] = "energy" | 174 labels[labels.index(self.energy_column.lower())] = "energy" |
162 else: | 175 else: |
163 raise ValueError(f"{self.energy_column} not found in {labels}") | 176 raise ValueError(f"{self.energy_column} not found in {labels}") |
164 else: | 177 else: |
165 for i, label in enumerate(labels): | 178 for i, label in enumerate(labels): |
166 if label == "col1" or label.endswith("energy"): | 179 if label in ("col1", "ef") or label.endswith("energy"): |
167 labels[i] = "energy" | 180 labels[i] = "energy" |
168 break | 181 break |
169 | 182 |
170 if "mu" in labels: | 183 if "mu" in labels: |
171 print("'mu' present in column headers") | 184 print("'mu' present in column headers") |
172 elif self.mu_column is not None: | 185 elif self.mu_column: |
173 if self.mu_column.lower() in labels: | 186 if self.mu_column.lower() in labels: |
174 labels[labels.index(self.mu_column.lower())] = "mu" | 187 labels[labels.index(self.mu_column.lower())] = "mu" |
175 else: | 188 else: |
176 raise ValueError(f"{self.mu_column} not found in {labels}") | 189 raise ValueError(f"{self.mu_column} not found in {labels}") |
177 else: | 190 else: |
178 for i, label in enumerate(labels): | 191 for i, label in enumerate(labels): |
179 if label in ["col2", "xmu", "lni0it", "ffi0"]: | 192 if label in ["col2", "xmu", "lni0it", "ffi0", "ff/i1"]: |
180 labels[i] = "mu" | 193 labels[i] = "mu" |
181 break | 194 break |
182 | 195 |
183 if labels != xafs_group.array_labels: | 196 if labels != xafs_group.array_labels: |
184 print(f"Renaming columns to: {labels}") | 197 print(f"Renaming columns to: {labels}") |
187 return xafs_group | 200 return xafs_group |
188 | 201 |
189 | 202 |
190 def calibrate_energy( | 203 def calibrate_energy( |
191 xafs_group: Group, | 204 xafs_group: Group, |
192 energy_0: float, | 205 calibration_e0: float = None, |
193 energy_min: float, | 206 energy_min: float = None, |
194 energy_max: float, | 207 energy_max: float = None, |
195 energy_format: str, | |
196 ): | 208 ): |
197 if energy_0 is not None: | 209 if calibration_e0 is not None: |
198 print(f"Recalibrating energy edge from {xafs_group.e0} to {energy_0}") | 210 print(f"Recalibrating edge from {xafs_group.e0} to {calibration_e0}") |
199 xafs_group.energy = xafs_group.energy + energy_0 - xafs_group.e0 | 211 xafs_group.energy = xafs_group.energy + calibration_e0 - xafs_group.e0 |
200 xafs_group.e0 = energy_0 | 212 xafs_group.e0 = calibration_e0 |
201 | 213 |
202 if not (energy_min or energy_max): | 214 if not (energy_min or energy_max): |
203 return xafs_group | 215 return xafs_group |
204 | 216 |
205 if energy_min: | 217 if energy_min is not None: |
206 if energy_format == "relative": | |
207 energy_min += xafs_group.e0 | |
208 index_min = np.searchsorted(xafs_group.energy, energy_min) | 218 index_min = np.searchsorted(xafs_group.energy, energy_min) |
209 else: | 219 else: |
210 index_min = 0 | 220 index_min = 0 |
211 | 221 |
212 if energy_max: | 222 if energy_max is not None: |
213 if energy_format == "relative": | |
214 energy_max += xafs_group.e0 | |
215 index_max = np.searchsorted(xafs_group.energy, energy_max) | 223 index_max = np.searchsorted(xafs_group.energy, energy_max) |
216 else: | 224 else: |
217 index_max = len(xafs_group.energy) | 225 index_max = len(xafs_group.energy) |
218 | 226 |
219 print( | 227 print( |
238 return xafs_group | 246 return xafs_group |
239 | 247 |
240 | 248 |
241 def main( | 249 def main( |
242 xas_data: Group, | 250 xas_data: Group, |
243 input_values: dict, | 251 do_calibrate: bool, |
252 calibrate_settings: dict, | |
253 do_rebin: bool, | |
254 do_pre_edge: bool, | |
255 pre_edge_settings: dict, | |
256 do_xftf: bool, | |
257 xftf_settings: dict, | |
258 plot_graph: bool, | |
259 annotation: str, | |
244 path_key: str = "out", | 260 path_key: str = "out", |
245 ): | 261 ): |
246 energy_0 = input_values["variables"]["energy_0"] | 262 if do_calibrate: |
247 if energy_0 is None and hasattr(xas_data, "e0"): | 263 print(f"Calibrating energy with {calibrate_settings}") |
248 energy_0 = xas_data.e0 | 264 xas_data = calibrate_energy(xas_data, **calibrate_settings) |
249 | 265 # After re-calibrating, will need to redo pre-edge with new range |
250 energy_format = input_values["variables"]["energy_format"] | 266 do_pre_edge = True |
251 pre1 = input_values["variables"]["pre1"] | 267 |
252 pre2 = input_values["variables"]["pre2"] | 268 if do_rebin: |
253 pre1 = validate_pre(pre1, energy_0, energy_format) | 269 print("Re-binning data") |
254 pre2 = validate_pre(pre2, energy_0, energy_format) | 270 rebin_xafs( |
255 | 271 energy=xas_data.energy, |
256 pre_edge( | 272 mu=xas_data.mu, |
257 energy=xas_data.energy, | 273 group=xas_data, |
258 mu=xas_data.mu, | 274 **pre_edge_settings, |
259 group=xas_data, | 275 ) |
260 e0=energy_0, | |
261 pre1=pre1, | |
262 pre2=pre2, | |
263 ) | |
264 | |
265 energy_min = input_values["variables"]["energy_min"] | |
266 energy_max = input_values["variables"]["energy_max"] | |
267 xas_data = calibrate_energy( | |
268 xas_data, | |
269 energy_0, | |
270 energy_min, | |
271 energy_max, | |
272 energy_format=energy_format, | |
273 ) | |
274 | |
275 if input_values["rebin"]: | |
276 print(xas_data.energy, xas_data.mu) | |
277 rebin_xafs(energy=xas_data.energy, mu=xas_data.mu, group=xas_data) | |
278 xas_data = xas_data.rebinned | 276 xas_data = xas_data.rebinned |
279 pre_edge(energy=xas_data.energy, mu=xas_data.mu, group=xas_data) | 277 # After re-bin, will need to redo pre-edge |
280 | 278 do_pre_edge = True |
281 try: | 279 |
282 autobk(xas_data) | 280 if do_pre_edge: |
283 except ValueError as e: | 281 pre_edge_with_defaults(xas_data, pre_edge_settings) |
284 raise ValueError( | 282 |
285 f"autobk failed with energy={xas_data.energy}, mu={xas_data.mu}.\n" | 283 if do_xftf: |
286 "This may occur if the edge is not included in the above ranges." | 284 xftf_with_defaults(xas_data, xftf_settings) |
287 ) from e | 285 |
288 xftf(xas_data, **xftf_params) | 286 if plot_graph: |
289 | |
290 if input_values["plot_graph"]: | |
291 plot_edge_fits(f"edge/{path_key}.png", xas_data) | 287 plot_edge_fits(f"edge/{path_key}.png", xas_data) |
292 plot_flattened(f"flat/{path_key}.png", xas_data) | 288 plot_flattened(f"flat/{path_key}.png", xas_data) |
293 plot_derivative(f"derivative/{path_key}.png", xas_data) | 289 plot_derivative(f"derivative/{path_key}.png", xas_data) |
294 | 290 |
295 xas_project = create_athena(f"prj/{path_key}.prj") | 291 xas_project = create_athena(f"prj/{path_key}.prj") |
296 xas_project.add_group(xas_data) | 292 xas_project.add_group(xas_data) |
297 if input_values["annotation"]: | 293 if annotation: |
298 group = next(iter(xas_project.groups.values())) | 294 group = next(iter(xas_project.groups.values())) |
299 group.args["annotation"] = input_values["annotation"] | 295 group.args["annotation"] = annotation |
300 xas_project.save() | 296 xas_project.save() |
301 | 297 |
302 # Ensure that we do not run out of memory when running on large zips | 298 # Ensure that we do not run out of memory when running on large zips |
303 gc.collect() | 299 gc.collect() |
304 | |
305 | |
306 def validate_pre(pre, energy_0, energy_format): | |
307 if pre is not None and energy_format == "absolute": | |
308 if energy_0 is None: | |
309 raise ValueError( | |
310 "Edge energy must be set manually or be present in the " | |
311 "existing Athena project if using absolute format." | |
312 ) | |
313 pre -= energy_0 | |
314 | |
315 return pre | |
316 | 300 |
317 | 301 |
318 def plot_derivative(plot_path: str, xafs_group: Group): | 302 def plot_derivative(plot_path: str, xafs_group: Group): |
319 plt.figure() | 303 plt.figure() |
320 plt.plot(xafs_group.energy, xafs_group.dmude) | 304 plt.plot(xafs_group.energy, xafs_group.dmude) |
361 is_zipped = bool( | 345 is_zipped = bool( |
362 input_values["merge_inputs"]["format"]["is_zipped"]["is_zipped"] | 346 input_values["merge_inputs"]["format"]["is_zipped"]["is_zipped"] |
363 ) | 347 ) |
364 else: | 348 else: |
365 is_zipped = False | 349 is_zipped = False |
366 xftf_params = input_values["variables"]["xftf"] | 350 |
367 extract_group = None | 351 extract_group = None |
368 | |
369 if "extract_group" in input_values["merge_inputs"]["format"]: | 352 if "extract_group" in input_values["merge_inputs"]["format"]: |
370 extract_group = input_values["merge_inputs"]["format"]["extract_group"] | 353 extract_group = input_values["merge_inputs"]["format"]["extract_group"] |
371 | 354 |
372 energy_column = None | 355 energy_column = None |
373 mu_column = None | 356 mu_column = None |
377 mu_column = input_values["merge_inputs"]["format"]["mu_column"] | 360 mu_column = input_values["merge_inputs"]["format"]["mu_column"] |
378 | 361 |
379 reader = Reader( | 362 reader = Reader( |
380 energy_column=energy_column, | 363 energy_column=energy_column, |
381 mu_column=mu_column, | 364 mu_column=mu_column, |
382 xftf_params=xftf_params, | |
383 data_format=data_format, | 365 data_format=data_format, |
384 extract_group=extract_group, | 366 extract_group=extract_group, |
385 ) | 367 ) |
386 keyed_data = reader.load_data( | 368 keyed_data = reader.load_data( |
387 dat_file=dat_file, | 369 dat_file=dat_file, |
388 merge_inputs=merge_inputs, | 370 merge_inputs=merge_inputs, |
389 is_zipped=is_zipped, | 371 is_zipped=is_zipped, |
390 ) | 372 ) |
373 | |
374 calibrate_items = input_values["processing"]["calibrate"].items() | |
375 calibrate_settings = {k: v for k, v in calibrate_items if v is not None} | |
376 do_calibrate = calibrate_settings.pop("calibrate") == "true" | |
377 | |
378 do_rebin = input_values["processing"].pop("rebin") | |
379 | |
380 pre_edge_items = input_values["processing"]["pre_edge"].items() | |
381 pre_edge_settings = {k: v for k, v in pre_edge_items if v is not None} | |
382 do_pre_edge = pre_edge_settings.pop("pre_edge") == "true" | |
383 | |
384 xftf_items = input_values["processing"]["xftf"].items() | |
385 xftf_settings = {k: v for k, v in xftf_items if v is not None} | |
386 do_xftf = xftf_settings.pop("xftf") == "true" | |
387 | |
388 plot_graph = input_values["plot_graph"] | |
389 annotation = input_values["annotation"] | |
390 | |
391 for key, group in keyed_data.items(): | 391 for key, group in keyed_data.items(): |
392 main( | 392 main( |
393 group, | 393 group, |
394 input_values=input_values, | 394 do_calibrate=do_calibrate, |
395 calibrate_settings=calibrate_settings, | |
396 do_rebin=do_rebin, | |
397 do_pre_edge=do_pre_edge, | |
398 pre_edge_settings=pre_edge_settings, | |
399 do_xftf=do_xftf, | |
400 xftf_settings=xftf_settings, | |
401 plot_graph=plot_graph, | |
402 annotation=annotation, | |
395 path_key=key, | 403 path_key=key, |
396 ) | 404 ) |