comparison larch_athena.py @ 1:2b3115342fef draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_athena commit 1cf6d7160497ba58fe16a51f00d088a20934eba6
author muon-spectroscopy-computational-project
date Wed, 06 Dec 2023 13:03:55 +0000
parents ae2f265ecf8e
children a1e26990131c
comparison
equal deleted inserted replaced
0:ae2f265ecf8e 1:2b3115342fef
2 import json 2 import json
3 import os 3 import os
4 import re 4 import re
5 import sys 5 import sys
6 6
7 from common import read_group 7 from common import (
8 pre_edge_with_defaults, read_all_groups, read_group, xftf_with_defaults
9 )
8 10
9 from larch.io import ( 11 from larch.io import (
10 create_athena, 12 create_athena,
11 h5group, 13 h5group,
12 merge_groups, 14 merge_groups,
13 read_ascii, 15 read_ascii,
14 set_array_labels, 16 set_array_labels,
15 ) 17 )
16 from larch.symboltable import Group 18 from larch.symboltable import Group
17 from larch.xafs import autobk, pre_edge, rebin_xafs, xftf 19 from larch.xafs import rebin_xafs
18 20
19 import matplotlib 21 import matplotlib
20 import matplotlib.pyplot as plt 22 import matplotlib.pyplot as plt
21 23
22 import numpy as np 24 import numpy as np
25 class Reader: 27 class Reader:
26 def __init__( 28 def __init__(
27 self, 29 self,
28 energy_column: str, 30 energy_column: str,
29 mu_column: str, 31 mu_column: str,
30 xftf_params: dict,
31 data_format: str, 32 data_format: str,
32 extract_group: str = None, 33 extract_group: "dict[str, str]" = None,
33 ): 34 ):
34 self.energy_column = energy_column 35 self.energy_column = energy_column
35 self.mu_column = mu_column 36 self.mu_column = mu_column
36 self.xftf_params = xftf_params
37 self.data_format = data_format 37 self.data_format = data_format
38 self.extract_group = extract_group 38 self.extract_group = extract_group
39 39
40 def load_data( 40 def load_data(
41 self, 41 self,
70 70
71 def load_single_file( 71 def load_single_file(
72 self, 72 self,
73 filepath: str, 73 filepath: str,
74 is_zipped: bool = False, 74 is_zipped: bool = False,
75 ) -> "dict[str,Group]": 75 ) -> "tuple[dict, bool]":
76 if is_zipped: 76 if is_zipped:
77 return self.load_zipped_files() 77 return self.load_zipped_files()
78 78
79 print(f"Attempting to read from {filepath}") 79 print(f"Attempting to read from {filepath}")
80 if self.data_format == "athena": 80 if self.data_format == "athena":
81 group = read_group(filepath, self.extract_group, self.xftf_params) 81 if self.extract_group["extract_group"] == "single":
82 group = read_group(filepath, self.extract_group["group_name"])
83 return {"out": group}
84 elif self.extract_group["extract_group"] == "multiple":
85 groups = {}
86 for repeat in self.extract_group["multiple"]:
87 name = repeat["group_name"]
88 groups[name] = read_group(filepath, name)
89 return groups
90 else:
91 return read_all_groups(filepath)
92
82 else: 93 else:
83 # Try ascii anyway 94 # Try ascii anyway
84 try: 95 try:
85 group = self.load_ascii(filepath) 96 group = self.load_ascii(filepath)
86 if not group.array_labels: 97 if not group.array_labels:
88 # will just fail to load any data 99 # will just fail to load any data
89 group = self.load_h5(filepath) 100 group = self.load_h5(filepath)
90 except (UnicodeDecodeError, TypeError): 101 except (UnicodeDecodeError, TypeError):
91 # Indicates this isn't plaintext, try h5 102 # Indicates this isn't plaintext, try h5
92 group = self.load_h5(filepath) 103 group = self.load_h5(filepath)
93 return {"out": group} 104 pre_edge_with_defaults(group)
105 xftf_with_defaults(group)
106 return {"out": group}
94 107
95 def load_ascii(self, dat_file): 108 def load_ascii(self, dat_file):
96 with open(dat_file) as f: 109 with open(dat_file) as f:
97 labels = None 110 labels = None
98 last_line = None 111 last_line = None
154 labels = [label.lower() for label in xafs_group.array_labels] 167 labels = [label.lower() for label in xafs_group.array_labels]
155 print(f"Read columns: {labels}") 168 print(f"Read columns: {labels}")
156 169
157 if "energy" in labels: 170 if "energy" in labels:
158 print("'energy' present in column headers") 171 print("'energy' present in column headers")
159 elif self.energy_column is not None: 172 elif self.energy_column:
160 if self.energy_column.lower() in labels: 173 if self.energy_column.lower() in labels:
161 labels[labels.index(self.energy_column.lower())] = "energy" 174 labels[labels.index(self.energy_column.lower())] = "energy"
162 else: 175 else:
163 raise ValueError(f"{self.energy_column} not found in {labels}") 176 raise ValueError(f"{self.energy_column} not found in {labels}")
164 else: 177 else:
165 for i, label in enumerate(labels): 178 for i, label in enumerate(labels):
166 if label == "col1" or label.endswith("energy"): 179 if label in ("col1", "ef") or label.endswith("energy"):
167 labels[i] = "energy" 180 labels[i] = "energy"
168 break 181 break
169 182
170 if "mu" in labels: 183 if "mu" in labels:
171 print("'mu' present in column headers") 184 print("'mu' present in column headers")
172 elif self.mu_column is not None: 185 elif self.mu_column:
173 if self.mu_column.lower() in labels: 186 if self.mu_column.lower() in labels:
174 labels[labels.index(self.mu_column.lower())] = "mu" 187 labels[labels.index(self.mu_column.lower())] = "mu"
175 else: 188 else:
176 raise ValueError(f"{self.mu_column} not found in {labels}") 189 raise ValueError(f"{self.mu_column} not found in {labels}")
177 else: 190 else:
178 for i, label in enumerate(labels): 191 for i, label in enumerate(labels):
179 if label in ["col2", "xmu", "lni0it", "ffi0"]: 192 if label in ["col2", "xmu", "lni0it", "ffi0", "ff/i1"]:
180 labels[i] = "mu" 193 labels[i] = "mu"
181 break 194 break
182 195
183 if labels != xafs_group.array_labels: 196 if labels != xafs_group.array_labels:
184 print(f"Renaming columns to: {labels}") 197 print(f"Renaming columns to: {labels}")
187 return xafs_group 200 return xafs_group
188 201
189 202
190 def calibrate_energy( 203 def calibrate_energy(
191 xafs_group: Group, 204 xafs_group: Group,
192 energy_0: float, 205 calibration_e0: float = None,
193 energy_min: float, 206 energy_min: float = None,
194 energy_max: float, 207 energy_max: float = None,
195 energy_format: str,
196 ): 208 ):
197 if energy_0 is not None: 209 if calibration_e0 is not None:
198 print(f"Recalibrating energy edge from {xafs_group.e0} to {energy_0}") 210 print(f"Recalibrating edge from {xafs_group.e0} to {calibration_e0}")
199 xafs_group.energy = xafs_group.energy + energy_0 - xafs_group.e0 211 xafs_group.energy = xafs_group.energy + calibration_e0 - xafs_group.e0
200 xafs_group.e0 = energy_0 212 xafs_group.e0 = calibration_e0
201 213
202 if not (energy_min or energy_max): 214 if not (energy_min or energy_max):
203 return xafs_group 215 return xafs_group
204 216
205 if energy_min: 217 if energy_min is not None:
206 if energy_format == "relative":
207 energy_min += xafs_group.e0
208 index_min = np.searchsorted(xafs_group.energy, energy_min) 218 index_min = np.searchsorted(xafs_group.energy, energy_min)
209 else: 219 else:
210 index_min = 0 220 index_min = 0
211 221
212 if energy_max: 222 if energy_max is not None:
213 if energy_format == "relative":
214 energy_max += xafs_group.e0
215 index_max = np.searchsorted(xafs_group.energy, energy_max) 223 index_max = np.searchsorted(xafs_group.energy, energy_max)
216 else: 224 else:
217 index_max = len(xafs_group.energy) 225 index_max = len(xafs_group.energy)
218 226
219 print( 227 print(
238 return xafs_group 246 return xafs_group
239 247
240 248
241 def main( 249 def main(
242 xas_data: Group, 250 xas_data: Group,
243 input_values: dict, 251 do_calibrate: bool,
252 calibrate_settings: dict,
253 do_rebin: bool,
254 do_pre_edge: bool,
255 pre_edge_settings: dict,
256 do_xftf: bool,
257 xftf_settings: dict,
258 plot_graph: bool,
259 annotation: str,
244 path_key: str = "out", 260 path_key: str = "out",
245 ): 261 ):
246 energy_0 = input_values["variables"]["energy_0"] 262 if do_calibrate:
247 if energy_0 is None and hasattr(xas_data, "e0"): 263 print(f"Calibrating energy with {calibrate_settings}")
248 energy_0 = xas_data.e0 264 xas_data = calibrate_energy(xas_data, **calibrate_settings)
249 265 # After re-calibrating, will need to redo pre-edge with new range
250 energy_format = input_values["variables"]["energy_format"] 266 do_pre_edge = True
251 pre1 = input_values["variables"]["pre1"] 267
252 pre2 = input_values["variables"]["pre2"] 268 if do_rebin:
253 pre1 = validate_pre(pre1, energy_0, energy_format) 269 print("Re-binning data")
254 pre2 = validate_pre(pre2, energy_0, energy_format) 270 rebin_xafs(
255 271 energy=xas_data.energy,
256 pre_edge( 272 mu=xas_data.mu,
257 energy=xas_data.energy, 273 group=xas_data,
258 mu=xas_data.mu, 274 **pre_edge_settings,
259 group=xas_data, 275 )
260 e0=energy_0,
261 pre1=pre1,
262 pre2=pre2,
263 )
264
265 energy_min = input_values["variables"]["energy_min"]
266 energy_max = input_values["variables"]["energy_max"]
267 xas_data = calibrate_energy(
268 xas_data,
269 energy_0,
270 energy_min,
271 energy_max,
272 energy_format=energy_format,
273 )
274
275 if input_values["rebin"]:
276 print(xas_data.energy, xas_data.mu)
277 rebin_xafs(energy=xas_data.energy, mu=xas_data.mu, group=xas_data)
278 xas_data = xas_data.rebinned 276 xas_data = xas_data.rebinned
279 pre_edge(energy=xas_data.energy, mu=xas_data.mu, group=xas_data) 277 # After re-bin, will need to redo pre-edge
280 278 do_pre_edge = True
281 try: 279
282 autobk(xas_data) 280 if do_pre_edge:
283 except ValueError as e: 281 pre_edge_with_defaults(xas_data, pre_edge_settings)
284 raise ValueError( 282
285 f"autobk failed with energy={xas_data.energy}, mu={xas_data.mu}.\n" 283 if do_xftf:
286 "This may occur if the edge is not included in the above ranges." 284 xftf_with_defaults(xas_data, xftf_settings)
287 ) from e 285
288 xftf(xas_data, **xftf_params) 286 if plot_graph:
289
290 if input_values["plot_graph"]:
291 plot_edge_fits(f"edge/{path_key}.png", xas_data) 287 plot_edge_fits(f"edge/{path_key}.png", xas_data)
292 plot_flattened(f"flat/{path_key}.png", xas_data) 288 plot_flattened(f"flat/{path_key}.png", xas_data)
293 plot_derivative(f"derivative/{path_key}.png", xas_data) 289 plot_derivative(f"derivative/{path_key}.png", xas_data)
294 290
295 xas_project = create_athena(f"prj/{path_key}.prj") 291 xas_project = create_athena(f"prj/{path_key}.prj")
296 xas_project.add_group(xas_data) 292 xas_project.add_group(xas_data)
297 if input_values["annotation"]: 293 if annotation:
298 group = next(iter(xas_project.groups.values())) 294 group = next(iter(xas_project.groups.values()))
299 group.args["annotation"] = input_values["annotation"] 295 group.args["annotation"] = annotation
300 xas_project.save() 296 xas_project.save()
301 297
302 # Ensure that we do not run out of memory when running on large zips 298 # Ensure that we do not run out of memory when running on large zips
303 gc.collect() 299 gc.collect()
304
305
306 def validate_pre(pre, energy_0, energy_format):
307 if pre is not None and energy_format == "absolute":
308 if energy_0 is None:
309 raise ValueError(
310 "Edge energy must be set manually or be present in the "
311 "existing Athena project if using absolute format."
312 )
313 pre -= energy_0
314
315 return pre
316 300
317 301
318 def plot_derivative(plot_path: str, xafs_group: Group): 302 def plot_derivative(plot_path: str, xafs_group: Group):
319 plt.figure() 303 plt.figure()
320 plt.plot(xafs_group.energy, xafs_group.dmude) 304 plt.plot(xafs_group.energy, xafs_group.dmude)
361 is_zipped = bool( 345 is_zipped = bool(
362 input_values["merge_inputs"]["format"]["is_zipped"]["is_zipped"] 346 input_values["merge_inputs"]["format"]["is_zipped"]["is_zipped"]
363 ) 347 )
364 else: 348 else:
365 is_zipped = False 349 is_zipped = False
366 xftf_params = input_values["variables"]["xftf"] 350
367 extract_group = None 351 extract_group = None
368
369 if "extract_group" in input_values["merge_inputs"]["format"]: 352 if "extract_group" in input_values["merge_inputs"]["format"]:
370 extract_group = input_values["merge_inputs"]["format"]["extract_group"] 353 extract_group = input_values["merge_inputs"]["format"]["extract_group"]
371 354
372 energy_column = None 355 energy_column = None
373 mu_column = None 356 mu_column = None
377 mu_column = input_values["merge_inputs"]["format"]["mu_column"] 360 mu_column = input_values["merge_inputs"]["format"]["mu_column"]
378 361
379 reader = Reader( 362 reader = Reader(
380 energy_column=energy_column, 363 energy_column=energy_column,
381 mu_column=mu_column, 364 mu_column=mu_column,
382 xftf_params=xftf_params,
383 data_format=data_format, 365 data_format=data_format,
384 extract_group=extract_group, 366 extract_group=extract_group,
385 ) 367 )
386 keyed_data = reader.load_data( 368 keyed_data = reader.load_data(
387 dat_file=dat_file, 369 dat_file=dat_file,
388 merge_inputs=merge_inputs, 370 merge_inputs=merge_inputs,
389 is_zipped=is_zipped, 371 is_zipped=is_zipped,
390 ) 372 )
373
374 calibrate_items = input_values["processing"]["calibrate"].items()
375 calibrate_settings = {k: v for k, v in calibrate_items if v is not None}
376 do_calibrate = calibrate_settings.pop("calibrate") == "true"
377
378 do_rebin = input_values["processing"].pop("rebin")
379
380 pre_edge_items = input_values["processing"]["pre_edge"].items()
381 pre_edge_settings = {k: v for k, v in pre_edge_items if v is not None}
382 do_pre_edge = pre_edge_settings.pop("pre_edge") == "true"
383
384 xftf_items = input_values["processing"]["xftf"].items()
385 xftf_settings = {k: v for k, v in xftf_items if v is not None}
386 do_xftf = xftf_settings.pop("xftf") == "true"
387
388 plot_graph = input_values["plot_graph"]
389 annotation = input_values["annotation"]
390
391 for key, group in keyed_data.items(): 391 for key, group in keyed_data.items():
392 main( 392 main(
393 group, 393 group,
394 input_values=input_values, 394 do_calibrate=do_calibrate,
395 calibrate_settings=calibrate_settings,
396 do_rebin=do_rebin,
397 do_pre_edge=do_pre_edge,
398 pre_edge_settings=pre_edge_settings,
399 do_xftf=do_xftf,
400 xftf_settings=xftf_settings,
401 plot_graph=plot_graph,
402 annotation=annotation,
395 path_key=key, 403 path_key=key,
396 ) 404 )