Mercurial > repos > muon-spectroscopy-computational-project > larch_plot
diff larch_plot.py @ 0:886949a03377 draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_plot commit 5be486890442dedfb327289d597e1c8110240735
author | muon-spectroscopy-computational-project |
---|---|
date | Tue, 14 Nov 2023 15:35:36 +0000 |
parents | |
children | 002c18a3e642 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/larch_plot.py Tue Nov 14 15:35:36 2023 +0000 @@ -0,0 +1,101 @@ +import json +import sys + +from common import read_groups + +import matplotlib +import matplotlib.pyplot as plt + +import numpy as np + + +Y_LABELS = { + "norm": r"x$\mu$(E), normalised", + "dmude": r"d(x$\mu$(E))/dE, normalised", + "chir_mag": r"|$\chi$(r)|", +} + + +def main(dat_files: "list[str]", plot_settings: "list[dict]"): + groups = list(read_groups(dat_files)) + + for i, settings in enumerate(plot_settings): + data_list = [] + e0_min = None + e0_max = None + variable = settings["variable"]["variable"] + x_min = settings["variable"]["energy_min"] + x_max = settings["variable"]["energy_max"] + plot_path = f"plots/{i}_{variable}.png" + plt.figure() + + for group in groups: + label = group.athena_params.annotation or group.athena_params.id + if variable == "chir_mag": + x = group.r + energy_format = None + else: + x = group.energy + energy_format = settings["variable"]["energy_format"] + if energy_format == "relative": + e0 = group.athena_params.bkg.e0 + e0_min = find_relative_limit(e0_min, e0, min) + e0_max = find_relative_limit(e0_max, e0, max) + + y = getattr(group, variable) + if x_min is None and x_max is None: + plt.plot(x, y, label=label) + else: + data_list.append({"x": x, "y": y, "label": label}) + + if variable != "chir_mag" and energy_format == "relative": + if x_min is not None: + x_min += e0_min + if x_max is not None: + x_max += e0_max + + if x_min is not None or x_max is not None: + for data in data_list: + index_min = None + index_max = None + x = data["x"] + if x_min is not None: + index_min = max(np.searchsorted(x, x_min) - 1, 0) + if x_max is not None: + index_max = min(np.searchsorted(x, x_max) + 1, len(x)) + plt.plot( + x[index_min:index_max], + data["y"][index_min:index_max], + label=data["label"], + ) + + plt.xlim(x_min, x_max) + + save_plot(variable, plot_path) + + +def find_relative_limit(e0_min: "float|None", e0: float, function: callable): + if e0_min is None: + e0_min = e0 + else: + e0_min = function(e0_min, e0) + return e0_min + + +def save_plot(y_type: str, plot_path: str): + plt.grid(color="r", linestyle=":", linewidth=1) + plt.xlabel("Energy (eV)") + plt.ylabel(Y_LABELS[y_type]) + plt.legend() + plt.savefig(plot_path, format="png") + plt.close("all") + + +if __name__ == "__main__": + # larch imports set this to an interactive backend, so need to change it + matplotlib.use("Agg") + + dat_files = sys.argv[1] + input_values = json.load(open(sys.argv[2], "r", encoding="utf-8")) + + main(dat_files.split(","), input_values["plots"])