Mercurial > repos > muon-spectroscopy-computational-project > larch_plot
diff larch_plot.py @ 4:35d24102cefd draft
planemo upload for repository https://github.com/MaterialsGalaxy/larch-tools/tree/main/larch_plot commit 3fe6078868efd0fcea0fb5eea8dcd4b152d9c0a8
author | muon-spectroscopy-computational-project |
---|---|
date | Thu, 11 Apr 2024 09:02:24 +0000 |
parents | 5b993aff09e3 |
children |
line wrap: on
line diff
--- a/larch_plot.py Fri Mar 22 14:23:33 2024 +0000 +++ b/larch_plot.py Thu Apr 11 09:02:24 2024 +0000 @@ -3,6 +3,8 @@ from common import read_groups +from larch.symboltable import Group + import matplotlib import matplotlib.pyplot as plt @@ -10,67 +12,84 @@ AXIS_LABELS = { + "energy": "Energy (eV)", + "distance": "r (ang)", + "sample": "Sample", "flat": r"x$\mu$(E), flattened", "dmude": r"d(x$\mu$(E))/dE, normalised", "chir_mag": r"|$\chi$(r)|", - "energy": "Energy (eV)", - "distance": "r (ang)", + "e0": "Edge Energy (eV)", } +def sample_plot(groups: "list[Group]", y_variable: str): + x = [get_label(group) for group in groups] + y = [getattr(group, y_variable) for group in groups] + plt.scatter(x, y) + + def main(dat_files: "list[str]", plot_settings: "list[dict]"): groups = list(read_groups(dat_files)) for i, settings in enumerate(plot_settings): data_list = [] - x_variable = "energy" y_variable = settings["variable"]["variable"] - x_min = settings["variable"]["x_limit_min"] - x_max = settings["variable"]["x_limit_max"] - y_min = settings["variable"]["y_limit_min"] - y_max = settings["variable"]["y_limit_max"] plot_path = f"plots/{i}_{y_variable}.png" plt.figure() - for group in groups: - params = group.athena_params - annotation = getattr(params, "annotation", None) - file = getattr(params, "file", None) - params_id = getattr(params, "id", None) - label = annotation or file or params_id - if y_variable == "chir_mag": - x_variable = "distance" - x = group.r - else: - x = group.energy - - y = getattr(group, y_variable) - if x_min is None and x_max is None: - plt.plot(x, y, label=label) - else: - data_list.append({"x": x, "y": y, "label": label}) + if y_variable == "e0": + x_variable = "sample" + sample_plot(groups, y_variable) + else: + x_variable = "energy" + x_min = settings["variable"]["x_limit_min"] + x_max = settings["variable"]["x_limit_max"] + y_min = settings["variable"]["y_limit_min"] + y_max = settings["variable"]["y_limit_max"] + for group in groups: + label = get_label(group) + if y_variable == "chir_mag": + x_variable = "distance" + x = group.r + else: + x = group.energy - if x_min is not None or x_max is not None: - for data in data_list: - index_min = None - index_max = None - x = data["x"] - if x_min is not None: - index_min = max(np.searchsorted(x, x_min) - 1, 0) - if x_max is not None: - index_max = min(np.searchsorted(x, x_max) + 1, len(x)) - plt.plot( - x[index_min:index_max], - data["y"][index_min:index_max], - label=data["label"], - ) + y = getattr(group, y_variable) + if x_min is None and x_max is None: + plt.plot(x, y, label=label) + else: + data_list.append({"x": x, "y": y, "label": label}) - plt.xlim(x_min, x_max) - plt.ylim(y_min, y_max) + if x_min is not None or x_max is not None: + for data in data_list: + index_min = None + index_max = None + x = data["x"] + if x_min is not None: + index_min = max(np.searchsorted(x, x_min) - 1, 0) + if x_max is not None: + index_max = min(np.searchsorted(x, x_max) + 1, len(x)) + plt.plot( + x[index_min:index_max], + data["y"][index_min:index_max], + label=data["label"], + ) + + plt.xlim(x_min, x_max) + plt.ylim(y_min, y_max) save_plot(x_variable, y_variable, plot_path) +def get_label(group: Group) -> str: + params = group.athena_params + annotation = getattr(params, "annotation", None) + file = getattr(params, "file", None) + params_id = getattr(params, "id", None) + label = annotation or file or params_id + return label + + def save_plot(x_type: str, y_type: str, plot_path: str): plt.grid(color="r", linestyle=":", linewidth=1) plt.xlabel(AXIS_LABELS[x_type])