comparison larch_plot.py @ 4:35d24102cefd draft

planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_plot commit 3fe6078868efd0fcea0fb5eea8dcd4b152d9c0a8
author muon-spectroscopy-computational-project
date Thu, 11 Apr 2024 09:02:24 +0000
parents 5b993aff09e3
children
comparison
equal deleted inserted replaced
3:5b993aff09e3 4:35d24102cefd
1 import json 1 import json
2 import sys 2 import sys
3 3
4 from common import read_groups 4 from common import read_groups
5
6 from larch.symboltable import Group
5 7
6 import matplotlib 8 import matplotlib
7 import matplotlib.pyplot as plt 9 import matplotlib.pyplot as plt
8 10
9 import numpy as np 11 import numpy as np
10 12
11 13
12 AXIS_LABELS = { 14 AXIS_LABELS = {
15 "energy": "Energy (eV)",
16 "distance": "r (ang)",
17 "sample": "Sample",
13 "flat": r"x$\mu$(E), flattened", 18 "flat": r"x$\mu$(E), flattened",
14 "dmude": r"d(x$\mu$(E))/dE, normalised", 19 "dmude": r"d(x$\mu$(E))/dE, normalised",
15 "chir_mag": r"|$\chi$(r)|", 20 "chir_mag": r"|$\chi$(r)|",
16 "energy": "Energy (eV)", 21 "e0": "Edge Energy (eV)",
17 "distance": "r (ang)",
18 } 22 }
23
24
25 def sample_plot(groups: "list[Group]", y_variable: str):
26 x = [get_label(group) for group in groups]
27 y = [getattr(group, y_variable) for group in groups]
28 plt.scatter(x, y)
19 29
20 30
21 def main(dat_files: "list[str]", plot_settings: "list[dict]"): 31 def main(dat_files: "list[str]", plot_settings: "list[dict]"):
22 groups = list(read_groups(dat_files)) 32 groups = list(read_groups(dat_files))
23 33
24 for i, settings in enumerate(plot_settings): 34 for i, settings in enumerate(plot_settings):
25 data_list = [] 35 data_list = []
26 x_variable = "energy"
27 y_variable = settings["variable"]["variable"] 36 y_variable = settings["variable"]["variable"]
28 x_min = settings["variable"]["x_limit_min"]
29 x_max = settings["variable"]["x_limit_max"]
30 y_min = settings["variable"]["y_limit_min"]
31 y_max = settings["variable"]["y_limit_max"]
32 plot_path = f"plots/{i}_{y_variable}.png" 37 plot_path = f"plots/{i}_{y_variable}.png"
33 plt.figure() 38 plt.figure()
34 39
35 for group in groups: 40 if y_variable == "e0":
36 params = group.athena_params 41 x_variable = "sample"
37 annotation = getattr(params, "annotation", None) 42 sample_plot(groups, y_variable)
38 file = getattr(params, "file", None) 43 else:
39 params_id = getattr(params, "id", None) 44 x_variable = "energy"
40 label = annotation or file or params_id 45 x_min = settings["variable"]["x_limit_min"]
41 if y_variable == "chir_mag": 46 x_max = settings["variable"]["x_limit_max"]
42 x_variable = "distance" 47 y_min = settings["variable"]["y_limit_min"]
43 x = group.r 48 y_max = settings["variable"]["y_limit_max"]
44 else: 49 for group in groups:
45 x = group.energy 50 label = get_label(group)
51 if y_variable == "chir_mag":
52 x_variable = "distance"
53 x = group.r
54 else:
55 x = group.energy
46 56
47 y = getattr(group, y_variable) 57 y = getattr(group, y_variable)
48 if x_min is None and x_max is None: 58 if x_min is None and x_max is None:
49 plt.plot(x, y, label=label) 59 plt.plot(x, y, label=label)
50 else: 60 else:
51 data_list.append({"x": x, "y": y, "label": label}) 61 data_list.append({"x": x, "y": y, "label": label})
52 62
53 if x_min is not None or x_max is not None: 63 if x_min is not None or x_max is not None:
54 for data in data_list: 64 for data in data_list:
55 index_min = None 65 index_min = None
56 index_max = None 66 index_max = None
57 x = data["x"] 67 x = data["x"]
58 if x_min is not None: 68 if x_min is not None:
59 index_min = max(np.searchsorted(x, x_min) - 1, 0) 69 index_min = max(np.searchsorted(x, x_min) - 1, 0)
60 if x_max is not None: 70 if x_max is not None:
61 index_max = min(np.searchsorted(x, x_max) + 1, len(x)) 71 index_max = min(np.searchsorted(x, x_max) + 1, len(x))
62 plt.plot( 72 plt.plot(
63 x[index_min:index_max], 73 x[index_min:index_max],
64 data["y"][index_min:index_max], 74 data["y"][index_min:index_max],
65 label=data["label"], 75 label=data["label"],
66 ) 76 )
67 77
68 plt.xlim(x_min, x_max) 78 plt.xlim(x_min, x_max)
69 plt.ylim(y_min, y_max) 79 plt.ylim(y_min, y_max)
70 80
71 save_plot(x_variable, y_variable, plot_path) 81 save_plot(x_variable, y_variable, plot_path)
82
83
84 def get_label(group: Group) -> str:
85 params = group.athena_params
86 annotation = getattr(params, "annotation", None)
87 file = getattr(params, "file", None)
88 params_id = getattr(params, "id", None)
89 label = annotation or file or params_id
90 return label
72 91
73 92
74 def save_plot(x_type: str, y_type: str, plot_path: str): 93 def save_plot(x_type: str, y_type: str, plot_path: str):
75 plt.grid(color="r", linestyle=":", linewidth=1) 94 plt.grid(color="r", linestyle=":", linewidth=1)
76 plt.xlabel(AXIS_LABELS[x_type]) 95 plt.xlabel(AXIS_LABELS[x_type])