Mercurial > repos > muon-spectroscopy-computational-project > larch_athena
comparison larch_athena.py @ 0:ae2f265ecf8e draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_athena commit 5be486890442dedfb327289d597e1c8110240735
author | muon-spectroscopy-computational-project |
---|---|
date | Tue, 14 Nov 2023 15:34:40 +0000 |
parents | |
children | 2b3115342fef |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:ae2f265ecf8e |
---|---|
1 import gc | |
2 import json | |
3 import os | |
4 import re | |
5 import sys | |
6 | |
7 from common import read_group | |
8 | |
9 from larch.io import ( | |
10 create_athena, | |
11 h5group, | |
12 merge_groups, | |
13 read_ascii, | |
14 set_array_labels, | |
15 ) | |
16 from larch.symboltable import Group | |
17 from larch.xafs import autobk, pre_edge, rebin_xafs, xftf | |
18 | |
19 import matplotlib | |
20 import matplotlib.pyplot as plt | |
21 | |
22 import numpy as np | |
23 | |
24 | |
25 class Reader: | |
26 def __init__( | |
27 self, | |
28 energy_column: str, | |
29 mu_column: str, | |
30 xftf_params: dict, | |
31 data_format: str, | |
32 extract_group: str = None, | |
33 ): | |
34 self.energy_column = energy_column | |
35 self.mu_column = mu_column | |
36 self.xftf_params = xftf_params | |
37 self.data_format = data_format | |
38 self.extract_group = extract_group | |
39 | |
40 def load_data( | |
41 self, | |
42 dat_file: str, | |
43 merge_inputs: bool, | |
44 is_zipped: bool, | |
45 ) -> "dict[str, Group]": | |
46 if merge_inputs: | |
47 out_group = self.merge_files( | |
48 dat_files=dat_file, is_zipped=is_zipped | |
49 ) | |
50 return {"out": out_group} | |
51 else: | |
52 return self.load_single_file( | |
53 filepath=dat_file, is_zipped=is_zipped | |
54 ) | |
55 | |
56 def merge_files( | |
57 self, | |
58 dat_files: str, | |
59 is_zipped: bool, | |
60 ) -> Group: | |
61 if is_zipped: | |
62 all_groups = list(self.load_zipped_files().values()) | |
63 else: | |
64 all_groups = [] | |
65 for filepath in dat_files.split(","): | |
66 group = self.load_single_file(filepath)["out"] | |
67 all_groups.append(group) | |
68 | |
69 return merge_groups(all_groups, xarray="energy", yarray="mu") | |
70 | |
71 def load_single_file( | |
72 self, | |
73 filepath: str, | |
74 is_zipped: bool = False, | |
75 ) -> "dict[str,Group]": | |
76 if is_zipped: | |
77 return self.load_zipped_files() | |
78 | |
79 print(f"Attempting to read from {filepath}") | |
80 if self.data_format == "athena": | |
81 group = read_group(filepath, self.extract_group, self.xftf_params) | |
82 else: | |
83 # Try ascii anyway | |
84 try: | |
85 group = self.load_ascii(filepath) | |
86 if not group.array_labels: | |
87 # In later versions of larch, won't get a type error it | |
88 # will just fail to load any data | |
89 group = self.load_h5(filepath) | |
90 except (UnicodeDecodeError, TypeError): | |
91 # Indicates this isn't plaintext, try h5 | |
92 group = self.load_h5(filepath) | |
93 return {"out": group} | |
94 | |
95 def load_ascii(self, dat_file): | |
96 with open(dat_file) as f: | |
97 labels = None | |
98 last_line = None | |
99 line = f.readline() | |
100 while line: | |
101 if not line.startswith("#"): | |
102 if last_line is not None and last_line.find("\t") > 0: | |
103 labels = [] | |
104 for label in last_line.split("\t"): | |
105 labels.append(label.strip()) | |
106 break | |
107 | |
108 last_line = line | |
109 line = f.readline() | |
110 | |
111 xas_data = read_ascii(filename=dat_file, labels=labels) | |
112 xas_data = self.rename_cols(xas_data) | |
113 return xas_data | |
114 | |
115 def load_h5(self, dat_file): | |
116 h5_group = h5group(fname=dat_file, mode="r") | |
117 energy = h5_group.entry1.instrument.qexafs_energy.qexafs_energy | |
118 mu = h5_group.entry1.instrument.qexafs_counterTimer01.lnI0It | |
119 xafs_group = Group(data=np.array([energy[:], mu[:]])) | |
120 set_array_labels(xafs_group, ["energy", "mu"]) | |
121 return xafs_group | |
122 | |
123 def load_zipped_files(self) -> "dict[str, Group]": | |
124 def sorting_key(filename: str) -> str: | |
125 return re.findall(r"\d+", filename)[-1] | |
126 | |
127 all_paths = list(os.walk("dat_files")) | |
128 all_paths.sort(key=lambda x: x[0]) | |
129 file_total = sum([len(f) for _, _, f in all_paths]) | |
130 print(f"{file_total} files found") | |
131 key_length = len(str(file_total)) | |
132 i = 0 | |
133 keyed_data = {} | |
134 for dirpath, _, filenames in all_paths: | |
135 try: | |
136 filenames.sort(key=sorting_key) | |
137 except IndexError as e: | |
138 print( | |
139 "WARNING: Unable to sort files numerically, " | |
140 f"defaulting to sorting alphabetically:\n{e}" | |
141 ) | |
142 filenames.sort() | |
143 | |
144 for filename in filenames: | |
145 key = str(i).zfill(key_length) | |
146 filepath = os.path.join(dirpath, filename) | |
147 xas_data = self.load_single_file(filepath) | |
148 keyed_data[key] = xas_data["out"] | |
149 i += 1 | |
150 | |
151 return keyed_data | |
152 | |
153 def rename_cols(self, xafs_group: Group) -> Group: | |
154 labels = [label.lower() for label in xafs_group.array_labels] | |
155 print(f"Read columns: {labels}") | |
156 | |
157 if "energy" in labels: | |
158 print("'energy' present in column headers") | |
159 elif self.energy_column is not None: | |
160 if self.energy_column.lower() in labels: | |
161 labels[labels.index(self.energy_column.lower())] = "energy" | |
162 else: | |
163 raise ValueError(f"{self.energy_column} not found in {labels}") | |
164 else: | |
165 for i, label in enumerate(labels): | |
166 if label == "col1" or label.endswith("energy"): | |
167 labels[i] = "energy" | |
168 break | |
169 | |
170 if "mu" in labels: | |
171 print("'mu' present in column headers") | |
172 elif self.mu_column is not None: | |
173 if self.mu_column.lower() in labels: | |
174 labels[labels.index(self.mu_column.lower())] = "mu" | |
175 else: | |
176 raise ValueError(f"{self.mu_column} not found in {labels}") | |
177 else: | |
178 for i, label in enumerate(labels): | |
179 if label in ["col2", "xmu", "lni0it", "ffi0"]: | |
180 labels[i] = "mu" | |
181 break | |
182 | |
183 if labels != xafs_group.array_labels: | |
184 print(f"Renaming columns to: {labels}") | |
185 return set_array_labels(xafs_group, labels) | |
186 else: | |
187 return xafs_group | |
188 | |
189 | |
190 def calibrate_energy( | |
191 xafs_group: Group, | |
192 energy_0: float, | |
193 energy_min: float, | |
194 energy_max: float, | |
195 energy_format: str, | |
196 ): | |
197 if energy_0 is not None: | |
198 print(f"Recalibrating energy edge from {xafs_group.e0} to {energy_0}") | |
199 xafs_group.energy = xafs_group.energy + energy_0 - xafs_group.e0 | |
200 xafs_group.e0 = energy_0 | |
201 | |
202 if not (energy_min or energy_max): | |
203 return xafs_group | |
204 | |
205 if energy_min: | |
206 if energy_format == "relative": | |
207 energy_min += xafs_group.e0 | |
208 index_min = np.searchsorted(xafs_group.energy, energy_min) | |
209 else: | |
210 index_min = 0 | |
211 | |
212 if energy_max: | |
213 if energy_format == "relative": | |
214 energy_max += xafs_group.e0 | |
215 index_max = np.searchsorted(xafs_group.energy, energy_max) | |
216 else: | |
217 index_max = len(xafs_group.energy) | |
218 | |
219 print( | |
220 f"Cropping energy range from {energy_min} to {energy_max}, " | |
221 f"index {index_min} to {index_max}" | |
222 ) | |
223 try: | |
224 xafs_group.dmude = xafs_group.dmude[index_min:index_max] | |
225 xafs_group.pre_edge = xafs_group.pre_edge[index_min:index_max] | |
226 xafs_group.post_edge = xafs_group.post_edge[index_min:index_max] | |
227 xafs_group.flat = xafs_group.flat[index_min:index_max] | |
228 except AttributeError: | |
229 pass | |
230 | |
231 xafs_group.energy = xafs_group.energy[index_min:index_max] | |
232 xafs_group.mu = xafs_group.mu[index_min:index_max] | |
233 | |
234 # Sanity check | |
235 if len(xafs_group.energy) == 0: | |
236 raise ValueError("Energy cropping led to an empty array") | |
237 | |
238 return xafs_group | |
239 | |
240 | |
241 def main( | |
242 xas_data: Group, | |
243 input_values: dict, | |
244 path_key: str = "out", | |
245 ): | |
246 energy_0 = input_values["variables"]["energy_0"] | |
247 if energy_0 is None and hasattr(xas_data, "e0"): | |
248 energy_0 = xas_data.e0 | |
249 | |
250 energy_format = input_values["variables"]["energy_format"] | |
251 pre1 = input_values["variables"]["pre1"] | |
252 pre2 = input_values["variables"]["pre2"] | |
253 pre1 = validate_pre(pre1, energy_0, energy_format) | |
254 pre2 = validate_pre(pre2, energy_0, energy_format) | |
255 | |
256 pre_edge( | |
257 energy=xas_data.energy, | |
258 mu=xas_data.mu, | |
259 group=xas_data, | |
260 e0=energy_0, | |
261 pre1=pre1, | |
262 pre2=pre2, | |
263 ) | |
264 | |
265 energy_min = input_values["variables"]["energy_min"] | |
266 energy_max = input_values["variables"]["energy_max"] | |
267 xas_data = calibrate_energy( | |
268 xas_data, | |
269 energy_0, | |
270 energy_min, | |
271 energy_max, | |
272 energy_format=energy_format, | |
273 ) | |
274 | |
275 if input_values["rebin"]: | |
276 print(xas_data.energy, xas_data.mu) | |
277 rebin_xafs(energy=xas_data.energy, mu=xas_data.mu, group=xas_data) | |
278 xas_data = xas_data.rebinned | |
279 pre_edge(energy=xas_data.energy, mu=xas_data.mu, group=xas_data) | |
280 | |
281 try: | |
282 autobk(xas_data) | |
283 except ValueError as e: | |
284 raise ValueError( | |
285 f"autobk failed with energy={xas_data.energy}, mu={xas_data.mu}.\n" | |
286 "This may occur if the edge is not included in the above ranges." | |
287 ) from e | |
288 xftf(xas_data, **xftf_params) | |
289 | |
290 if input_values["plot_graph"]: | |
291 plot_edge_fits(f"edge/{path_key}.png", xas_data) | |
292 plot_flattened(f"flat/{path_key}.png", xas_data) | |
293 plot_derivative(f"derivative/{path_key}.png", xas_data) | |
294 | |
295 xas_project = create_athena(f"prj/{path_key}.prj") | |
296 xas_project.add_group(xas_data) | |
297 if input_values["annotation"]: | |
298 group = next(iter(xas_project.groups.values())) | |
299 group.args["annotation"] = input_values["annotation"] | |
300 xas_project.save() | |
301 | |
302 # Ensure that we do not run out of memory when running on large zips | |
303 gc.collect() | |
304 | |
305 | |
306 def validate_pre(pre, energy_0, energy_format): | |
307 if pre is not None and energy_format == "absolute": | |
308 if energy_0 is None: | |
309 raise ValueError( | |
310 "Edge energy must be set manually or be present in the " | |
311 "existing Athena project if using absolute format." | |
312 ) | |
313 pre -= energy_0 | |
314 | |
315 return pre | |
316 | |
317 | |
318 def plot_derivative(plot_path: str, xafs_group: Group): | |
319 plt.figure() | |
320 plt.plot(xafs_group.energy, xafs_group.dmude) | |
321 plt.grid(color="r", linestyle=":", linewidth=1) | |
322 plt.xlabel("Energy (eV)") | |
323 plt.ylabel("Derivative normalised to x$\mu$(E)") # noqa: W605 | |
324 plt.savefig(plot_path, format="png") | |
325 plt.close("all") | |
326 | |
327 | |
328 def plot_edge_fits(plot_path: str, xafs_group: Group): | |
329 plt.figure() | |
330 plt.plot(xafs_group.energy, xafs_group.pre_edge, "g", label="pre-edge") | |
331 plt.plot(xafs_group.energy, xafs_group.post_edge, "r", label="post-edge") | |
332 plt.plot(xafs_group.energy, xafs_group.mu, "b", label="fit data") | |
333 plt.grid(color="r", linestyle=":", linewidth=1) | |
334 plt.xlabel("Energy (eV)") | |
335 plt.ylabel("x$\mu$(E)") # noqa: W605 | |
336 plt.title("pre-edge and post_edge fitting to $\mu$") # noqa: W605 | |
337 plt.legend() | |
338 plt.savefig(plot_path, format="png") | |
339 plt.close("all") | |
340 | |
341 | |
342 def plot_flattened(plot_path: str, xafs_group: Group): | |
343 plt.figure() | |
344 plt.plot(xafs_group.energy, xafs_group.flat) | |
345 plt.grid(color="r", linestyle=":", linewidth=1) | |
346 plt.xlabel("Energy (eV)") | |
347 plt.ylabel("normalised x$\mu$(E)") # noqa: W605 | |
348 plt.savefig(plot_path, format="png") | |
349 plt.close("all") | |
350 | |
351 | |
352 if __name__ == "__main__": | |
353 # larch imports set this to an interactive backend, so need to change it | |
354 matplotlib.use("Agg") | |
355 | |
356 dat_file = sys.argv[1] | |
357 input_values = json.load(open(sys.argv[2], "r", encoding="utf-8")) | |
358 merge_inputs = input_values["merge_inputs"]["merge_inputs"] | |
359 data_format = input_values["merge_inputs"]["format"]["format"] | |
360 if "is_zipped" in input_values["merge_inputs"]["format"]: | |
361 is_zipped = bool( | |
362 input_values["merge_inputs"]["format"]["is_zipped"]["is_zipped"] | |
363 ) | |
364 else: | |
365 is_zipped = False | |
366 xftf_params = input_values["variables"]["xftf"] | |
367 extract_group = None | |
368 | |
369 if "extract_group" in input_values["merge_inputs"]["format"]: | |
370 extract_group = input_values["merge_inputs"]["format"]["extract_group"] | |
371 | |
372 energy_column = None | |
373 mu_column = None | |
374 if "energy_column" in input_values["merge_inputs"]["format"]: | |
375 energy_column = input_values["merge_inputs"]["format"]["energy_column"] | |
376 if "mu_column" in input_values["merge_inputs"]["format"]: | |
377 mu_column = input_values["merge_inputs"]["format"]["mu_column"] | |
378 | |
379 reader = Reader( | |
380 energy_column=energy_column, | |
381 mu_column=mu_column, | |
382 xftf_params=xftf_params, | |
383 data_format=data_format, | |
384 extract_group=extract_group, | |
385 ) | |
386 keyed_data = reader.load_data( | |
387 dat_file=dat_file, | |
388 merge_inputs=merge_inputs, | |
389 is_zipped=is_zipped, | |
390 ) | |
391 for key, group in keyed_data.items(): | |
392 main( | |
393 group, | |
394 input_values=input_values, | |
395 path_key=key, | |
396 ) |