diff larch_athena.py @ 1:2b3115342fef draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_athena commit 1cf6d7160497ba58fe16a51f00d088a20934eba6
author muon-spectroscopy-computational-project
date Wed, 06 Dec 2023 13:03:55 +0000
parents ae2f265ecf8e
children a1e26990131c
line wrap: on
line diff
--- a/larch_athena.py	Tue Nov 14 15:34:40 2023 +0000
+++ b/larch_athena.py	Wed Dec 06 13:03:55 2023 +0000
@@ -4,7 +4,9 @@
 import re
 import sys
 
-from common import read_group
+from common import (
+    pre_edge_with_defaults, read_all_groups, read_group, xftf_with_defaults
+)
 
 from larch.io import (
     create_athena,
@@ -14,7 +16,7 @@
     set_array_labels,
 )
 from larch.symboltable import Group
-from larch.xafs import autobk, pre_edge, rebin_xafs, xftf
+from larch.xafs import rebin_xafs
 
 import matplotlib
 import matplotlib.pyplot as plt
@@ -27,13 +29,11 @@
         self,
         energy_column: str,
         mu_column: str,
-        xftf_params: dict,
         data_format: str,
-        extract_group: str = None,
+        extract_group: "dict[str, str]" = None,
     ):
         self.energy_column = energy_column
         self.mu_column = mu_column
-        self.xftf_params = xftf_params
         self.data_format = data_format
         self.extract_group = extract_group
 
@@ -72,13 +72,24 @@
         self,
         filepath: str,
         is_zipped: bool = False,
-    ) -> "dict[str,Group]":
+    ) -> "tuple[dict, bool]":
         if is_zipped:
             return self.load_zipped_files()
 
         print(f"Attempting to read from {filepath}")
         if self.data_format == "athena":
-            group = read_group(filepath, self.extract_group, self.xftf_params)
+            if self.extract_group["extract_group"] == "single":
+                group = read_group(filepath, self.extract_group["group_name"])
+                return {"out": group}
+            elif self.extract_group["extract_group"] == "multiple":
+                groups = {}
+                for repeat in self.extract_group["multiple"]:
+                    name = repeat["group_name"]
+                    groups[name] = read_group(filepath, name)
+                return groups
+            else:
+                return read_all_groups(filepath)
+
         else:
             # Try ascii anyway
             try:
@@ -90,7 +101,9 @@
             except (UnicodeDecodeError, TypeError):
                 # Indicates this isn't plaintext, try h5
                 group = self.load_h5(filepath)
-        return {"out": group}
+            pre_edge_with_defaults(group)
+            xftf_with_defaults(group)
+            return {"out": group}
 
     def load_ascii(self, dat_file):
         with open(dat_file) as f:
@@ -156,27 +169,27 @@
 
         if "energy" in labels:
             print("'energy' present in column headers")
-        elif self.energy_column is not None:
+        elif self.energy_column:
             if self.energy_column.lower() in labels:
                 labels[labels.index(self.energy_column.lower())] = "energy"
             else:
                 raise ValueError(f"{self.energy_column} not found in {labels}")
         else:
             for i, label in enumerate(labels):
-                if label == "col1" or label.endswith("energy"):
+                if label in ("col1", "ef") or label.endswith("energy"):
                     labels[i] = "energy"
                     break
 
         if "mu" in labels:
             print("'mu' present in column headers")
-        elif self.mu_column is not None:
+        elif self.mu_column:
             if self.mu_column.lower() in labels:
                 labels[labels.index(self.mu_column.lower())] = "mu"
             else:
                 raise ValueError(f"{self.mu_column} not found in {labels}")
         else:
             for i, label in enumerate(labels):
-                if label in ["col2", "xmu", "lni0it", "ffi0"]:
+                if label in ["col2", "xmu", "lni0it", "ffi0", "ff/i1"]:
                     labels[i] = "mu"
                     break
 
@@ -189,29 +202,24 @@
 
 def calibrate_energy(
     xafs_group: Group,
-    energy_0: float,
-    energy_min: float,
-    energy_max: float,
-    energy_format: str,
+    calibration_e0: float = None,
+    energy_min: float = None,
+    energy_max: float = None,
 ):
-    if energy_0 is not None:
-        print(f"Recalibrating energy edge from {xafs_group.e0} to {energy_0}")
-        xafs_group.energy = xafs_group.energy + energy_0 - xafs_group.e0
-        xafs_group.e0 = energy_0
+    if calibration_e0 is not None:
+        print(f"Recalibrating edge from {xafs_group.e0} to {calibration_e0}")
+        xafs_group.energy = xafs_group.energy + calibration_e0 - xafs_group.e0
+        xafs_group.e0 = calibration_e0
 
     if not (energy_min or energy_max):
         return xafs_group
 
-    if energy_min:
-        if energy_format == "relative":
-            energy_min += xafs_group.e0
+    if energy_min is not None:
         index_min = np.searchsorted(xafs_group.energy, energy_min)
     else:
         index_min = 0
 
-    if energy_max:
-        if energy_format == "relative":
-            energy_max += xafs_group.e0
+    if energy_max is not None:
         index_max = np.searchsorted(xafs_group.energy, energy_max)
     else:
         index_max = len(xafs_group.energy)
@@ -240,81 +248,57 @@
 
 def main(
     xas_data: Group,
-    input_values: dict,
+    do_calibrate: bool,
+    calibrate_settings: dict,
+    do_rebin: bool,
+    do_pre_edge: bool,
+    pre_edge_settings: dict,
+    do_xftf: bool,
+    xftf_settings: dict,
+    plot_graph: bool,
+    annotation: str,
     path_key: str = "out",
 ):
-    energy_0 = input_values["variables"]["energy_0"]
-    if energy_0 is None and hasattr(xas_data, "e0"):
-        energy_0 = xas_data.e0
-
-    energy_format = input_values["variables"]["energy_format"]
-    pre1 = input_values["variables"]["pre1"]
-    pre2 = input_values["variables"]["pre2"]
-    pre1 = validate_pre(pre1, energy_0, energy_format)
-    pre2 = validate_pre(pre2, energy_0, energy_format)
-
-    pre_edge(
-        energy=xas_data.energy,
-        mu=xas_data.mu,
-        group=xas_data,
-        e0=energy_0,
-        pre1=pre1,
-        pre2=pre2,
-    )
+    if do_calibrate:
+        print(f"Calibrating energy with {calibrate_settings}")
+        xas_data = calibrate_energy(xas_data, **calibrate_settings)
+        # After re-calibrating, will need to redo pre-edge with new range
+        do_pre_edge = True
 
-    energy_min = input_values["variables"]["energy_min"]
-    energy_max = input_values["variables"]["energy_max"]
-    xas_data = calibrate_energy(
-        xas_data,
-        energy_0,
-        energy_min,
-        energy_max,
-        energy_format=energy_format,
-    )
+    if do_rebin:
+        print("Re-binning data")
+        rebin_xafs(
+            energy=xas_data.energy,
+            mu=xas_data.mu,
+            group=xas_data,
+            **pre_edge_settings,
+        )
+        xas_data = xas_data.rebinned
+        # After re-bin, will need to redo pre-edge
+        do_pre_edge = True
 
-    if input_values["rebin"]:
-        print(xas_data.energy, xas_data.mu)
-        rebin_xafs(energy=xas_data.energy, mu=xas_data.mu, group=xas_data)
-        xas_data = xas_data.rebinned
-        pre_edge(energy=xas_data.energy, mu=xas_data.mu, group=xas_data)
+    if do_pre_edge:
+        pre_edge_with_defaults(xas_data, pre_edge_settings)
 
-    try:
-        autobk(xas_data)
-    except ValueError as e:
-        raise ValueError(
-            f"autobk failed with energy={xas_data.energy}, mu={xas_data.mu}.\n"
-            "This may occur if the edge is not included in the above ranges."
-        ) from e
-    xftf(xas_data, **xftf_params)
+    if do_xftf:
+        xftf_with_defaults(xas_data, xftf_settings)
 
-    if input_values["plot_graph"]:
+    if plot_graph:
         plot_edge_fits(f"edge/{path_key}.png", xas_data)
         plot_flattened(f"flat/{path_key}.png", xas_data)
         plot_derivative(f"derivative/{path_key}.png", xas_data)
 
     xas_project = create_athena(f"prj/{path_key}.prj")
     xas_project.add_group(xas_data)
-    if input_values["annotation"]:
+    if annotation:
         group = next(iter(xas_project.groups.values()))
-        group.args["annotation"] = input_values["annotation"]
+        group.args["annotation"] = annotation
     xas_project.save()
 
     # Ensure that we do not run out of memory when running on large zips
     gc.collect()
 
 
-def validate_pre(pre, energy_0, energy_format):
-    if pre is not None and energy_format == "absolute":
-        if energy_0 is None:
-            raise ValueError(
-                "Edge energy must be set manually or be present in the "
-                "existing Athena project if using absolute format."
-            )
-        pre -= energy_0
-
-    return pre
-
-
 def plot_derivative(plot_path: str, xafs_group: Group):
     plt.figure()
     plt.plot(xafs_group.energy, xafs_group.dmude)
@@ -363,9 +347,8 @@
         )
     else:
         is_zipped = False
-    xftf_params = input_values["variables"]["xftf"]
+
     extract_group = None
-
     if "extract_group" in input_values["merge_inputs"]["format"]:
         extract_group = input_values["merge_inputs"]["format"]["extract_group"]
 
@@ -379,7 +362,6 @@
     reader = Reader(
         energy_column=energy_column,
         mu_column=mu_column,
-        xftf_params=xftf_params,
         data_format=data_format,
         extract_group=extract_group,
     )
@@ -388,9 +370,35 @@
         merge_inputs=merge_inputs,
         is_zipped=is_zipped,
     )
+
+    calibrate_items = input_values["processing"]["calibrate"].items()
+    calibrate_settings = {k: v for k, v in calibrate_items if v is not None}
+    do_calibrate = calibrate_settings.pop("calibrate") == "true"
+
+    do_rebin = input_values["processing"].pop("rebin")
+
+    pre_edge_items = input_values["processing"]["pre_edge"].items()
+    pre_edge_settings = {k: v for k, v in pre_edge_items if v is not None}
+    do_pre_edge = pre_edge_settings.pop("pre_edge") == "true"
+
+    xftf_items = input_values["processing"]["xftf"].items()
+    xftf_settings = {k: v for k, v in xftf_items if v is not None}
+    do_xftf = xftf_settings.pop("xftf") == "true"
+
+    plot_graph = input_values["plot_graph"]
+    annotation = input_values["annotation"]
+
     for key, group in keyed_data.items():
         main(
             group,
-            input_values=input_values,
+            do_calibrate=do_calibrate,
+            calibrate_settings=calibrate_settings,
+            do_rebin=do_rebin,
+            do_pre_edge=do_pre_edge,
+            pre_edge_settings=pre_edge_settings,
+            do_xftf=do_xftf,
+            xftf_settings=xftf_settings,
+            plot_graph=plot_graph,
+            annotation=annotation,
             path_key=key,
         )