Mercurial > repos > muon-spectroscopy-computational-project > larch_plot
view larch_plot.py @ 2:59d0d15a40ef draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_plot commit 0f66842e802430e887d1c6cb7be1cc5436408fd2
author | muon-spectroscopy-computational-project |
---|---|
date | Mon, 04 Mar 2024 11:43:46 +0000 |
parents | 002c18a3e642 |
children | 5b993aff09e3 |
line wrap: on
line source
import json import sys from common import read_groups import matplotlib import matplotlib.pyplot as plt import numpy as np AXIS_LABELS = { "norm": r"x$\mu$(E), normalised", "dmude": r"d(x$\mu$(E))/dE, normalised", "chir_mag": r"|$\chi$(r)|", "energy": "Energy (eV)", "distance": "r (ang)", } def main(dat_files: "list[str]", plot_settings: "list[dict]"): groups = list(read_groups(dat_files)) for i, settings in enumerate(plot_settings): data_list = [] x_variable = "energy" y_variable = settings["variable"]["variable"] x_min = settings["variable"]["x_limit_min"] x_max = settings["variable"]["x_limit_max"] y_min = settings["variable"]["y_limit_min"] y_max = settings["variable"]["y_limit_max"] plot_path = f"plots/{i}_{y_variable}.png" plt.figure() for group in groups: params = group.athena_params annotation = getattr(params, "annotation", None) file = getattr(params, "file", None) params_id = getattr(params, "id", None) label = annotation or file or params_id if y_variable == "chir_mag": x_variable = "distance" x = group.r else: x = group.energy y = getattr(group, y_variable) if x_min is None and x_max is None: plt.plot(x, y, label=label) else: data_list.append({"x": x, "y": y, "label": label}) if x_min is not None or x_max is not None: for data in data_list: index_min = None index_max = None x = data["x"] if x_min is not None: index_min = max(np.searchsorted(x, x_min) - 1, 0) if x_max is not None: index_max = min(np.searchsorted(x, x_max) + 1, len(x)) plt.plot( x[index_min:index_max], data["y"][index_min:index_max], label=data["label"], ) plt.xlim(x_min, x_max) plt.ylim(y_min, y_max) save_plot(x_variable, y_variable, plot_path) def save_plot(x_type: str, y_type: str, plot_path: str): plt.grid(color="r", linestyle=":", linewidth=1) plt.xlabel(AXIS_LABELS[x_type]) plt.ylabel(AXIS_LABELS[y_type]) plt.legend() plt.savefig(plot_path, format="png") plt.close("all") if __name__ == "__main__": # larch imports set this to an interactive backend, so need to change it matplotlib.use("Agg") dat_files = sys.argv[1] input_values = json.load(open(sys.argv[2], "r", encoding="utf-8")) main(dat_files.split(","), input_values["plots"])