Mercurial > repos > muon-spectroscopy-computational-project > larch_plot
view larch_plot.py @ 6:0339eb694129 draft default tip
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_plot commit 7f52c8654581d23a2acffc818e0c197cf8e04504
author | muon-spectroscopy-computational-project |
---|---|
date | Tue, 03 Sep 2024 11:51:45 +0000 |
parents | 35d24102cefd |
children |
line wrap: on
line source
import json import sys from common import read_groups from larch.symboltable import Group import matplotlib import matplotlib.pyplot as plt import numpy as np AXIS_LABELS = { "energy": "Energy (eV)", "distance": "r (ang)", "sample": "Sample", "flat": r"x$\mu$(E), flattened", "dmude": r"d(x$\mu$(E))/dE, normalised", "chir_mag": r"|$\chi$(r)|", "e0": "Edge Energy (eV)", } def sample_plot(groups: "list[Group]", y_variable: str): x = [get_label(group) for group in groups] y = [getattr(group, y_variable) for group in groups] plt.scatter(x, y) def main(dat_files: "list[str]", plot_settings: "list[dict]"): groups = list(read_groups(dat_files)) for i, settings in enumerate(plot_settings): data_list = [] y_variable = settings["variable"]["variable"] plot_path = f"plots/{i}_{y_variable}.png" plt.figure() if y_variable == "e0": x_variable = "sample" sample_plot(groups, y_variable) else: x_variable = "energy" x_min = settings["variable"]["x_limit_min"] x_max = settings["variable"]["x_limit_max"] y_min = settings["variable"]["y_limit_min"] y_max = settings["variable"]["y_limit_max"] for group in groups: label = get_label(group) if y_variable == "chir_mag": x_variable = "distance" x = group.r else: x = group.energy y = getattr(group, y_variable) if x_min is None and x_max is None: plt.plot(x, y, label=label) else: data_list.append({"x": x, "y": y, "label": label}) if x_min is not None or x_max is not None: for data in data_list: index_min = None index_max = None x = data["x"] if x_min is not None: index_min = max(np.searchsorted(x, x_min) - 1, 0) if x_max is not None: index_max = min(np.searchsorted(x, x_max) + 1, len(x)) plt.plot( x[index_min:index_max], data["y"][index_min:index_max], label=data["label"], ) plt.xlim(x_min, x_max) plt.ylim(y_min, y_max) save_plot(x_variable, y_variable, plot_path) def get_label(group: Group) -> str: params = group.athena_params annotation = getattr(params, "annotation", None) file = getattr(params, "file", None) params_id = getattr(params, "id", None) label = annotation or file or params_id return label def save_plot(x_type: str, y_type: str, plot_path: str): plt.grid(color="r", linestyle=":", linewidth=1) plt.xlabel(AXIS_LABELS[x_type]) plt.ylabel(AXIS_LABELS[y_type]) plt.legend() plt.savefig(plot_path, format="png") plt.close("all") if __name__ == "__main__": # larch imports set this to an interactive backend, so need to change it matplotlib.use("Agg") dat_files = sys.argv[1] input_values = json.load(open(sys.argv[2], "r", encoding="utf-8")) main(dat_files.split(","), input_values["plots"])