Mercurial > repos > muon-spectroscopy-computational-project > larch_athena
diff larch_athena.py @ 0:ae2f265ecf8e draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_athena commit 5be486890442dedfb327289d597e1c8110240735
author | muon-spectroscopy-computational-project |
---|---|
date | Tue, 14 Nov 2023 15:34:40 +0000 |
parents | |
children | 2b3115342fef |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/larch_athena.py Tue Nov 14 15:34:40 2023 +0000 @@ -0,0 +1,396 @@ +import gc +import json +import os +import re +import sys + +from common import read_group + +from larch.io import ( + create_athena, + h5group, + merge_groups, + read_ascii, + set_array_labels, +) +from larch.symboltable import Group +from larch.xafs import autobk, pre_edge, rebin_xafs, xftf + +import matplotlib +import matplotlib.pyplot as plt + +import numpy as np + + +class Reader: + def __init__( + self, + energy_column: str, + mu_column: str, + xftf_params: dict, + data_format: str, + extract_group: str = None, + ): + self.energy_column = energy_column + self.mu_column = mu_column + self.xftf_params = xftf_params + self.data_format = data_format + self.extract_group = extract_group + + def load_data( + self, + dat_file: str, + merge_inputs: bool, + is_zipped: bool, + ) -> "dict[str, Group]": + if merge_inputs: + out_group = self.merge_files( + dat_files=dat_file, is_zipped=is_zipped + ) + return {"out": out_group} + else: + return self.load_single_file( + filepath=dat_file, is_zipped=is_zipped + ) + + def merge_files( + self, + dat_files: str, + is_zipped: bool, + ) -> Group: + if is_zipped: + all_groups = list(self.load_zipped_files().values()) + else: + all_groups = [] + for filepath in dat_files.split(","): + group = self.load_single_file(filepath)["out"] + all_groups.append(group) + + return merge_groups(all_groups, xarray="energy", yarray="mu") + + def load_single_file( + self, + filepath: str, + is_zipped: bool = False, + ) -> "dict[str,Group]": + if is_zipped: + return self.load_zipped_files() + + print(f"Attempting to read from {filepath}") + if self.data_format == "athena": + group = read_group(filepath, self.extract_group, self.xftf_params) + else: + # Try ascii anyway + try: + group = self.load_ascii(filepath) + if not group.array_labels: + # In later versions of larch, won't get a type error it + # will just fail to load any data + group = self.load_h5(filepath) + except (UnicodeDecodeError, TypeError): + # Indicates this isn't plaintext, try h5 + group = self.load_h5(filepath) + return {"out": group} + + def load_ascii(self, dat_file): + with open(dat_file) as f: + labels = None + last_line = None + line = f.readline() + while line: + if not line.startswith("#"): + if last_line is not None and last_line.find("\t") > 0: + labels = [] + for label in last_line.split("\t"): + labels.append(label.strip()) + break + + last_line = line + line = f.readline() + + xas_data = read_ascii(filename=dat_file, labels=labels) + xas_data = self.rename_cols(xas_data) + return xas_data + + def load_h5(self, dat_file): + h5_group = h5group(fname=dat_file, mode="r") + energy = h5_group.entry1.instrument.qexafs_energy.qexafs_energy + mu = h5_group.entry1.instrument.qexafs_counterTimer01.lnI0It + xafs_group = Group(data=np.array([energy[:], mu[:]])) + set_array_labels(xafs_group, ["energy", "mu"]) + return xafs_group + + def load_zipped_files(self) -> "dict[str, Group]": + def sorting_key(filename: str) -> str: + return re.findall(r"\d+", filename)[-1] + + all_paths = list(os.walk("dat_files")) + all_paths.sort(key=lambda x: x[0]) + file_total = sum([len(f) for _, _, f in all_paths]) + print(f"{file_total} files found") + key_length = len(str(file_total)) + i = 0 + keyed_data = {} + for dirpath, _, filenames in all_paths: + try: + filenames.sort(key=sorting_key) + except IndexError as e: + print( + "WARNING: Unable to sort files numerically, " + f"defaulting to sorting alphabetically:\n{e}" + ) + filenames.sort() + + for filename in filenames: + key = str(i).zfill(key_length) + filepath = os.path.join(dirpath, filename) + xas_data = self.load_single_file(filepath) + keyed_data[key] = xas_data["out"] + i += 1 + + return keyed_data + + def rename_cols(self, xafs_group: Group) -> Group: + labels = [label.lower() for label in xafs_group.array_labels] + print(f"Read columns: {labels}") + + if "energy" in labels: + print("'energy' present in column headers") + elif self.energy_column is not None: + if self.energy_column.lower() in labels: + labels[labels.index(self.energy_column.lower())] = "energy" + else: + raise ValueError(f"{self.energy_column} not found in {labels}") + else: + for i, label in enumerate(labels): + if label == "col1" or label.endswith("energy"): + labels[i] = "energy" + break + + if "mu" in labels: + print("'mu' present in column headers") + elif self.mu_column is not None: + if self.mu_column.lower() in labels: + labels[labels.index(self.mu_column.lower())] = "mu" + else: + raise ValueError(f"{self.mu_column} not found in {labels}") + else: + for i, label in enumerate(labels): + if label in ["col2", "xmu", "lni0it", "ffi0"]: + labels[i] = "mu" + break + + if labels != xafs_group.array_labels: + print(f"Renaming columns to: {labels}") + return set_array_labels(xafs_group, labels) + else: + return xafs_group + + +def calibrate_energy( + xafs_group: Group, + energy_0: float, + energy_min: float, + energy_max: float, + energy_format: str, +): + if energy_0 is not None: + print(f"Recalibrating energy edge from {xafs_group.e0} to {energy_0}") + xafs_group.energy = xafs_group.energy + energy_0 - xafs_group.e0 + xafs_group.e0 = energy_0 + + if not (energy_min or energy_max): + return xafs_group + + if energy_min: + if energy_format == "relative": + energy_min += xafs_group.e0 + index_min = np.searchsorted(xafs_group.energy, energy_min) + else: + index_min = 0 + + if energy_max: + if energy_format == "relative": + energy_max += xafs_group.e0 + index_max = np.searchsorted(xafs_group.energy, energy_max) + else: + index_max = len(xafs_group.energy) + + print( + f"Cropping energy range from {energy_min} to {energy_max}, " + f"index {index_min} to {index_max}" + ) + try: + xafs_group.dmude = xafs_group.dmude[index_min:index_max] + xafs_group.pre_edge = xafs_group.pre_edge[index_min:index_max] + xafs_group.post_edge = xafs_group.post_edge[index_min:index_max] + xafs_group.flat = xafs_group.flat[index_min:index_max] + except AttributeError: + pass + + xafs_group.energy = xafs_group.energy[index_min:index_max] + xafs_group.mu = xafs_group.mu[index_min:index_max] + + # Sanity check + if len(xafs_group.energy) == 0: + raise ValueError("Energy cropping led to an empty array") + + return xafs_group + + +def main( + xas_data: Group, + input_values: dict, + path_key: str = "out", +): + energy_0 = input_values["variables"]["energy_0"] + if energy_0 is None and hasattr(xas_data, "e0"): + energy_0 = xas_data.e0 + + energy_format = input_values["variables"]["energy_format"] + pre1 = input_values["variables"]["pre1"] + pre2 = input_values["variables"]["pre2"] + pre1 = validate_pre(pre1, energy_0, energy_format) + pre2 = validate_pre(pre2, energy_0, energy_format) + + pre_edge( + energy=xas_data.energy, + mu=xas_data.mu, + group=xas_data, + e0=energy_0, + pre1=pre1, + pre2=pre2, + ) + + energy_min = input_values["variables"]["energy_min"] + energy_max = input_values["variables"]["energy_max"] + xas_data = calibrate_energy( + xas_data, + energy_0, + energy_min, + energy_max, + energy_format=energy_format, + ) + + if input_values["rebin"]: + print(xas_data.energy, xas_data.mu) + rebin_xafs(energy=xas_data.energy, mu=xas_data.mu, group=xas_data) + xas_data = xas_data.rebinned + pre_edge(energy=xas_data.energy, mu=xas_data.mu, group=xas_data) + + try: + autobk(xas_data) + except ValueError as e: + raise ValueError( + f"autobk failed with energy={xas_data.energy}, mu={xas_data.mu}.\n" + "This may occur if the edge is not included in the above ranges." + ) from e + xftf(xas_data, **xftf_params) + + if input_values["plot_graph"]: + plot_edge_fits(f"edge/{path_key}.png", xas_data) + plot_flattened(f"flat/{path_key}.png", xas_data) + plot_derivative(f"derivative/{path_key}.png", xas_data) + + xas_project = create_athena(f"prj/{path_key}.prj") + xas_project.add_group(xas_data) + if input_values["annotation"]: + group = next(iter(xas_project.groups.values())) + group.args["annotation"] = input_values["annotation"] + xas_project.save() + + # Ensure that we do not run out of memory when running on large zips + gc.collect() + + +def validate_pre(pre, energy_0, energy_format): + if pre is not None and energy_format == "absolute": + if energy_0 is None: + raise ValueError( + "Edge energy must be set manually or be present in the " + "existing Athena project if using absolute format." + ) + pre -= energy_0 + + return pre + + +def plot_derivative(plot_path: str, xafs_group: Group): + plt.figure() + plt.plot(xafs_group.energy, xafs_group.dmude) + plt.grid(color="r", linestyle=":", linewidth=1) + plt.xlabel("Energy (eV)") + plt.ylabel("Derivative normalised to x$\mu$(E)") # noqa: W605 + plt.savefig(plot_path, format="png") + plt.close("all") + + +def plot_edge_fits(plot_path: str, xafs_group: Group): + plt.figure() + plt.plot(xafs_group.energy, xafs_group.pre_edge, "g", label="pre-edge") + plt.plot(xafs_group.energy, xafs_group.post_edge, "r", label="post-edge") + plt.plot(xafs_group.energy, xafs_group.mu, "b", label="fit data") + plt.grid(color="r", linestyle=":", linewidth=1) + plt.xlabel("Energy (eV)") + plt.ylabel("x$\mu$(E)") # noqa: W605 + plt.title("pre-edge and post_edge fitting to $\mu$") # noqa: W605 + plt.legend() + plt.savefig(plot_path, format="png") + plt.close("all") + + +def plot_flattened(plot_path: str, xafs_group: Group): + plt.figure() + plt.plot(xafs_group.energy, xafs_group.flat) + plt.grid(color="r", linestyle=":", linewidth=1) + plt.xlabel("Energy (eV)") + plt.ylabel("normalised x$\mu$(E)") # noqa: W605 + plt.savefig(plot_path, format="png") + plt.close("all") + + +if __name__ == "__main__": + # larch imports set this to an interactive backend, so need to change it + matplotlib.use("Agg") + + dat_file = sys.argv[1] + input_values = json.load(open(sys.argv[2], "r", encoding="utf-8")) + merge_inputs = input_values["merge_inputs"]["merge_inputs"] + data_format = input_values["merge_inputs"]["format"]["format"] + if "is_zipped" in input_values["merge_inputs"]["format"]: + is_zipped = bool( + input_values["merge_inputs"]["format"]["is_zipped"]["is_zipped"] + ) + else: + is_zipped = False + xftf_params = input_values["variables"]["xftf"] + extract_group = None + + if "extract_group" in input_values["merge_inputs"]["format"]: + extract_group = input_values["merge_inputs"]["format"]["extract_group"] + + energy_column = None + mu_column = None + if "energy_column" in input_values["merge_inputs"]["format"]: + energy_column = input_values["merge_inputs"]["format"]["energy_column"] + if "mu_column" in input_values["merge_inputs"]["format"]: + mu_column = input_values["merge_inputs"]["format"]["mu_column"] + + reader = Reader( + energy_column=energy_column, + mu_column=mu_column, + xftf_params=xftf_params, + data_format=data_format, + extract_group=extract_group, + ) + keyed_data = reader.load_data( + dat_file=dat_file, + merge_inputs=merge_inputs, + is_zipped=is_zipped, + ) + for key, group in keyed_data.items(): + main( + group, + input_values=input_values, + path_key=key, + )