Mercurial > repos > muon-spectroscopy-computational-project > larch_athena
diff larch_athena.py @ 1:2b3115342fef draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_athena commit 1cf6d7160497ba58fe16a51f00d088a20934eba6
author | muon-spectroscopy-computational-project |
---|---|
date | Wed, 06 Dec 2023 13:03:55 +0000 |
parents | ae2f265ecf8e |
children | a1e26990131c |
line wrap: on
line diff
--- a/larch_athena.py Tue Nov 14 15:34:40 2023 +0000 +++ b/larch_athena.py Wed Dec 06 13:03:55 2023 +0000 @@ -4,7 +4,9 @@ import re import sys -from common import read_group +from common import ( + pre_edge_with_defaults, read_all_groups, read_group, xftf_with_defaults +) from larch.io import ( create_athena, @@ -14,7 +16,7 @@ set_array_labels, ) from larch.symboltable import Group -from larch.xafs import autobk, pre_edge, rebin_xafs, xftf +from larch.xafs import rebin_xafs import matplotlib import matplotlib.pyplot as plt @@ -27,13 +29,11 @@ self, energy_column: str, mu_column: str, - xftf_params: dict, data_format: str, - extract_group: str = None, + extract_group: "dict[str, str]" = None, ): self.energy_column = energy_column self.mu_column = mu_column - self.xftf_params = xftf_params self.data_format = data_format self.extract_group = extract_group @@ -72,13 +72,24 @@ self, filepath: str, is_zipped: bool = False, - ) -> "dict[str,Group]": + ) -> "tuple[dict, bool]": if is_zipped: return self.load_zipped_files() print(f"Attempting to read from {filepath}") if self.data_format == "athena": - group = read_group(filepath, self.extract_group, self.xftf_params) + if self.extract_group["extract_group"] == "single": + group = read_group(filepath, self.extract_group["group_name"]) + return {"out": group} + elif self.extract_group["extract_group"] == "multiple": + groups = {} + for repeat in self.extract_group["multiple"]: + name = repeat["group_name"] + groups[name] = read_group(filepath, name) + return groups + else: + return read_all_groups(filepath) + else: # Try ascii anyway try: @@ -90,7 +101,9 @@ except (UnicodeDecodeError, TypeError): # Indicates this isn't plaintext, try h5 group = self.load_h5(filepath) - return {"out": group} + pre_edge_with_defaults(group) + xftf_with_defaults(group) + return {"out": group} def load_ascii(self, dat_file): with open(dat_file) as f: @@ -156,27 +169,27 @@ if "energy" in labels: print("'energy' present in column headers") - elif self.energy_column is not None: + elif self.energy_column: if self.energy_column.lower() in labels: labels[labels.index(self.energy_column.lower())] = "energy" else: raise ValueError(f"{self.energy_column} not found in {labels}") else: for i, label in enumerate(labels): - if label == "col1" or label.endswith("energy"): + if label in ("col1", "ef") or label.endswith("energy"): labels[i] = "energy" break if "mu" in labels: print("'mu' present in column headers") - elif self.mu_column is not None: + elif self.mu_column: if self.mu_column.lower() in labels: labels[labels.index(self.mu_column.lower())] = "mu" else: raise ValueError(f"{self.mu_column} not found in {labels}") else: for i, label in enumerate(labels): - if label in ["col2", "xmu", "lni0it", "ffi0"]: + if label in ["col2", "xmu", "lni0it", "ffi0", "ff/i1"]: labels[i] = "mu" break @@ -189,29 +202,24 @@ def calibrate_energy( xafs_group: Group, - energy_0: float, - energy_min: float, - energy_max: float, - energy_format: str, + calibration_e0: float = None, + energy_min: float = None, + energy_max: float = None, ): - if energy_0 is not None: - print(f"Recalibrating energy edge from {xafs_group.e0} to {energy_0}") - xafs_group.energy = xafs_group.energy + energy_0 - xafs_group.e0 - xafs_group.e0 = energy_0 + if calibration_e0 is not None: + print(f"Recalibrating edge from {xafs_group.e0} to {calibration_e0}") + xafs_group.energy = xafs_group.energy + calibration_e0 - xafs_group.e0 + xafs_group.e0 = calibration_e0 if not (energy_min or energy_max): return xafs_group - if energy_min: - if energy_format == "relative": - energy_min += xafs_group.e0 + if energy_min is not None: index_min = np.searchsorted(xafs_group.energy, energy_min) else: index_min = 0 - if energy_max: - if energy_format == "relative": - energy_max += xafs_group.e0 + if energy_max is not None: index_max = np.searchsorted(xafs_group.energy, energy_max) else: index_max = len(xafs_group.energy) @@ -240,81 +248,57 @@ def main( xas_data: Group, - input_values: dict, + do_calibrate: bool, + calibrate_settings: dict, + do_rebin: bool, + do_pre_edge: bool, + pre_edge_settings: dict, + do_xftf: bool, + xftf_settings: dict, + plot_graph: bool, + annotation: str, path_key: str = "out", ): - energy_0 = input_values["variables"]["energy_0"] - if energy_0 is None and hasattr(xas_data, "e0"): - energy_0 = xas_data.e0 - - energy_format = input_values["variables"]["energy_format"] - pre1 = input_values["variables"]["pre1"] - pre2 = input_values["variables"]["pre2"] - pre1 = validate_pre(pre1, energy_0, energy_format) - pre2 = validate_pre(pre2, energy_0, energy_format) - - pre_edge( - energy=xas_data.energy, - mu=xas_data.mu, - group=xas_data, - e0=energy_0, - pre1=pre1, - pre2=pre2, - ) + if do_calibrate: + print(f"Calibrating energy with {calibrate_settings}") + xas_data = calibrate_energy(xas_data, **calibrate_settings) + # After re-calibrating, will need to redo pre-edge with new range + do_pre_edge = True - energy_min = input_values["variables"]["energy_min"] - energy_max = input_values["variables"]["energy_max"] - xas_data = calibrate_energy( - xas_data, - energy_0, - energy_min, - energy_max, - energy_format=energy_format, - ) + if do_rebin: + print("Re-binning data") + rebin_xafs( + energy=xas_data.energy, + mu=xas_data.mu, + group=xas_data, + **pre_edge_settings, + ) + xas_data = xas_data.rebinned + # After re-bin, will need to redo pre-edge + do_pre_edge = True - if input_values["rebin"]: - print(xas_data.energy, xas_data.mu) - rebin_xafs(energy=xas_data.energy, mu=xas_data.mu, group=xas_data) - xas_data = xas_data.rebinned - pre_edge(energy=xas_data.energy, mu=xas_data.mu, group=xas_data) + if do_pre_edge: + pre_edge_with_defaults(xas_data, pre_edge_settings) - try: - autobk(xas_data) - except ValueError as e: - raise ValueError( - f"autobk failed with energy={xas_data.energy}, mu={xas_data.mu}.\n" - "This may occur if the edge is not included in the above ranges." - ) from e - xftf(xas_data, **xftf_params) + if do_xftf: + xftf_with_defaults(xas_data, xftf_settings) - if input_values["plot_graph"]: + if plot_graph: plot_edge_fits(f"edge/{path_key}.png", xas_data) plot_flattened(f"flat/{path_key}.png", xas_data) plot_derivative(f"derivative/{path_key}.png", xas_data) xas_project = create_athena(f"prj/{path_key}.prj") xas_project.add_group(xas_data) - if input_values["annotation"]: + if annotation: group = next(iter(xas_project.groups.values())) - group.args["annotation"] = input_values["annotation"] + group.args["annotation"] = annotation xas_project.save() # Ensure that we do not run out of memory when running on large zips gc.collect() -def validate_pre(pre, energy_0, energy_format): - if pre is not None and energy_format == "absolute": - if energy_0 is None: - raise ValueError( - "Edge energy must be set manually or be present in the " - "existing Athena project if using absolute format." - ) - pre -= energy_0 - - return pre - - def plot_derivative(plot_path: str, xafs_group: Group): plt.figure() plt.plot(xafs_group.energy, xafs_group.dmude) @@ -363,9 +347,8 @@ ) else: is_zipped = False - xftf_params = input_values["variables"]["xftf"] + extract_group = None - if "extract_group" in input_values["merge_inputs"]["format"]: extract_group = input_values["merge_inputs"]["format"]["extract_group"] @@ -379,7 +362,6 @@ reader = Reader( energy_column=energy_column, mu_column=mu_column, - xftf_params=xftf_params, data_format=data_format, extract_group=extract_group, ) @@ -388,9 +370,35 @@ merge_inputs=merge_inputs, is_zipped=is_zipped, ) + + calibrate_items = input_values["processing"]["calibrate"].items() + calibrate_settings = {k: v for k, v in calibrate_items if v is not None} + do_calibrate = calibrate_settings.pop("calibrate") == "true" + + do_rebin = input_values["processing"].pop("rebin") + + pre_edge_items = input_values["processing"]["pre_edge"].items() + pre_edge_settings = {k: v for k, v in pre_edge_items if v is not None} + do_pre_edge = pre_edge_settings.pop("pre_edge") == "true" + + xftf_items = input_values["processing"]["xftf"].items() + xftf_settings = {k: v for k, v in xftf_items if v is not None} + do_xftf = xftf_settings.pop("xftf") == "true" + + plot_graph = input_values["plot_graph"] + annotation = input_values["annotation"] + for key, group in keyed_data.items(): main( group, - input_values=input_values, + do_calibrate=do_calibrate, + calibrate_settings=calibrate_settings, + do_rebin=do_rebin, + do_pre_edge=do_pre_edge, + pre_edge_settings=pre_edge_settings, + do_xftf=do_xftf, + xftf_settings=xftf_settings, + plot_graph=plot_graph, + annotation=annotation, path_key=key, )