Mercurial > repos > muon-spectroscopy-computational-project > larch_plot
comparison larch_plot.py @ 1:002c18a3e642 draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_plot commit 1cf6d7160497ba58fe16a51f00d088a20934eba6
author | muon-spectroscopy-computational-project |
---|---|
date | Wed, 06 Dec 2023 13:04:06 +0000 |
parents | 886949a03377 |
children | 59d0d15a40ef |
comparison
equal
deleted
inserted
replaced
0:886949a03377 | 1:002c18a3e642 |
---|---|
7 import matplotlib.pyplot as plt | 7 import matplotlib.pyplot as plt |
8 | 8 |
9 import numpy as np | 9 import numpy as np |
10 | 10 |
11 | 11 |
12 Y_LABELS = { | 12 AXIS_LABELS = { |
13 "norm": r"x$\mu$(E), normalised", | 13 "norm": r"x$\mu$(E), normalised", |
14 "dmude": r"d(x$\mu$(E))/dE, normalised", | 14 "dmude": r"d(x$\mu$(E))/dE, normalised", |
15 "chir_mag": r"|$\chi$(r)|", | 15 "chir_mag": r"|$\chi$(r)|", |
16 "energy": "Energy (eV)", | |
17 "distance": "r (ang)", | |
16 } | 18 } |
17 | 19 |
18 | 20 |
19 def main(dat_files: "list[str]", plot_settings: "list[dict]"): | 21 def main(dat_files: "list[str]", plot_settings: "list[dict]"): |
20 groups = list(read_groups(dat_files)) | 22 groups = list(read_groups(dat_files)) |
21 | 23 |
22 for i, settings in enumerate(plot_settings): | 24 for i, settings in enumerate(plot_settings): |
23 data_list = [] | 25 data_list = [] |
24 e0_min = None | 26 x_variable = "energy" |
25 e0_max = None | 27 y_variable = settings["variable"]["variable"] |
26 variable = settings["variable"]["variable"] | 28 x_min = settings["variable"]["x_limit_min"] |
27 x_min = settings["variable"]["energy_min"] | 29 x_max = settings["variable"]["x_limit_max"] |
28 x_max = settings["variable"]["energy_max"] | 30 y_min = settings["variable"]["y_limit_min"] |
29 plot_path = f"plots/{i}_{variable}.png" | 31 y_max = settings["variable"]["y_limit_max"] |
32 plot_path = f"plots/{i}_{y_variable}.png" | |
30 plt.figure() | 33 plt.figure() |
31 | 34 |
32 for group in groups: | 35 for group in groups: |
33 label = group.athena_params.annotation or group.athena_params.id | 36 params = group.athena_params |
34 if variable == "chir_mag": | 37 label = params.annotation or params.file or params.id |
38 if y_variable == "chir_mag": | |
39 x_variable = "distance" | |
35 x = group.r | 40 x = group.r |
36 energy_format = None | |
37 else: | 41 else: |
38 x = group.energy | 42 x = group.energy |
39 energy_format = settings["variable"]["energy_format"] | |
40 if energy_format == "relative": | |
41 e0 = group.athena_params.bkg.e0 | |
42 e0_min = find_relative_limit(e0_min, e0, min) | |
43 e0_max = find_relative_limit(e0_max, e0, max) | |
44 | 43 |
45 y = getattr(group, variable) | 44 y = getattr(group, y_variable) |
46 if x_min is None and x_max is None: | 45 if x_min is None and x_max is None: |
47 plt.plot(x, y, label=label) | 46 plt.plot(x, y, label=label) |
48 else: | 47 else: |
49 data_list.append({"x": x, "y": y, "label": label}) | 48 data_list.append({"x": x, "y": y, "label": label}) |
50 | |
51 if variable != "chir_mag" and energy_format == "relative": | |
52 if x_min is not None: | |
53 x_min += e0_min | |
54 if x_max is not None: | |
55 x_max += e0_max | |
56 | 49 |
57 if x_min is not None or x_max is not None: | 50 if x_min is not None or x_max is not None: |
58 for data in data_list: | 51 for data in data_list: |
59 index_min = None | 52 index_min = None |
60 index_max = None | 53 index_max = None |
68 data["y"][index_min:index_max], | 61 data["y"][index_min:index_max], |
69 label=data["label"], | 62 label=data["label"], |
70 ) | 63 ) |
71 | 64 |
72 plt.xlim(x_min, x_max) | 65 plt.xlim(x_min, x_max) |
66 plt.ylim(y_min, y_max) | |
73 | 67 |
74 save_plot(variable, plot_path) | 68 save_plot(x_variable, y_variable, plot_path) |
75 | 69 |
76 | 70 |
77 def find_relative_limit(e0_min: "float|None", e0: float, function: callable): | 71 def save_plot(x_type: str, y_type: str, plot_path: str): |
78 if e0_min is None: | |
79 e0_min = e0 | |
80 else: | |
81 e0_min = function(e0_min, e0) | |
82 return e0_min | |
83 | |
84 | |
85 def save_plot(y_type: str, plot_path: str): | |
86 plt.grid(color="r", linestyle=":", linewidth=1) | 72 plt.grid(color="r", linestyle=":", linewidth=1) |
87 plt.xlabel("Energy (eV)") | 73 plt.xlabel(AXIS_LABELS[x_type]) |
88 plt.ylabel(Y_LABELS[y_type]) | 74 plt.ylabel(AXIS_LABELS[y_type]) |
89 plt.legend() | 75 plt.legend() |
90 plt.savefig(plot_path, format="png") | 76 plt.savefig(plot_path, format="png") |
91 plt.close("all") | 77 plt.close("all") |
92 | 78 |
93 | 79 |