diff larch_athena.py @ 3:82e9dd980916 draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_athena commit d4c7e090dc5c94395d7e1574845ac2c76f2e4f5f
author muon-spectroscopy-computational-project
date Fri, 22 Mar 2024 14:23:27 +0000
parents a1e26990131c
children a0d3b0fe0fa3
line wrap: on
line diff
--- a/larch_athena.py	Mon Mar 04 11:43:19 2024 +0000
+++ b/larch_athena.py	Fri Mar 22 14:23:27 2024 +0000
@@ -5,7 +5,10 @@
 import sys
 
 from common import (
-    pre_edge_with_defaults, read_all_groups, read_group, xftf_with_defaults
+    pre_edge_with_defaults,
+    read_all_groups,
+    read_group,
+    xftf_with_defaults,
 )
 
 from larch.io import (
@@ -45,12 +48,14 @@
     ) -> "dict[str, Group]":
         if merge_inputs:
             out_group = self.merge_files(
-                dat_files=dat_file, is_zipped=is_zipped
+                dat_files=dat_file,
+                is_zipped=is_zipped,
             )
             return {"out": out_group}
         else:
             return self.load_single_file(
-                filepath=dat_file, is_zipped=is_zipped
+                filepath=dat_file,
+                is_zipped=is_zipped,
             )
 
     def merge_files(
@@ -258,7 +263,7 @@
     pre_edge_settings: dict,
     do_xftf: bool,
     xftf_settings: dict,
-    plot_graph: bool,
+    plot_graph: list,
     annotation: str,
     path_key: str = "out",
 ):
@@ -287,9 +292,11 @@
         xftf_with_defaults(xas_data, xftf_settings)
 
     if plot_graph:
-        plot_edge_fits(f"edge/{path_key}.png", xas_data)
-        plot_flattened(f"flat/{path_key}.png", xas_data)
-        plot_derivative(f"derivative/{path_key}.png", xas_data)
+        plot_graphs(
+            plot_path=f"plot/{path_key}.png",
+            xas_data=xas_data,
+            plot_keys=plot_graph,
+        )
 
     xas_project = create_athena(f"prj/{path_key}.prj")
     xas_project.add_group(xas_data)
@@ -302,36 +309,43 @@
     gc.collect()
 
 
-def plot_derivative(plot_path: str, xafs_group: Group):
-    plt.figure()
-    plt.plot(xafs_group.energy, xafs_group.dmude)
-    plt.grid(color="r", linestyle=":", linewidth=1)
-    plt.xlabel("Energy (eV)")
-    plt.ylabel("Derivative normalised to x$\mu$(E)")  # noqa: W605
-    plt.savefig(plot_path, format="png")
-    plt.close("all")
-
+def plot_graphs(
+    plot_path: str,
+    xas_data: Group,
+    plot_keys: list,
+) -> None:
+    nrows = len(plot_keys)
+    index = 1
+    plt.figure(figsize=(6.4, nrows * 4.8))
+    if "edge" in plot_keys:
+        plt.subplot(nrows, 1, index)
+        plt.plot(xas_data.energy, xas_data.pre_edge, "g", label="pre-edge")
+        plt.plot(xas_data.energy, xas_data.post_edge, "r", label="post-edge")
+        plt.plot(xas_data.energy, xas_data.mu, "b", label="fit data")
+        plt.grid(color="r", linestyle=":", linewidth=1)
+        plt.xlabel("Energy (eV)")
+        plt.ylabel("x$\mu$(E)")  # noqa: W605
+        plt.title("Pre-edge and post_edge fitting to $\mu$")  # noqa: W605
+        plt.legend()
+        index += 1
 
-def plot_edge_fits(plot_path: str, xafs_group: Group):
-    plt.figure()
-    plt.plot(xafs_group.energy, xafs_group.pre_edge, "g", label="pre-edge")
-    plt.plot(xafs_group.energy, xafs_group.post_edge, "r", label="post-edge")
-    plt.plot(xafs_group.energy, xafs_group.mu, "b", label="fit data")
-    plt.grid(color="r", linestyle=":", linewidth=1)
-    plt.xlabel("Energy (eV)")
-    plt.ylabel("x$\mu$(E)")  # noqa: W605
-    plt.title("pre-edge and post_edge fitting to $\mu$")  # noqa: W605
-    plt.legend()
-    plt.savefig(plot_path, format="png")
-    plt.close("all")
+    if "flat" in plot_keys:
+        plt.subplot(nrows, 1, index)
+        plt.plot(xas_data.energy, xas_data.flat)
+        plt.grid(color="r", linestyle=":", linewidth=1)
+        plt.xlabel("Energy (eV)")
+        plt.ylabel("Flattened x$\mu$(E)")  # noqa: W605
+        index += 1
 
+    if "dmude" in plot_keys:
+        plt.subplot(nrows, 1, index)
+        plt.plot(xas_data.energy, xas_data.dmude)
+        plt.grid(color="r", linestyle=":", linewidth=1)
+        plt.xlabel("Energy (eV)")
+        plt.ylabel("Derivative normalised to x$\mu$(E)")  # noqa: W605
+        index += 1
 
-def plot_flattened(plot_path: str, xafs_group: Group):
-    plt.figure()
-    plt.plot(xafs_group.energy, xafs_group.flat)
-    plt.grid(color="r", linestyle=":", linewidth=1)
-    plt.xlabel("Energy (eV)")
-    plt.ylabel("normalised x$\mu$(E)")  # noqa: W605
+    plt.tight_layout(rect=(0, 0, 0.88, 1))
     plt.savefig(plot_path, format="png")
     plt.close("all")