diff larch_plot.py @ 4:35d24102cefd draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_plot commit 3fe6078868efd0fcea0fb5eea8dcd4b152d9c0a8
author muon-spectroscopy-computational-project
date Thu, 11 Apr 2024 09:02:24 +0000
parents 5b993aff09e3
children
line wrap: on
line diff
--- a/larch_plot.py	Fri Mar 22 14:23:33 2024 +0000
+++ b/larch_plot.py	Thu Apr 11 09:02:24 2024 +0000
@@ -3,6 +3,8 @@
 
 from common import read_groups
 
+from larch.symboltable import Group
+
 import matplotlib
 import matplotlib.pyplot as plt
 
@@ -10,67 +12,84 @@
 
 
 AXIS_LABELS = {
+    "energy": "Energy (eV)",
+    "distance": "r (ang)",
+    "sample": "Sample",
     "flat": r"x$\mu$(E), flattened",
     "dmude": r"d(x$\mu$(E))/dE, normalised",
     "chir_mag": r"|$\chi$(r)|",
-    "energy": "Energy (eV)",
-    "distance": "r (ang)",
+    "e0": "Edge Energy (eV)",
 }
 
 
+def sample_plot(groups: "list[Group]", y_variable: str):
+    x = [get_label(group) for group in groups]
+    y = [getattr(group, y_variable) for group in groups]
+    plt.scatter(x, y)
+
+
 def main(dat_files: "list[str]", plot_settings: "list[dict]"):
     groups = list(read_groups(dat_files))
 
     for i, settings in enumerate(plot_settings):
         data_list = []
-        x_variable = "energy"
         y_variable = settings["variable"]["variable"]
-        x_min = settings["variable"]["x_limit_min"]
-        x_max = settings["variable"]["x_limit_max"]
-        y_min = settings["variable"]["y_limit_min"]
-        y_max = settings["variable"]["y_limit_max"]
         plot_path = f"plots/{i}_{y_variable}.png"
         plt.figure()
 
-        for group in groups:
-            params = group.athena_params
-            annotation = getattr(params, "annotation", None)
-            file = getattr(params, "file", None)
-            params_id = getattr(params, "id", None)
-            label = annotation or file or params_id
-            if y_variable == "chir_mag":
-                x_variable = "distance"
-                x = group.r
-            else:
-                x = group.energy
-
-            y = getattr(group, y_variable)
-            if x_min is None and x_max is None:
-                plt.plot(x, y, label=label)
-            else:
-                data_list.append({"x": x, "y": y, "label": label})
+        if y_variable == "e0":
+            x_variable = "sample"
+            sample_plot(groups, y_variable)
+        else:
+            x_variable = "energy"
+            x_min = settings["variable"]["x_limit_min"]
+            x_max = settings["variable"]["x_limit_max"]
+            y_min = settings["variable"]["y_limit_min"]
+            y_max = settings["variable"]["y_limit_max"]
+            for group in groups:
+                label = get_label(group)
+                if y_variable == "chir_mag":
+                    x_variable = "distance"
+                    x = group.r
+                else:
+                    x = group.energy
 
-        if x_min is not None or x_max is not None:
-            for data in data_list:
-                index_min = None
-                index_max = None
-                x = data["x"]
-                if x_min is not None:
-                    index_min = max(np.searchsorted(x, x_min) - 1, 0)
-                if x_max is not None:
-                    index_max = min(np.searchsorted(x, x_max) + 1, len(x))
-                plt.plot(
-                    x[index_min:index_max],
-                    data["y"][index_min:index_max],
-                    label=data["label"],
-                )
+                y = getattr(group, y_variable)
+                if x_min is None and x_max is None:
+                    plt.plot(x, y, label=label)
+                else:
+                    data_list.append({"x": x, "y": y, "label": label})
 
-        plt.xlim(x_min, x_max)
-        plt.ylim(y_min, y_max)
+            if x_min is not None or x_max is not None:
+                for data in data_list:
+                    index_min = None
+                    index_max = None
+                    x = data["x"]
+                    if x_min is not None:
+                        index_min = max(np.searchsorted(x, x_min) - 1, 0)
+                    if x_max is not None:
+                        index_max = min(np.searchsorted(x, x_max) + 1, len(x))
+                    plt.plot(
+                        x[index_min:index_max],
+                        data["y"][index_min:index_max],
+                        label=data["label"],
+                    )
+
+            plt.xlim(x_min, x_max)
+            plt.ylim(y_min, y_max)
 
         save_plot(x_variable, y_variable, plot_path)
 
 
+def get_label(group: Group) -> str:
+    params = group.athena_params
+    annotation = getattr(params, "annotation", None)
+    file = getattr(params, "file", None)
+    params_id = getattr(params, "id", None)
+    label = annotation or file or params_id
+    return label
+
+
 def save_plot(x_type: str, y_type: str, plot_path: str):
     plt.grid(color="r", linestyle=":", linewidth=1)
     plt.xlabel(AXIS_LABELS[x_type])