comparison larch_athena.py @ 0:ae2f265ecf8e draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_athena commit 5be486890442dedfb327289d597e1c8110240735
author muon-spectroscopy-computational-project
date Tue, 14 Nov 2023 15:34:40 +0000
parents
children 2b3115342fef
comparison
equal deleted inserted replaced
-1:000000000000 0:ae2f265ecf8e
1 import gc
2 import json
3 import os
4 import re
5 import sys
6
7 from common import read_group
8
9 from larch.io import (
10 create_athena,
11 h5group,
12 merge_groups,
13 read_ascii,
14 set_array_labels,
15 )
16 from larch.symboltable import Group
17 from larch.xafs import autobk, pre_edge, rebin_xafs, xftf
18
19 import matplotlib
20 import matplotlib.pyplot as plt
21
22 import numpy as np
23
24
25 class Reader:
26 def __init__(
27 self,
28 energy_column: str,
29 mu_column: str,
30 xftf_params: dict,
31 data_format: str,
32 extract_group: str = None,
33 ):
34 self.energy_column = energy_column
35 self.mu_column = mu_column
36 self.xftf_params = xftf_params
37 self.data_format = data_format
38 self.extract_group = extract_group
39
40 def load_data(
41 self,
42 dat_file: str,
43 merge_inputs: bool,
44 is_zipped: bool,
45 ) -> "dict[str, Group]":
46 if merge_inputs:
47 out_group = self.merge_files(
48 dat_files=dat_file, is_zipped=is_zipped
49 )
50 return {"out": out_group}
51 else:
52 return self.load_single_file(
53 filepath=dat_file, is_zipped=is_zipped
54 )
55
56 def merge_files(
57 self,
58 dat_files: str,
59 is_zipped: bool,
60 ) -> Group:
61 if is_zipped:
62 all_groups = list(self.load_zipped_files().values())
63 else:
64 all_groups = []
65 for filepath in dat_files.split(","):
66 group = self.load_single_file(filepath)["out"]
67 all_groups.append(group)
68
69 return merge_groups(all_groups, xarray="energy", yarray="mu")
70
71 def load_single_file(
72 self,
73 filepath: str,
74 is_zipped: bool = False,
75 ) -> "dict[str,Group]":
76 if is_zipped:
77 return self.load_zipped_files()
78
79 print(f"Attempting to read from {filepath}")
80 if self.data_format == "athena":
81 group = read_group(filepath, self.extract_group, self.xftf_params)
82 else:
83 # Try ascii anyway
84 try:
85 group = self.load_ascii(filepath)
86 if not group.array_labels:
87 # In later versions of larch, won't get a type error it
88 # will just fail to load any data
89 group = self.load_h5(filepath)
90 except (UnicodeDecodeError, TypeError):
91 # Indicates this isn't plaintext, try h5
92 group = self.load_h5(filepath)
93 return {"out": group}
94
95 def load_ascii(self, dat_file):
96 with open(dat_file) as f:
97 labels = None
98 last_line = None
99 line = f.readline()
100 while line:
101 if not line.startswith("#"):
102 if last_line is not None and last_line.find("\t") > 0:
103 labels = []
104 for label in last_line.split("\t"):
105 labels.append(label.strip())
106 break
107
108 last_line = line
109 line = f.readline()
110
111 xas_data = read_ascii(filename=dat_file, labels=labels)
112 xas_data = self.rename_cols(xas_data)
113 return xas_data
114
115 def load_h5(self, dat_file):
116 h5_group = h5group(fname=dat_file, mode="r")
117 energy = h5_group.entry1.instrument.qexafs_energy.qexafs_energy
118 mu = h5_group.entry1.instrument.qexafs_counterTimer01.lnI0It
119 xafs_group = Group(data=np.array([energy[:], mu[:]]))
120 set_array_labels(xafs_group, ["energy", "mu"])
121 return xafs_group
122
123 def load_zipped_files(self) -> "dict[str, Group]":
124 def sorting_key(filename: str) -> str:
125 return re.findall(r"\d+", filename)[-1]
126
127 all_paths = list(os.walk("dat_files"))
128 all_paths.sort(key=lambda x: x[0])
129 file_total = sum([len(f) for _, _, f in all_paths])
130 print(f"{file_total} files found")
131 key_length = len(str(file_total))
132 i = 0
133 keyed_data = {}
134 for dirpath, _, filenames in all_paths:
135 try:
136 filenames.sort(key=sorting_key)
137 except IndexError as e:
138 print(
139 "WARNING: Unable to sort files numerically, "
140 f"defaulting to sorting alphabetically:\n{e}"
141 )
142 filenames.sort()
143
144 for filename in filenames:
145 key = str(i).zfill(key_length)
146 filepath = os.path.join(dirpath, filename)
147 xas_data = self.load_single_file(filepath)
148 keyed_data[key] = xas_data["out"]
149 i += 1
150
151 return keyed_data
152
153 def rename_cols(self, xafs_group: Group) -> Group:
154 labels = [label.lower() for label in xafs_group.array_labels]
155 print(f"Read columns: {labels}")
156
157 if "energy" in labels:
158 print("'energy' present in column headers")
159 elif self.energy_column is not None:
160 if self.energy_column.lower() in labels:
161 labels[labels.index(self.energy_column.lower())] = "energy"
162 else:
163 raise ValueError(f"{self.energy_column} not found in {labels}")
164 else:
165 for i, label in enumerate(labels):
166 if label == "col1" or label.endswith("energy"):
167 labels[i] = "energy"
168 break
169
170 if "mu" in labels:
171 print("'mu' present in column headers")
172 elif self.mu_column is not None:
173 if self.mu_column.lower() in labels:
174 labels[labels.index(self.mu_column.lower())] = "mu"
175 else:
176 raise ValueError(f"{self.mu_column} not found in {labels}")
177 else:
178 for i, label in enumerate(labels):
179 if label in ["col2", "xmu", "lni0it", "ffi0"]:
180 labels[i] = "mu"
181 break
182
183 if labels != xafs_group.array_labels:
184 print(f"Renaming columns to: {labels}")
185 return set_array_labels(xafs_group, labels)
186 else:
187 return xafs_group
188
189
190 def calibrate_energy(
191 xafs_group: Group,
192 energy_0: float,
193 energy_min: float,
194 energy_max: float,
195 energy_format: str,
196 ):
197 if energy_0 is not None:
198 print(f"Recalibrating energy edge from {xafs_group.e0} to {energy_0}")
199 xafs_group.energy = xafs_group.energy + energy_0 - xafs_group.e0
200 xafs_group.e0 = energy_0
201
202 if not (energy_min or energy_max):
203 return xafs_group
204
205 if energy_min:
206 if energy_format == "relative":
207 energy_min += xafs_group.e0
208 index_min = np.searchsorted(xafs_group.energy, energy_min)
209 else:
210 index_min = 0
211
212 if energy_max:
213 if energy_format == "relative":
214 energy_max += xafs_group.e0
215 index_max = np.searchsorted(xafs_group.energy, energy_max)
216 else:
217 index_max = len(xafs_group.energy)
218
219 print(
220 f"Cropping energy range from {energy_min} to {energy_max}, "
221 f"index {index_min} to {index_max}"
222 )
223 try:
224 xafs_group.dmude = xafs_group.dmude[index_min:index_max]
225 xafs_group.pre_edge = xafs_group.pre_edge[index_min:index_max]
226 xafs_group.post_edge = xafs_group.post_edge[index_min:index_max]
227 xafs_group.flat = xafs_group.flat[index_min:index_max]
228 except AttributeError:
229 pass
230
231 xafs_group.energy = xafs_group.energy[index_min:index_max]
232 xafs_group.mu = xafs_group.mu[index_min:index_max]
233
234 # Sanity check
235 if len(xafs_group.energy) == 0:
236 raise ValueError("Energy cropping led to an empty array")
237
238 return xafs_group
239
240
241 def main(
242 xas_data: Group,
243 input_values: dict,
244 path_key: str = "out",
245 ):
246 energy_0 = input_values["variables"]["energy_0"]
247 if energy_0 is None and hasattr(xas_data, "e0"):
248 energy_0 = xas_data.e0
249
250 energy_format = input_values["variables"]["energy_format"]
251 pre1 = input_values["variables"]["pre1"]
252 pre2 = input_values["variables"]["pre2"]
253 pre1 = validate_pre(pre1, energy_0, energy_format)
254 pre2 = validate_pre(pre2, energy_0, energy_format)
255
256 pre_edge(
257 energy=xas_data.energy,
258 mu=xas_data.mu,
259 group=xas_data,
260 e0=energy_0,
261 pre1=pre1,
262 pre2=pre2,
263 )
264
265 energy_min = input_values["variables"]["energy_min"]
266 energy_max = input_values["variables"]["energy_max"]
267 xas_data = calibrate_energy(
268 xas_data,
269 energy_0,
270 energy_min,
271 energy_max,
272 energy_format=energy_format,
273 )
274
275 if input_values["rebin"]:
276 print(xas_data.energy, xas_data.mu)
277 rebin_xafs(energy=xas_data.energy, mu=xas_data.mu, group=xas_data)
278 xas_data = xas_data.rebinned
279 pre_edge(energy=xas_data.energy, mu=xas_data.mu, group=xas_data)
280
281 try:
282 autobk(xas_data)
283 except ValueError as e:
284 raise ValueError(
285 f"autobk failed with energy={xas_data.energy}, mu={xas_data.mu}.\n"
286 "This may occur if the edge is not included in the above ranges."
287 ) from e
288 xftf(xas_data, **xftf_params)
289
290 if input_values["plot_graph"]:
291 plot_edge_fits(f"edge/{path_key}.png", xas_data)
292 plot_flattened(f"flat/{path_key}.png", xas_data)
293 plot_derivative(f"derivative/{path_key}.png", xas_data)
294
295 xas_project = create_athena(f"prj/{path_key}.prj")
296 xas_project.add_group(xas_data)
297 if input_values["annotation"]:
298 group = next(iter(xas_project.groups.values()))
299 group.args["annotation"] = input_values["annotation"]
300 xas_project.save()
301
302 # Ensure that we do not run out of memory when running on large zips
303 gc.collect()
304
305
306 def validate_pre(pre, energy_0, energy_format):
307 if pre is not None and energy_format == "absolute":
308 if energy_0 is None:
309 raise ValueError(
310 "Edge energy must be set manually or be present in the "
311 "existing Athena project if using absolute format."
312 )
313 pre -= energy_0
314
315 return pre
316
317
318 def plot_derivative(plot_path: str, xafs_group: Group):
319 plt.figure()
320 plt.plot(xafs_group.energy, xafs_group.dmude)
321 plt.grid(color="r", linestyle=":", linewidth=1)
322 plt.xlabel("Energy (eV)")
323 plt.ylabel("Derivative normalised to x$\mu$(E)") # noqa: W605
324 plt.savefig(plot_path, format="png")
325 plt.close("all")
326
327
328 def plot_edge_fits(plot_path: str, xafs_group: Group):
329 plt.figure()
330 plt.plot(xafs_group.energy, xafs_group.pre_edge, "g", label="pre-edge")
331 plt.plot(xafs_group.energy, xafs_group.post_edge, "r", label="post-edge")
332 plt.plot(xafs_group.energy, xafs_group.mu, "b", label="fit data")
333 plt.grid(color="r", linestyle=":", linewidth=1)
334 plt.xlabel("Energy (eV)")
335 plt.ylabel("x$\mu$(E)") # noqa: W605
336 plt.title("pre-edge and post_edge fitting to $\mu$") # noqa: W605
337 plt.legend()
338 plt.savefig(plot_path, format="png")
339 plt.close("all")
340
341
342 def plot_flattened(plot_path: str, xafs_group: Group):
343 plt.figure()
344 plt.plot(xafs_group.energy, xafs_group.flat)
345 plt.grid(color="r", linestyle=":", linewidth=1)
346 plt.xlabel("Energy (eV)")
347 plt.ylabel("normalised x$\mu$(E)") # noqa: W605
348 plt.savefig(plot_path, format="png")
349 plt.close("all")
350
351
352 if __name__ == "__main__":
353 # larch imports set this to an interactive backend, so need to change it
354 matplotlib.use("Agg")
355
356 dat_file = sys.argv[1]
357 input_values = json.load(open(sys.argv[2], "r", encoding="utf-8"))
358 merge_inputs = input_values["merge_inputs"]["merge_inputs"]
359 data_format = input_values["merge_inputs"]["format"]["format"]
360 if "is_zipped" in input_values["merge_inputs"]["format"]:
361 is_zipped = bool(
362 input_values["merge_inputs"]["format"]["is_zipped"]["is_zipped"]
363 )
364 else:
365 is_zipped = False
366 xftf_params = input_values["variables"]["xftf"]
367 extract_group = None
368
369 if "extract_group" in input_values["merge_inputs"]["format"]:
370 extract_group = input_values["merge_inputs"]["format"]["extract_group"]
371
372 energy_column = None
373 mu_column = None
374 if "energy_column" in input_values["merge_inputs"]["format"]:
375 energy_column = input_values["merge_inputs"]["format"]["energy_column"]
376 if "mu_column" in input_values["merge_inputs"]["format"]:
377 mu_column = input_values["merge_inputs"]["format"]["mu_column"]
378
379 reader = Reader(
380 energy_column=energy_column,
381 mu_column=mu_column,
382 xftf_params=xftf_params,
383 data_format=data_format,
384 extract_group=extract_group,
385 )
386 keyed_data = reader.load_data(
387 dat_file=dat_file,
388 merge_inputs=merge_inputs,
389 is_zipped=is_zipped,
390 )
391 for key, group in keyed_data.items():
392 main(
393 group,
394 input_values=input_values,
395 path_key=key,
396 )